Skip to content

Commit

Permalink
Refactor and split source generator into multiple files
Browse files Browse the repository at this point in the history
  • Loading branch information
bash committed Jun 22, 2021
1 parent 6a7e6f6 commit 5261258
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace System.Linq.Async.SourceGenerator
{
internal sealed record AsyncMethod(IMethodSymbol Symbol, MethodDeclarationSyntax Syntax);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using System.Collections.Generic;

using Microsoft.CodeAnalysis;

namespace System.Linq.Async.SourceGenerator
{
internal sealed record AsyncMethodGrouping(SyntaxTree SyntaxTree, IEnumerable<AsyncMethod> Methods);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,77 +25,92 @@ public sealed class AsyncOverloadsGenerator : ISourceGenerator
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
context.RegisterForPostInitialization(c => c.AddSource("Attribute.cs", AttributeSource));
context.RegisterForPostInitialization(c => c.AddSource("GenerateAsyncOverloadAttribute", AttributeSource));
}

public void Execute(GeneratorExecutionContext context)
{
if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;

var supportFlatAsyncApi = context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API");
var attributeSymbol = context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute");

foreach (var grouping in syntaxReceiver.Candidates.GroupBy(c => c.SyntaxTree))
{
var model = context.Compilation.GetSemanticModel(grouping.Key);
var methodsBuilder = new StringBuilder();

foreach (var candidate in grouping)
{
var methodSymbol = model.GetDeclaredSymbol(candidate) ?? throw new NullReferenceException();

if (!methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))) continue;

var shortName = methodSymbol.Name.Replace("Core", "");
if (supportFlatAsyncApi)
{
shortName = shortName.Replace("Await", "").Replace("WithCancellation", "");
}

var publicMethod = MethodDeclaration(candidate.ReturnType, shortName)
.WithModifiers(TokenList(Token(TriviaList(), SyntaxKind.PublicKeyword, TriviaList(Space)), Token(TriviaList(), SyntaxKind.StaticKeyword, TriviaList(Space))))
.WithTypeParameterList(candidate.TypeParameterList)
.WithParameterList(candidate.ParameterList)
.WithConstraintClauses(candidate.ConstraintClauses)
.WithExpressionBody(ArrowExpressionClause(InvocationExpression(IdentifierName(methodSymbol.Name), ArgumentList(SeparatedList(candidate.ParameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier))))))))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithLeadingTrivia(candidate.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax));

methodsBuilder.AppendLine(publicMethod.ToFullString());
}

if (methodsBuilder.Length == 0) continue;

var usings = grouping.Key.GetRoot() is CompilationUnitSyntax compilationUnit
? compilationUnit.Usings
: List<UsingDirectiveSyntax>();

var overloads = new StringBuilder();
overloads.AppendLine("#nullable enable");
overloads.AppendLine(usings.ToString());
overloads.AppendLine("namespace System.Linq");
overloads.AppendLine("{");
overloads.AppendLine(" partial class AsyncEnumerable");
overloads.AppendLine(" {");
overloads.AppendLine(methodsBuilder.ToString());
overloads.AppendLine(" }");
overloads.AppendLine("}");

context.AddSource($"{Path.GetFileNameWithoutExtension(grouping.Key.FilePath)}.AsyncOverloads.cs", overloads.ToString());
}
var options = GetGenerationOptions(context);
var methodsBySyntaxTree = GetMethodsGroupedBySyntaxTree(context, syntaxReceiver);

foreach (var grouping in methodsBySyntaxTree)
context.AddSource(
$"{Path.GetFileNameWithoutExtension(grouping.SyntaxTree.FilePath)}.AsyncOverloads",
GenerateOverloads(grouping, options));
}

private static GenerationOptions GetGenerationOptions(GeneratorExecutionContext context)
=> new(SupportFlatAsyncApi: context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API"));

private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver)
=> GetMethodsGroupedBySyntaxTree(
context,
syntaxReceiver,
GetAsyncOverloadAttributeSymbol(context));

private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options)
{
var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
? compilationUnit.Usings.ToString()
: string.Empty;

var overloads = new StringBuilder();
overloads.AppendLine("#nullable enable");
overloads.AppendLine(usings);
overloads.AppendLine("namespace System.Linq");
overloads.AppendLine("{");
overloads.AppendLine(" partial class AsyncEnumerable");
overloads.AppendLine(" {");

foreach (var method in grouping.Methods)
overloads.AppendLine(GenerateOverload(method, options));

overloads.AppendLine(" }");
overloads.AppendLine("}");

return overloads.ToString();
}

private sealed class SyntaxReceiver : ISyntaxReceiver
private static string GenerateOverload(AsyncMethod method, GenerationOptions options)
=> MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
.WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(method.Syntax.TypeParameterList)
.WithParameterList(method.Syntax.ParameterList)
.WithConstraintClauses(method.Syntax.ConstraintClauses)
.WithExpressionBody(ArrowExpressionClause(
InvocationExpression(
IdentifierName(method.Symbol.Name),
ArgumentList(
SeparatedList(
method.Syntax.ParameterList.Parameters
.Select(p => Argument(IdentifierName(p.Identifier))))))))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax))
.NormalizeWhitespace()
.ToFullString();

private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
=> context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();

private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, INamedTypeSymbol attributeSymbol)
=> from candidate in syntaxReceiver.Candidates
group candidate by candidate.SyntaxTree into grouping
let model = context.Compilation.GetSemanticModel(grouping.Key)
select new AsyncMethodGrouping(
grouping.Key,
from methodSyntax in grouping
let methodSymbol = model.GetDeclaredSymbol(methodSyntax) ?? throw new InvalidOperationException()
where methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))
select new AsyncMethod(methodSymbol, methodSyntax));

private static string GetMethodName(IMethodSymbol methodSymbol, GenerationOptions options)
{
public IList<MethodDeclarationSyntax> Candidates { get; } = new List<MethodDeclarationSyntax>();

public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
if (syntaxNode is MethodDeclarationSyntax { AttributeLists: { Count: >0 } } methodDeclarationSyntax)
{
Candidates.Add(methodDeclarationSyntax);
}
}
var methodName = methodSymbol.Name.Replace("Core", "");
return options.SupportFlatAsyncApi
? methodName.Replace("Await", "").Replace("WithCancellation", "")
: methodName;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
namespace System.Linq.Async.SourceGenerator
{
internal sealed record GenerationOptions(bool SupportFlatAsyncApi);
}
20 changes: 20 additions & 0 deletions Ix.NET/Source/System.Linq.Async.SourceGenerator/SyntaxReceiver.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System.Collections.Generic;

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace System.Linq.Async.SourceGenerator
{
internal sealed class SyntaxReceiver : ISyntaxReceiver
{
public IList<MethodDeclarationSyntax> Candidates { get; } = new List<MethodDeclarationSyntax>();

public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
if (syntaxNode is MethodDeclarationSyntax { AttributeLists: { Count: >0 } } methodDeclarationSyntax)
{
Candidates.Add(methodDeclarationSyntax);
}
}
}
}

0 comments on commit 5261258

Please sign in to comment.