Browse Source

Add serialization source generator

feature/3.0-serialization-generator
FiniteReality 4 years ago
parent
commit
091cc6bfb1
19 changed files with 963 additions and 138 deletions
  1. +0
    -4
      samples/PingPong/PingPong.csproj
  2. +241
    -0
      src/Models/Channel.cs
  3. +68
    -0
      src/Models/ChannelType.cs
  4. +7
    -0
      src/Models/Discord.Net.Models.csproj
  5. +3
    -0
      src/Serialization/DiscriminatedUnionAttribute.cs
  6. +5
    -2
      src/Serialization/DiscriminatedUnionMemberAttribute.cs
  7. +16
    -0
      src/Serialization/GenerateSerializerAttribute.cs
  8. +5
    -0
      tools/SourceGenerators/Directory.Build.props
  9. +4
    -0
      tools/SourceGenerators/IsExternalInit.cs
  10. +4
    -0
      tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj
  11. +6
    -0
      tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props
  12. +0
    -37
      tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs
  13. +15
    -31
      tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs
  14. +106
    -64
      tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs
  15. +253
    -0
      tools/SourceGenerators/Serialization/Structure/SerializedType.cs
  16. +119
    -0
      tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs
  17. +27
    -0
      tools/SourceGenerators/Serialization/SymbolExtensions.cs
  18. +25
    -0
      tools/SourceGenerators/Serialization/Utils.cs
  19. +59
    -0
      tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs

+ 0
- 4
samples/PingPong/PingPong.csproj View File

@@ -10,10 +10,6 @@
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<ProjectReference Include="../../tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Hosting" /> <PackageReference Include="Microsoft.Extensions.Hosting" />
</ItemGroup> </ItemGroup>




+ 241
- 0
src/Models/Channel.cs View File

@@ -0,0 +1,241 @@
using System;
using Discord.Net.Serialization;

namespace Discord.Net.Models
{
/// <summary>
/// Represents a guild or DM channel within Discord.
/// </summary>
/// <remarks>
/// <see href="https://discord.com/developers/docs/resources/channel#channel-object-channel-structure"/>
/// </remarks>
/// <param name="Id">
/// The id of this channel.
/// </param>
/// <param name="Type">
/// The type of channel.
/// </param>
[DiscriminatedUnion(nameof(Channel.Type))]
[GenerateSerializer]
public record Channel(
ChannelType Type,
Snowflake Id);

/// <summary>
/// Represents a text channel within a server.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildText)]
[GenerateSerializer]
public record GuildTextChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
/*Overwrite[] PermissionOverwrites,*/
string Name,
string? Topic,
bool Nsfw,
Snowflake LastMessageId,
int RateLimitPerUser,
Snowflake? ParentId,
DateTimeOffset? LastPinTimestamp)
: Channel(
ChannelType.GuildText,
Id);

/*

/// <summary>
/// Represents a direct message between users.
/// </summary>
[DiscriminatedUnionMember(ChannelType.DM)]
[GenerateSerializer]
public record DMChannel(
Snowflake Id,
User[] Recipients)
: Channel(
ChannelType.DM,
Id);

/// <summary>
/// Represents a voice channel within a server.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildVoice)]
[GenerateSerializer]
public record GuildVoiceChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites,
string Name,
bool Nsfw,
int Bitrate,
int UserLimit,
Snowflake? ParentId,
string? RtcRegion,
int VideoQualityMode)
: Channel(
ChannelType.GuildVoice,
Id);

/// <summary>
/// Represents a direct message between multiple users.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GroupDM)]
[GenerateSerializer]
public record GroupDMChannel(
Snowflake Id,
string Name,
Snowflake LastMessageId,
User[] Recipients,
string? Icon,
Snowflake? OwnerId,
Snowflake? ApplicationId,
DateTimeOffset? LastPinTimestamp)
: Channel(
ChannelType.GroupDM,
Id);

/// <summary>
/// Represents an organizational category that contains up to 50 channels.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildCategory)]
[GenerateSerializer]
public record GuildCategoryChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites,
string Name)
: Channel(
ChannelType.GuildCategory,
Id);

/// <summary>
/// Represents a channel that users can follow and crosspost into their own
/// server.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildNews)]
[GenerateSerializer]
public record GuildNewsChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites,
string Name,
string? Topic,
bool Nsfw,
Snowflake? LastMessageId,
int RateLimitPerUser,
Snowflake? ParentId,
Snowflake? LastPinTimestamp)
: Channel(
ChannelType.GuildNews,
Id);

/// <summary>
/// Represents a channel in which game developers can sell their game on
/// Discord.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildStore)]
[GenerateSerializer]
public record GuildStoreChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites, // I guess???
string? Name,
Snowflake? ParentId)
: Channel(
ChannelType.GuildStore,
Id);

/// <summary>
/// Represents a temporary sub-channel within a
/// <see cref="GuildNewsChannel"/>.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildNewsThread)]
[GenerateSerializer]
public record GuildNewsThreadChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites, // I guess??
string Name,
Snowflake? LastMessageId,
Snowflake? ParentId,
Snowflake? LastPinTimestamp,
int MessageCount,
int MemberCount,
ThreadMetadata ThreadMetadata,
ThreadMember Member)
: Channel(
ChannelType.GuildNewsThread,
Id);

/// <summary>
/// Represents a temporary sub-channel within a
/// <see cref="GuildTextChannel"/>.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildPublicThread)]
[GenerateSerializer]
public record GuildPublicThreadChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites, // I guess??
string Name,
Snowflake? LastMessageId,
Snowflake? ParentId,
Snowflake? LastPinTimestamp,
int MessageCount,
int MemberCount,
ThreadMetadata ThreadMetadata,
ThreadMember Member)
: Channel(
ChannelType.GuildPublicThread,
Id);

/// <summary>
/// Represents a temporary sub-channel within a
/// <see cref="GuildTextChannel"/> that is only viewable by those invited
/// and those with the MANAGE_THREADS permission.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildPrivateThread)]
[GenerateSerializer]
public record GuildPrivateThreadChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites, // I guess???
string Name,
Snowflake? LastMessageId,
Snowflake? ParentId,
Snowflake? LastPinTimestamp,
int MessageCount,
int MemberCount,
ThreadMetadata ThreadMetadata,
ThreadMember Member)
: Channel(
ChannelType.GuildPrivateThread,
Id);

/// <summary>
/// Represents a voice channel for hosting events with an audience.
/// </summary>
[DiscriminatedUnionMember(ChannelType.GuildStageVoice)]
[GenerateSerializer]
public record GuildStageVoiceChannel(
Snowflake Id,
Snowflake GuildId,
int Position,
Overwrite[] PermissionOverwrites,
string Name,
int Bitrate,
int UserLimit,
string? RtcRegion)
: Channel(
ChannelType.GuildStageVoice,
Id);

*/
}

+ 68
- 0
src/Models/ChannelType.cs View File

@@ -0,0 +1,68 @@
namespace Discord.Net.Models
{
/// <summary>
/// Declares an enum which represents the type of a <see cref="Channel"/>.
/// </summary>
/// <remarks>
/// <see href="https://discord.com/developers/docs/resources/channel#channel-object-channel-types"/>
/// </remarks>
public enum ChannelType
{
/// <summary>
/// A text channel within a server.
/// </summary>
GuildText = 0,

/// <summary>
/// A direct message between users.
/// </summary>
DM = 1,

/// <summary>
/// A voice channel within a server.
/// </summary>
GuildVoice = 2,

/// <summary>
/// A direct message between multiple users.
/// </summary>
GroupDM = 3,

/// <summary>
/// An organizational category that contains up to 50 channels.
/// </summary>
GuildCategory = 4,

/// <summary>
/// A channel that users can follow and crosspost into their own server.
/// </summary>
GuildNews = 5,

/// <summary>
/// A channel in which game developers can sell their game on Discord.
/// </summary>
GuildStore = 6,

/// <summary>
/// A temporary sub-channel within a <see cref="GuildNews"/> channel.
/// </summary>
GuildNewsThread = 10,

/// <summary>
/// A temporary sub-channel within a <see cref="GuildText"/> channel.
/// </summary>
GuildPublicThread = 11,

/// <summary>
/// A temporary sub-channel within a <see cref="GuildText"/> channel
/// that is only viewable by those invited and those with the
/// MANAGE_THREADS permission.
/// </summary>
GuildPrivateThread = 12,

/// <summary>
/// A voice channel for hosting events with an audience.
/// </summary>
GuildStageVoice = 13
}
}

+ 7
- 0
src/Models/Discord.Net.Models.csproj View File

@@ -7,6 +7,8 @@
$(Description) $(Description)
Shared models between the Discord REST API and Gateway. Shared models between the Discord REST API and Gateway.
</Description> </Description>

<DiscordNet_SerializationGenerator_OptionsTypeNamespace>Discord.Net.Serialization</DiscordNet_SerializationGenerator_OptionsTypeNamespace>
</PropertyGroup> </PropertyGroup>


<ItemGroup> <ItemGroup>
@@ -14,6 +16,11 @@
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<CompilerVisibleProperty Include="DiscordNet_SerializationGenerator_OptionsTypeNamespace" />
<ProjectReference Include="../../tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="../Serialization/Discord.Net.Serialization.csproj" /> <ProjectReference Include="../Serialization/Discord.Net.Serialization.csproj" />
</ItemGroup> </ItemGroup>




+ 3
- 0
src/Serialization/DiscriminatedUnionAttribute.cs View File

@@ -5,6 +5,9 @@ namespace Discord.Net.Serialization
/// <summary> /// <summary>
/// Defines an attribute used to mark discriminated unions. /// Defines an attribute used to mark discriminated unions.
/// </summary> /// </summary>
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct,
AllowMultiple = false,
Inherited = false)]
public class DiscriminatedUnionAttribute : Attribute public class DiscriminatedUnionAttribute : Attribute
{ {
/// <summary> /// <summary>


+ 5
- 2
src/Serialization/DiscriminatedUnionMemberAttribute.cs View File

@@ -5,12 +5,15 @@ namespace Discord.Net.Serialization
/// <summary> /// <summary>
/// Defines an attribute used to mark members of discriminated unions. /// Defines an attribute used to mark members of discriminated unions.
/// </summary> /// </summary>
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct,
AllowMultiple = false,
Inherited = false)]
public class DiscriminatedUnionMemberAttribute : Attribute public class DiscriminatedUnionMemberAttribute : Attribute
{ {
/// <summary> /// <summary>
/// Gets the discriminator value used to identify this member type. /// Gets the discriminator value used to identify this member type.
/// </summary> /// </summary>
public string Discriminator { get; }
public object Discriminator { get; }


/// <summary> /// <summary>
/// Creates a new <see cref="DiscriminatedUnionMemberAttribute"/> /// Creates a new <see cref="DiscriminatedUnionMemberAttribute"/>
@@ -19,7 +22,7 @@ namespace Discord.Net.Serialization
/// <param name="discriminator"> /// <param name="discriminator">
/// The discriminator value used to identify this member type. /// The discriminator value used to identify this member type.
/// </param> /// </param>
public DiscriminatedUnionMemberAttribute(string discriminator)
public DiscriminatedUnionMemberAttribute(object discriminator)
{ {
Discriminator = discriminator; Discriminator = discriminator;
} }


+ 16
- 0
src/Serialization/GenerateSerializerAttribute.cs View File

@@ -0,0 +1,16 @@
using System;

namespace Discord.Net.Serialization
{
/// <summary>
/// Defines an attribute which informs the serializer generator to generate
/// a serializer for this type.
/// </summary>
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct,
AllowMultiple = false,
Inherited = false)]
public class GenerateSerializerAttribute : Attribute
{

}
}

+ 5
- 0
tools/SourceGenerators/Directory.Build.props View File

@@ -21,6 +21,7 @@
<PropertyGroup> <PropertyGroup>
<GenerateDocumentationFile>false</GenerateDocumentationFile> <GenerateDocumentationFile>false</GenerateDocumentationFile>
<NoPackageAnalysis>true</NoPackageAnalysis> <NoPackageAnalysis>true</NoPackageAnalysis>
<IncludeBuildOutput>false</IncludeBuildOutput>
<!-- Disable release tracking analyzers due to weird behaviour with OmniSharp --> <!-- Disable release tracking analyzers due to weird behaviour with OmniSharp -->
<NoWarn>$(NoWarn);RS2000;RS2001;RS2002;RS2003;RS2004;RS2005;RS2006;RS2007;RS2008</NoWarn> <NoWarn>$(NoWarn);RS2000;RS2001;RS2002;RS2003;RS2004;RS2005;RS2006;RS2007;RS2008</NoWarn>
</PropertyGroup> </PropertyGroup>
@@ -33,4 +34,8 @@
<None Include="$(OutputPath)$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" /> <None Include="$(OutputPath)$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
</ItemGroup> </ItemGroup>


<ItemGroup>
<Compile Include="$(MSBuildThisFileDirectory)\IsExternalInit.cs" Link="IsExternalInit.cs" />
</ItemGroup>

</Project> </Project>

+ 4
- 0
tools/SourceGenerators/IsExternalInit.cs View File

@@ -0,0 +1,4 @@
namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 4
- 0
tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj View File

@@ -4,4 +4,8 @@
<TargetFramework>netstandard2.0</TargetFramework> <TargetFramework>netstandard2.0</TargetFramework>
</PropertyGroup> </PropertyGroup>


<ItemGroup>
<None Include="Discord.Net.SourceGenerators.Serialization.props" Pack="true" PackagePath="build" Visible="false" />
</ItemGroup>

</Project> </Project>

+ 6
- 0
tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props View File

@@ -0,0 +1,6 @@
<Project>
<ItemGroup>
<CompilerVisibleProperty Include="DiscordNet_SerializationGenerator_OptionsTypeNamespace" />
<CompilerVisibleProperty Include="DiscordNet_SerializationGenerator_SearchThroughReferencedAssemblies" />
</ItemGroup>
</Project>

+ 0
- 37
tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs View File

@@ -1,37 +0,0 @@
using Microsoft.CodeAnalysis;

namespace Discord.Net.SourceGenerators.Serialization
{
public partial class SerializationSourceGenerator
{
private static string GenerateConverter(INamedTypeSymbol @class)
{
return $@"
using System;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace Discord.Net.Serialization.Converters
{{
public class {@class.Name}Converter : JsonConverter<{@class.ToDisplayString()}>
{{
public override {@class.ToDisplayString()} Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options)
{{
return default;
}}

public override void Write(
Utf8JsonWriter writer,
{@class.ToDisplayString()} value,
JsonSerializerOptions options)
{{
writer.WriteNull();
}}
}}
}}";
}
}
}

+ 15
- 31
tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs View File

@@ -5,20 +5,26 @@ namespace Discord.Net.SourceGenerators.Serialization
{ {
public partial class SerializationSourceGenerator public partial class SerializationSourceGenerator
{ {
private static string GenerateSerializerOptionsTemplateSourceCode()
private static string GenerateSerializerOptionsSourceCode(
string @namespace,
IEnumerable<SerializedType> converters)
{ {
return @"
using System;
var snippets = string.Join("\n",
converters.Select(
x => $" options.Converters.Add(new {@namespace}.Internal.Converters.{x.ConverterTypeName}());"));

return $@"using System;
using System.Text.Json; using System.Text.Json;
using Discord.Net.Serialization.Converters;


namespace Discord.Net.Serialization
{
namespace {@namespace}
{{
/// <summary> /// <summary>
/// Defines extension methods for adding Discord.Net JSON converters to a /// Defines extension methods for adding Discord.Net JSON converters to a
/// <see cref=""JsonSerializerOptions""/> instance. /// <see cref=""JsonSerializerOptions""/> instance.
/// </summary> /// </summary>
public static partial class JsonSerializerOptionsExtensions
{
public static class JsonSerializerOptionsExtensions
{{
/// <summary> /// <summary>
/// Adds Discord.Net JSON converters to the passed /// Adds Discord.Net JSON converters to the passed
/// <see cref=""JsonSerializerOptions""/>. /// <see cref=""JsonSerializerOptions""/>.
@@ -30,33 +36,11 @@ namespace Discord.Net.Serialization
/// The modified <see cref=""JsonSerializerOptions""/>, so this method /// The modified <see cref=""JsonSerializerOptions""/>, so this method
/// can be chained. /// can be chained.
/// </returns> /// </returns>
public static partial JsonSerializerOptions WithDiscordNetConverters(
this JsonSerializerOptions options);
}
}";
}

private static string GenerateSerializerOptionsSourceCode(
List<string> converters)
{
var snippets = string.Join("\n",
converters.Select(
x => $"options.Converters.Add(new {x}());"));

return $@"
using System;
using System.Text.Json;
using Discord.Net.Serialization.Converters;

namespace Discord.Net.Serialization
{{
public static partial class JsonSerializerOptionsExtensions
{{
public static partial JsonSerializerOptions WithDiscordNetConverters(
public static JsonSerializerOptions WithDiscordNetConverters(
this JsonSerializerOptions options) this JsonSerializerOptions options)
{{ {{
options.Converters.Add(new OptionalConverterFactory()); options.Converters.Add(new OptionalConverterFactory());
{snippets}
{snippets}


return options; return options;
}} }}


+ 106
- 64
tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs View File

@@ -3,7 +3,7 @@ using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Reflection;
using System.Threading;
using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.CSharp.Syntax;


@@ -14,94 +14,136 @@ namespace Discord.Net.SourceGenerators.Serialization
{ {
public void Execute(GeneratorExecutionContext context) public void Execute(GeneratorExecutionContext context)
{ {
if (!context.AnalyzerConfigOptions.GlobalOptions.TryGetValue(
"build_property.DiscordNet_SerializationGenerator_OptionsTypeNamespace",
out var serializerOptionsNamespace))
throw new InvalidOperationException(
"Missing output namespace. Set DiscordNet_SerializationGenerator_OptionsTypeNamespace in your project file.");

bool searchThroughReferencedAssemblies =
context.AnalyzerConfigOptions.GlobalOptions.TryGetValue(
"build_property.DiscordNet_SerializationGenerator_SearchThroughReferencedAssemblies",
out var _);

var generateSerializerAttribute = context.Compilation
.GetTypeByMetadataName(
"Discord.Net.Serialization.GenerateSerializerAttribute");
var discriminatedUnionSymbol = context.Compilation
.GetTypeByMetadataName(
"Discord.Net.Serialization.DiscriminatedUnionAttribute");
var discriminatedUnionMemberSymbol = context.Compilation
.GetTypeByMetadataName(
"Discord.Net.Serialization.DiscriminatedUnionMemberAttribute");

Debug.Assert(generateSerializerAttribute != null);
Debug.Assert(discriminatedUnionSymbol != null);
Debug.Assert(discriminatedUnionMemberSymbol != null);
Debug.Assert(context.SyntaxContextReceiver != null);

var receiver = (SyntaxReceiver)context.SyntaxContextReceiver!; var receiver = (SyntaxReceiver)context.SyntaxContextReceiver!;
var converters = new List<string>();
var symbolsToBuild = receiver.GetSerializedTypes(
context.Compilation);


foreach (var @class in receiver.Classes)
if (searchThroughReferencedAssemblies)
{ {
var semanticModel = context.Compilation.GetSemanticModel(
@class.SyntaxTree);
var visitor = new VisibleTypeVisitor(context.CancellationToken);
foreach (var module in context.Compilation.Assembly.Modules)
foreach (var reference in module.ReferencedAssemblySymbols)
visitor.Visit(reference);


if (semanticModel.GetDeclaredSymbol(@class) is
not INamedTypeSymbol classSymbol)
throw new InvalidOperationException(
"Could not find named type symbol for " +
$"{@class.Identifier}");
symbolsToBuild = symbolsToBuild
.Concat(visitor.GetVisibleTypes());
}


context.AddSource(
$"Converters.{classSymbol.Name}",
GenerateConverter(classSymbol));
var types = SerializedTypeUtils.BuildTypeTrees(
generateSerializerAttribute: generateSerializerAttribute!,
discriminatedUnionSymbol: discriminatedUnionSymbol!,
discriminatedUnionMemberSymbol: discriminatedUnionMemberSymbol!,
symbolsToBuild: symbolsToBuild);


converters.Add($"{classSymbol.Name}Converter");
foreach (var type in types)
{
context.AddSource($"Converters.{type.ConverterTypeName}",
type.GenerateSourceCode(serializerOptionsNamespace));

if (type is DiscriminatedUnionSerializedType duDeclaration)
foreach (var member in duDeclaration.Members)
context.AddSource(
$"Converters.{type.ConverterTypeName}.{member.ConverterTypeName}",
member.GenerateSourceCode(serializerOptionsNamespace));
} }


context.AddSource("SerializerOptions.Complete",
GenerateSerializerOptionsSourceCode(converters));
context.AddSource("SerializerOptions",
GenerateSerializerOptionsSourceCode(
serializerOptionsNamespace, types));
} }


public void Initialize(GeneratorInitializationContext context) public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForPostInitialization(PostInitialize);
context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
}

public static void PostInitialize(
GeneratorPostInitializationContext context)
=> context.AddSource("SerializerOptions.Template",
GenerateSerializerOptionsTemplateSourceCode());
=> context.RegisterForSyntaxNotifications(
() => new SyntaxReceiver());


internal class SyntaxReceiver : ISyntaxContextReceiver
private class SyntaxReceiver : ISyntaxContextReceiver
{ {
public List<ClassDeclarationSyntax> Classes { get; } = new();

private readonly Dictionary<string, INamedTypeSymbol> _interestingAttributes
= new();
private readonly List<SyntaxNode> _classes;


public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
public SyntaxReceiver()
{ {
_ = GetOrAddAttribute(_interestingAttributes,
context.SemanticModel,
"Discord.Net.Serialization.DiscriminatedUnionAttribute");
_ = GetOrAddAttribute(_interestingAttributes,
context.SemanticModel,
"Discord.Net.Serialization.DiscriminatedUnionMemberAttribute");
_classes = new();
}


if (context.Node is ClassDeclarationSyntax classDecl
&& classDecl.AttributeLists is
SyntaxList<AttributeListSyntax> attrList
&& attrList.Any(
list => list.Attributes
.Any(a => IsInterestingAttribute(a,
context.SemanticModel,
_interestingAttributes.Values))))
public IEnumerable<INamedTypeSymbol> GetSerializedTypes(
Compilation compilation)
{
foreach (var @class in _classes)
{ {
Classes.Add(classDecl);
var semanticModel = compilation.GetSemanticModel(
@class.SyntaxTree);

if (semanticModel.GetDeclaredSymbol(@class) is
INamedTypeSymbol classSymbol)
yield return classSymbol;
} }
} }


private static INamedTypeSymbol GetOrAddAttribute(
Dictionary<string, INamedTypeSymbol> cache,
SemanticModel model, string name)
private INamedTypeSymbol? _generateSerializerAttributeSymbol;
public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
{ {
if (!cache.TryGetValue(name, out var type))
_generateSerializerAttributeSymbol ??=
context.SemanticModel.Compilation.GetTypeByMetadataName(
"Discord.Net.Serialization.GenerateSerializerAttribute");

Debug.Assert(_generateSerializerAttributeSymbol != null);

if (context.Node is ClassDeclarationSyntax classDeclaration
&& classDeclaration.AttributeLists is
SyntaxList<AttributeListSyntax> classAttributeLists
&& classAttributeLists.Any(
list => list.Attributes.Any(
n => IsAttribute(n, context.SemanticModel,
_generateSerializerAttributeSymbol!))))
{ {
type = model.Compilation.GetTypeByMetadataName(name);
Debug.Assert(type != null);
cache.Add(name, type!);
_classes.Add(classDeclaration);
}
else if (context.Node is RecordDeclarationSyntax recordDeclaration
&& recordDeclaration.AttributeLists is
SyntaxList<AttributeListSyntax> recordAttributeLists
&& recordAttributeLists.Any(
list => list.Attributes.Any(
n => IsAttribute(n, context.SemanticModel,
_generateSerializerAttributeSymbol!))))
{
_classes.Add(recordDeclaration);
} }


return type!;
}

private static bool IsInterestingAttribute(
AttributeSyntax attribute, SemanticModel model,
IEnumerable<INamedTypeSymbol> interestingAttributes)
{
var typeInfo = model.GetTypeInfo(attribute.Name);
static bool IsAttribute(AttributeSyntax attribute,
SemanticModel model, INamedTypeSymbol expected)
{
var typeInfo = model.GetTypeInfo(attribute.Name);


return interestingAttributes.Any(
x => SymbolEqualityComparer.Default
.Equals(typeInfo.Type, x));
return SymbolEqualityComparer.Default.Equals(
typeInfo.Type, expected);
}
} }
} }
} }


+ 253
- 0
tools/SourceGenerators/Serialization/Structure/SerializedType.cs View File

@@ -0,0 +1,253 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using static Discord.Net.SourceGenerators.Serialization.Utils;

namespace Discord.Net.SourceGenerators.Serialization
{
internal record SerializedType(
INamedTypeSymbol Declaration)
{
public virtual string ConverterTypeName
=> $"{Declaration.Name}Converter";

protected virtual IEnumerable<IPropertySymbol> SymbolsToSerialize
=> Declaration.GetProperties(includeInherited: true)
.Where(x => !x.IsReadOnly);

public virtual string GenerateSourceCode(string outputNamespace)
{
var deserializers = SymbolsToSerialize
.Select(GenerateFieldReader);

var bytes = string.Join("\n",
deserializers.Select(x => x.utf8));
var fields = string.Join("\n",
deserializers.Select(x => x.field));
var readers = string.Join("\n",
deserializers.Select(x => x.reader));

var fieldUnassigned = string.Join("\n || ",
deserializers
.Where(x => x.type.NullableAnnotation != NullableAnnotation.Annotated)
.Select(
x => $"{x.snakeCase}OrDefault is not {x.type} {x.snakeCase}"));

var constructorParams = string.Join(",\n",
deserializers
.Select(x => $" {x.name}: {x.snakeCase}{(x.type.NullableAnnotation == NullableAnnotation.Annotated ? "OrDefault" : "")}"));

return $@"using System;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace {outputNamespace}.Converters
{{
internal class {ConverterTypeName} : JsonConverter<{Declaration.ToDisplayString()}>
{{
{bytes}

public override {Declaration.ToDisplayString()}? Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options)
{{
if (reader.TokenType != JsonTokenType.StartObject)
throw new JsonException(""Expected StartObject"");

{fields}

while (reader.Read())
{{
if (reader.TokenType == JsonTokenType.EndObject)
break;

if (reader.TokenType != JsonTokenType.PropertyName)
throw new JsonException(""Expected PropertyName"");

{readers}
else if (!reader.Read())
throw new JsonException();

if (reader.TokenType == JsonTokenType.StartArray
|| reader.TokenType == JsonTokenType.StartObject)
reader.Skip();
}}

if ({fieldUnassigned})
throw new JsonException(""Missing field"");

return new {Declaration.ToDisplayString()}(
{constructorParams}
);
}}

public override void Write(
Utf8JsonWriter writer,
{Declaration.ToDisplayString()} value,
JsonSerializerOptions options)
{{
writer.WriteNullValue();
}}
}}
}}";

static (string name, ITypeSymbol type, string snakeCase, string utf8, string field, string reader)
GenerateFieldReader(IPropertySymbol member, int position)
{
var needsNullableAnnotation = false;
if (member.Type.IsValueType
&& member.Type.OriginalDefinition.SpecialType != SpecialType.System_Nullable_T)
needsNullableAnnotation = true;

var snakeCase = ConvertToSnakeCase(member.Name);
return (member.Name, member.Type, snakeCase,
$@" private static ReadOnlySpan<byte> {member.Name}Bytes => new byte[]
{{
// {snakeCase}
{string.Join(", ", Encoding.UTF8.GetBytes(snakeCase))}
}};",
$" {member.Type.WithNullableAnnotation(NullableAnnotation.Annotated).ToDisplayString()}{(needsNullableAnnotation ? "?" : "")} {snakeCase}OrDefault = default;",
$@" {(position > 0 ? "else " : "")}if (reader.ValueTextEquals({member.Name}Bytes))
{{
if (!reader.Read())
throw new JsonException(""Expected value"");

var cvt = options.GetConverter(
typeof({member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}));

if (cvt is JsonConverter<{member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}> converter)
{snakeCase}OrDefault = converter.Read(ref reader,
typeof({member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}),
options);
else
{snakeCase}OrDefault = JsonSerializer.Deserialize<{member.Type.ToDisplayString()}>(
ref reader, options);
}}");
}
}
}

internal record DiscriminatedUnionSerializedType(
INamedTypeSymbol Declaration,
ISymbol Discriminator)
: SerializedType(Declaration)
{
public List<DiscriminatedUnionMemberSerializedType> Members { get; }
= new();

public override string GenerateSourceCode(string outputNamespace)
{
var discriminatorField = ConvertToSnakeCase(Discriminator.Name);

var discriminatorType = Discriminator switch
{
IPropertySymbol prop => prop.Type,
IFieldSymbol field => field.Type,
_ => throw new InvalidOperationException(
"Unsupported discriminator member type")
};

var switchCaseMembers = string.Join(",\n",
Members.Select(
x => $@" {x.DiscriminatorValue.ToCSharpString()}
=> JsonSerializer.Deserialize(ref copy,
typeof({x.Declaration.ToDisplayString()}), options)"));

return $@"using System;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace {outputNamespace}.Internal.Converters
{{
internal class {ConverterTypeName} : JsonConverter<{Declaration.ToDisplayString()}>
{{
private static ReadOnlySpan<byte> DiscriminatorBytes => new byte[]
{{
// {discriminatorField}
{string.Join(", ", Encoding.UTF8.GetBytes(discriminatorField))}
}};

public override {Declaration.ToDisplayString()}? Read(
ref Utf8JsonReader reader,
Type typeToConvert,
JsonSerializerOptions options)
{{
var copy = reader;
if (reader.TokenType != JsonTokenType.StartObject)
throw new JsonException(""Expected StartObject"");

{discriminatorType.ToDisplayString()}? discriminator = null;

while (reader.Read())
{{
if (reader.TokenType == JsonTokenType.EndObject)
break;

if (reader.TokenType != JsonTokenType.PropertyName)
throw new JsonException(""Expected PropertyName"");

if (reader.ValueTextEquals(DiscriminatorBytes))
{{
if (!reader.Read())
throw new JsonException(""Expected value"");

var cvt = options.GetConverter(
typeof({discriminatorType.ToDisplayString()}));

if (cvt is JsonConverter<{discriminatorType.ToDisplayString()}> converter)
discriminator = converter.Read(ref reader,
typeof({discriminatorType.ToDisplayString()}),
options);
else
discriminator = JsonSerializer
.Deserialize<{discriminatorType.ToDisplayString()}>(
ref reader, options);
}}
else if (!reader.Read())
throw new JsonException(""Expected value"");

if (reader.TokenType == JsonTokenType.StartArray
|| reader.TokenType == JsonTokenType.StartObject)
reader.Skip();
}}

var result = discriminator switch
{{
{switchCaseMembers},
_ => throw new JsonException(""Unknown discriminator value"")
}} as {Declaration.ToDisplayString()};

reader = copy;
return result;
}}

public override void Write(
Utf8JsonWriter writer,
{Declaration.ToDisplayString()} value,
JsonSerializerOptions options)
{{
writer.WriteNullValue();
}}
}}
}}";
}
}

internal record DiscriminatedUnionMemberSerializedType(
INamedTypeSymbol Declaration,
TypedConstant DiscriminatorValue)
: SerializedType(Declaration)
{
public DiscriminatedUnionSerializedType? DiscriminatedUnionDeclaration
{ get; init; }

protected override IEnumerable<IPropertySymbol> SymbolsToSerialize
=> base.SymbolsToSerialize
.Where(x => !SymbolEqualityComparer.Default.Equals(x,
DiscriminatedUnionDeclaration?.Discriminator));
}
}

+ 119
- 0
tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs View File

@@ -0,0 +1,119 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace Discord.Net.SourceGenerators.Serialization
{
internal static class SerializedTypeUtils
{
public static List<SerializedType> BuildTypeTrees(
INamedTypeSymbol generateSerializerAttribute,
INamedTypeSymbol discriminatedUnionSymbol,
INamedTypeSymbol discriminatedUnionMemberSymbol,
IEnumerable<INamedTypeSymbol> symbolsToBuild)
{
var types = new List<SerializedType>();

FindAllSerializedTypes(types, generateSerializerAttribute,
discriminatedUnionSymbol, discriminatedUnionMemberSymbol,
symbolsToBuild);

// Now, move DU members into their relevant DU declaration.
int x = 0;
while (x < types.Count)
{
var type = types[x];
if (type is DiscriminatedUnionMemberSerializedType duMember)
{
var declaration = types.FirstOrDefault(
x => SymbolEqualityComparer.Default.Equals(
x.Declaration, duMember.Declaration.BaseType));

if (declaration is not DiscriminatedUnionSerializedType duDeclaration)
throw new InvalidOperationException(
"Could not find DU declaration for DU " +
$"member {duMember.Declaration.ToDisplayString()}");

duDeclaration.Members.Add(duMember with
{
DiscriminatedUnionDeclaration = duDeclaration
});
types.RemoveAt(x);
continue;
}

x++;
}

return types;
}

private static void FindAllSerializedTypes(
List<SerializedType> types,
INamedTypeSymbol generateSerializerAttribute,
INamedTypeSymbol discriminatedUnionSymbol,
INamedTypeSymbol discriminatedUnionMemberSymbol,
IEnumerable<INamedTypeSymbol> symbolsToBuild)
{
foreach (var type in symbolsToBuild)
{
var generateSerializer = type.GetAttributes()
.Any(x => SymbolEqualityComparer.Default
.Equals(x.AttributeClass, generateSerializerAttribute));

if (!generateSerializer)
continue;

var duDeclaration = type.GetAttributes()
.FirstOrDefault(x => SymbolEqualityComparer.Default
.Equals(x.AttributeClass, discriminatedUnionSymbol));

if (duDeclaration != null)
{
if (duDeclaration
.ConstructorArguments
.FirstOrDefault()
.Value is not string memberName)
throw new InvalidOperationException(
"Failed to get DU discriminator member name");

var member = type.GetMembers(memberName)
.FirstOrDefault(
x => x is IPropertySymbol or IFieldSymbol);

if (member is null)
throw new InvalidOperationException(
"Failed to get DU discriminator member symbol");

types.Add(new DiscriminatedUnionSerializedType(
type, member));

continue;
}

var duMemberDeclaration = type
.GetAttributes()
.FirstOrDefault(x => SymbolEqualityComparer.Default
.Equals(x.AttributeClass,
discriminatedUnionMemberSymbol));

if (duMemberDeclaration != null)
{
if (duMemberDeclaration.ConstructorArguments.Length == 0
|| duMemberDeclaration.ConstructorArguments[0].IsNull)
throw new InvalidOperationException(
"Failed to get DU discriminator value");

types.Add(new DiscriminatedUnionMemberSerializedType(
type, duMemberDeclaration.ConstructorArguments[0]));

continue;
}

types.Add(new SerializedType(type));
}
}
}
}

+ 27
- 0
tools/SourceGenerators/Serialization/SymbolExtensions.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.SymbolStore;
using Microsoft.CodeAnalysis;

namespace Discord.Net.SourceGenerators.Serialization
{
internal static class SymbolExtensions
{
public static IEnumerable<IPropertySymbol> GetProperties(
this INamedTypeSymbol symbol,
bool includeInherited)
{
var s = symbol;
do
{
foreach (var member in s.GetMembers())
if (member is IPropertySymbol property)
yield return property;

s = s.BaseType;
}
while (includeInherited && s != null);
}
}
}

+ 25
- 0
tools/SourceGenerators/Serialization/Utils.cs View File

@@ -0,0 +1,25 @@
using System.Text;

namespace Discord.Net.SourceGenerators.Serialization
{
internal static class Utils
{
private static readonly StringBuilder CaseChangeBuffer = new();

public static string ConvertToSnakeCase(string value)
{
foreach (var c in value)
{
if (char.IsUpper(c) && CaseChangeBuffer.Length > 0)
_ = CaseChangeBuffer.Append('_');

_ = CaseChangeBuffer.Append(char.ToLower(c));
}

var result = CaseChangeBuffer.ToString();
_ = CaseChangeBuffer.Clear();

return result;
}
}
}

+ 59
- 0
tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs View File

@@ -0,0 +1,59 @@
using System.Collections.Generic;
using System.Threading;
using Microsoft.CodeAnalysis;

namespace Discord.Net.SourceGenerators.Serialization
{
internal sealed class VisibleTypeVisitor
: SymbolVisitor
{
private readonly CancellationToken _cancellationToken;
private readonly HashSet<INamedTypeSymbol> _typeSymbols;

public VisibleTypeVisitor(CancellationToken cancellationToken)
{
_cancellationToken = cancellationToken;
_typeSymbols = new(SymbolEqualityComparer.Default);
}

public IEnumerable<INamedTypeSymbol> GetVisibleTypes()
=> _typeSymbols;

public override void VisitAssembly(IAssemblySymbol symbol)
{
_cancellationToken.ThrowIfCancellationRequested();
symbol.GlobalNamespace.Accept(this);
}

public override void VisitNamespace(INamespaceSymbol symbol)
{
foreach (var member in symbol.GetMembers())
{
_cancellationToken.ThrowIfCancellationRequested();
member.Accept(this);
}
}

public override void VisitNamedType(INamedTypeSymbol symbol)
{
_cancellationToken.ThrowIfCancellationRequested();

var isVisible = symbol.DeclaredAccessibility switch
{
Accessibility.Protected => true,
Accessibility.ProtectedOrInternal => true,
Accessibility.Public => true,
_ => false,
};

if (!isVisible || !_typeSymbols.Add(symbol))
return;

foreach (var member in symbol.GetTypeMembers())
{
_cancellationToken.ThrowIfCancellationRequested();
member.Accept(this);
}
}
}
}

Loading…
Cancel
Save