From 091cc6bfb145d734fc855e4929d43fed4c00d331 Mon Sep 17 00:00:00 2001 From: FiniteReality Date: Sat, 5 Jun 2021 19:07:34 +0100 Subject: [PATCH] Add serialization source generator --- samples/PingPong/PingPong.csproj | 4 - src/Models/Channel.cs | 241 ++++++++++++++++++++ src/Models/ChannelType.cs | 68 ++++++ src/Models/Discord.Net.Models.csproj | 7 + src/Serialization/DiscriminatedUnionAttribute.cs | 3 + .../DiscriminatedUnionMemberAttribute.cs | 7 +- src/Serialization/GenerateSerializerAttribute.cs | 16 ++ tools/SourceGenerators/Directory.Build.props | 5 + tools/SourceGenerators/IsExternalInit.cs | 4 + ...scord.Net.SourceGenerators.Serialization.csproj | 4 + ...iscord.Net.SourceGenerators.Serialization.props | 6 + ...SerializationSourceGenerator.ConverterSource.cs | 37 --- .../SerializationSourceGenerator.OptionsSource.cs | 46 ++-- .../Serialization/SerializationSourceGenerator.cs | 170 ++++++++------ .../Serialization/Structure/SerializedType.cs | 253 +++++++++++++++++++++ .../Serialization/Structure/SerializedTypeUtils.cs | 119 ++++++++++ .../Serialization/SymbolExtensions.cs | 27 +++ tools/SourceGenerators/Serialization/Utils.cs | 25 ++ .../Serialization/VisibleTypeVisitor.cs | 59 +++++ 19 files changed, 963 insertions(+), 138 deletions(-) create mode 100644 src/Models/Channel.cs create mode 100644 src/Models/ChannelType.cs create mode 100644 src/Serialization/GenerateSerializerAttribute.cs create mode 100644 tools/SourceGenerators/IsExternalInit.cs create mode 100644 tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props delete mode 100644 tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs create mode 100644 tools/SourceGenerators/Serialization/Structure/SerializedType.cs create mode 100644 tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs create mode 100644 tools/SourceGenerators/Serialization/SymbolExtensions.cs create mode 100644 tools/SourceGenerators/Serialization/Utils.cs create mode 100644 tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs diff --git a/samples/PingPong/PingPong.csproj b/samples/PingPong/PingPong.csproj index 51f903439..a11fbc8f1 100644 --- a/samples/PingPong/PingPong.csproj +++ b/samples/PingPong/PingPong.csproj @@ -10,10 +10,6 @@ - - - - diff --git a/src/Models/Channel.cs b/src/Models/Channel.cs new file mode 100644 index 000000000..af9c61452 --- /dev/null +++ b/src/Models/Channel.cs @@ -0,0 +1,241 @@ +using System; +using Discord.Net.Serialization; + +namespace Discord.Net.Models +{ + /// + /// Represents a guild or DM channel within Discord. + /// + /// + /// + /// + /// + /// The id of this channel. + /// + /// + /// The type of channel. + /// + [DiscriminatedUnion(nameof(Channel.Type))] + [GenerateSerializer] + public record Channel( + ChannelType Type, + Snowflake Id); + + /// + /// Represents a text channel within a server. + /// + [DiscriminatedUnionMember(ChannelType.GuildText)] + [GenerateSerializer] + public record GuildTextChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + /*Overwrite[] PermissionOverwrites,*/ + string Name, + string? Topic, + bool Nsfw, + Snowflake LastMessageId, + int RateLimitPerUser, + Snowflake? ParentId, + DateTimeOffset? LastPinTimestamp) + : Channel( + ChannelType.GuildText, + Id); + + /* + + /// + /// Represents a direct message between users. + /// + [DiscriminatedUnionMember(ChannelType.DM)] + [GenerateSerializer] + public record DMChannel( + Snowflake Id, + User[] Recipients) + : Channel( + ChannelType.DM, + Id); + + /// + /// Represents a voice channel within a server. + /// + [DiscriminatedUnionMember(ChannelType.GuildVoice)] + [GenerateSerializer] + public record GuildVoiceChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, + string Name, + bool Nsfw, + int Bitrate, + int UserLimit, + Snowflake? ParentId, + string? RtcRegion, + int VideoQualityMode) + : Channel( + ChannelType.GuildVoice, + Id); + + /// + /// Represents a direct message between multiple users. + /// + [DiscriminatedUnionMember(ChannelType.GroupDM)] + [GenerateSerializer] + public record GroupDMChannel( + Snowflake Id, + string Name, + Snowflake LastMessageId, + User[] Recipients, + string? Icon, + Snowflake? OwnerId, + Snowflake? ApplicationId, + DateTimeOffset? LastPinTimestamp) + : Channel( + ChannelType.GroupDM, + Id); + + /// + /// Represents an organizational category that contains up to 50 channels. + /// + [DiscriminatedUnionMember(ChannelType.GuildCategory)] + [GenerateSerializer] + public record GuildCategoryChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, + string Name) + : Channel( + ChannelType.GuildCategory, + Id); + + /// + /// Represents a channel that users can follow and crosspost into their own + /// server. + /// + [DiscriminatedUnionMember(ChannelType.GuildNews)] + [GenerateSerializer] + public record GuildNewsChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, + string Name, + string? Topic, + bool Nsfw, + Snowflake? LastMessageId, + int RateLimitPerUser, + Snowflake? ParentId, + Snowflake? LastPinTimestamp) + : Channel( + ChannelType.GuildNews, + Id); + + /// + /// Represents a channel in which game developers can sell their game on + /// Discord. + /// + [DiscriminatedUnionMember(ChannelType.GuildStore)] + [GenerateSerializer] + public record GuildStoreChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, // I guess??? + string? Name, + Snowflake? ParentId) + : Channel( + ChannelType.GuildStore, + Id); + + /// + /// Represents a temporary sub-channel within a + /// . + /// + [DiscriminatedUnionMember(ChannelType.GuildNewsThread)] + [GenerateSerializer] + public record GuildNewsThreadChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, // I guess?? + string Name, + Snowflake? LastMessageId, + Snowflake? ParentId, + Snowflake? LastPinTimestamp, + int MessageCount, + int MemberCount, + ThreadMetadata ThreadMetadata, + ThreadMember Member) + : Channel( + ChannelType.GuildNewsThread, + Id); + + /// + /// Represents a temporary sub-channel within a + /// . + /// + [DiscriminatedUnionMember(ChannelType.GuildPublicThread)] + [GenerateSerializer] + public record GuildPublicThreadChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, // I guess?? + string Name, + Snowflake? LastMessageId, + Snowflake? ParentId, + Snowflake? LastPinTimestamp, + int MessageCount, + int MemberCount, + ThreadMetadata ThreadMetadata, + ThreadMember Member) + : Channel( + ChannelType.GuildPublicThread, + Id); + + /// + /// Represents a temporary sub-channel within a + /// that is only viewable by those invited + /// and those with the MANAGE_THREADS permission. + /// + [DiscriminatedUnionMember(ChannelType.GuildPrivateThread)] + [GenerateSerializer] + public record GuildPrivateThreadChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, // I guess??? + string Name, + Snowflake? LastMessageId, + Snowflake? ParentId, + Snowflake? LastPinTimestamp, + int MessageCount, + int MemberCount, + ThreadMetadata ThreadMetadata, + ThreadMember Member) + : Channel( + ChannelType.GuildPrivateThread, + Id); + + /// + /// Represents a voice channel for hosting events with an audience. + /// + [DiscriminatedUnionMember(ChannelType.GuildStageVoice)] + [GenerateSerializer] + public record GuildStageVoiceChannel( + Snowflake Id, + Snowflake GuildId, + int Position, + Overwrite[] PermissionOverwrites, + string Name, + int Bitrate, + int UserLimit, + string? RtcRegion) + : Channel( + ChannelType.GuildStageVoice, + Id); + + */ +} diff --git a/src/Models/ChannelType.cs b/src/Models/ChannelType.cs new file mode 100644 index 000000000..ef050a637 --- /dev/null +++ b/src/Models/ChannelType.cs @@ -0,0 +1,68 @@ +namespace Discord.Net.Models +{ + /// + /// Declares an enum which represents the type of a . + /// + /// + /// + /// + public enum ChannelType + { + /// + /// A text channel within a server. + /// + GuildText = 0, + + /// + /// A direct message between users. + /// + DM = 1, + + /// + /// A voice channel within a server. + /// + GuildVoice = 2, + + /// + /// A direct message between multiple users. + /// + GroupDM = 3, + + /// + /// An organizational category that contains up to 50 channels. + /// + GuildCategory = 4, + + /// + /// A channel that users can follow and crosspost into their own server. + /// + GuildNews = 5, + + /// + /// A channel in which game developers can sell their game on Discord. + /// + GuildStore = 6, + + /// + /// A temporary sub-channel within a channel. + /// + GuildNewsThread = 10, + + /// + /// A temporary sub-channel within a channel. + /// + GuildPublicThread = 11, + + /// + /// A temporary sub-channel within a channel + /// that is only viewable by those invited and those with the + /// MANAGE_THREADS permission. + /// + GuildPrivateThread = 12, + + /// + /// A voice channel for hosting events with an audience. + /// + GuildStageVoice = 13 + } +} diff --git a/src/Models/Discord.Net.Models.csproj b/src/Models/Discord.Net.Models.csproj index 80f81f32c..35a3bb821 100644 --- a/src/Models/Discord.Net.Models.csproj +++ b/src/Models/Discord.Net.Models.csproj @@ -7,6 +7,8 @@ $(Description) Shared models between the Discord REST API and Gateway. + + Discord.Net.Serialization @@ -14,6 +16,11 @@ + + + + + diff --git a/src/Serialization/DiscriminatedUnionAttribute.cs b/src/Serialization/DiscriminatedUnionAttribute.cs index e156d1260..1d2a00d48 100644 --- a/src/Serialization/DiscriminatedUnionAttribute.cs +++ b/src/Serialization/DiscriminatedUnionAttribute.cs @@ -5,6 +5,9 @@ namespace Discord.Net.Serialization /// /// Defines an attribute used to mark discriminated unions. /// + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, + AllowMultiple = false, + Inherited = false)] public class DiscriminatedUnionAttribute : Attribute { /// diff --git a/src/Serialization/DiscriminatedUnionMemberAttribute.cs b/src/Serialization/DiscriminatedUnionMemberAttribute.cs index 619bb4070..4958d5864 100644 --- a/src/Serialization/DiscriminatedUnionMemberAttribute.cs +++ b/src/Serialization/DiscriminatedUnionMemberAttribute.cs @@ -5,12 +5,15 @@ namespace Discord.Net.Serialization /// /// Defines an attribute used to mark members of discriminated unions. /// + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, + AllowMultiple = false, + Inherited = false)] public class DiscriminatedUnionMemberAttribute : Attribute { /// /// Gets the discriminator value used to identify this member type. /// - public string Discriminator { get; } + public object Discriminator { get; } /// /// Creates a new @@ -19,7 +22,7 @@ namespace Discord.Net.Serialization /// /// The discriminator value used to identify this member type. /// - public DiscriminatedUnionMemberAttribute(string discriminator) + public DiscriminatedUnionMemberAttribute(object discriminator) { Discriminator = discriminator; } diff --git a/src/Serialization/GenerateSerializerAttribute.cs b/src/Serialization/GenerateSerializerAttribute.cs new file mode 100644 index 000000000..138095872 --- /dev/null +++ b/src/Serialization/GenerateSerializerAttribute.cs @@ -0,0 +1,16 @@ +using System; + +namespace Discord.Net.Serialization +{ + /// + /// Defines an attribute which informs the serializer generator to generate + /// a serializer for this type. + /// + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, + AllowMultiple = false, + Inherited = false)] + public class GenerateSerializerAttribute : Attribute + { + + } +} diff --git a/tools/SourceGenerators/Directory.Build.props b/tools/SourceGenerators/Directory.Build.props index 262b2c5c9..88a648206 100644 --- a/tools/SourceGenerators/Directory.Build.props +++ b/tools/SourceGenerators/Directory.Build.props @@ -21,6 +21,7 @@ false true + false $(NoWarn);RS2000;RS2001;RS2002;RS2003;RS2004;RS2005;RS2006;RS2007;RS2008 @@ -33,4 +34,8 @@ + + + + diff --git a/tools/SourceGenerators/IsExternalInit.cs b/tools/SourceGenerators/IsExternalInit.cs new file mode 100644 index 000000000..5ead4123a --- /dev/null +++ b/tools/SourceGenerators/IsExternalInit.cs @@ -0,0 +1,4 @@ +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit { } +} diff --git a/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj index 9f5c4f4ab..864db7532 100644 --- a/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj +++ b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj @@ -4,4 +4,8 @@ netstandard2.0 + + + + diff --git a/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props new file mode 100644 index 000000000..4d8e9fb92 --- /dev/null +++ b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props @@ -0,0 +1,6 @@ + + + + + + diff --git a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs deleted file mode 100644 index 8b313aa79..000000000 --- a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs +++ /dev/null @@ -1,37 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace Discord.Net.SourceGenerators.Serialization -{ - public partial class SerializationSourceGenerator - { - private static string GenerateConverter(INamedTypeSymbol @class) - { -return $@" -using System; -using System.Text.Json; -using System.Text.Json.Serialization; - -namespace Discord.Net.Serialization.Converters -{{ - public class {@class.Name}Converter : JsonConverter<{@class.ToDisplayString()}> - {{ - public override {@class.ToDisplayString()} Read( - ref Utf8JsonReader reader, - Type typeToConvert, - JsonSerializerOptions options) - {{ - return default; - }} - - public override void Write( - Utf8JsonWriter writer, - {@class.ToDisplayString()} value, - JsonSerializerOptions options) - {{ - writer.WriteNull(); - }} - }} -}}"; - } - } -} diff --git a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs index 41674de44..b01d3f942 100644 --- a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs +++ b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs @@ -5,20 +5,26 @@ namespace Discord.Net.SourceGenerators.Serialization { public partial class SerializationSourceGenerator { - private static string GenerateSerializerOptionsTemplateSourceCode() + private static string GenerateSerializerOptionsSourceCode( + string @namespace, + IEnumerable converters) { -return @" -using System; + var snippets = string.Join("\n", + converters.Select( + x => $" options.Converters.Add(new {@namespace}.Internal.Converters.{x.ConverterTypeName}());")); + +return $@"using System; using System.Text.Json; +using Discord.Net.Serialization.Converters; -namespace Discord.Net.Serialization -{ +namespace {@namespace} +{{ /// /// Defines extension methods for adding Discord.Net JSON converters to a /// instance. /// - public static partial class JsonSerializerOptionsExtensions - { + public static class JsonSerializerOptionsExtensions + {{ /// /// Adds Discord.Net JSON converters to the passed /// . @@ -30,33 +36,11 @@ namespace Discord.Net.Serialization /// The modified , so this method /// can be chained. /// - public static partial JsonSerializerOptions WithDiscordNetConverters( - this JsonSerializerOptions options); - } -}"; - } - - private static string GenerateSerializerOptionsSourceCode( - List converters) - { - var snippets = string.Join("\n", - converters.Select( - x => $"options.Converters.Add(new {x}());")); - -return $@" -using System; -using System.Text.Json; -using Discord.Net.Serialization.Converters; - -namespace Discord.Net.Serialization -{{ - public static partial class JsonSerializerOptionsExtensions - {{ - public static partial JsonSerializerOptions WithDiscordNetConverters( + public static JsonSerializerOptions WithDiscordNetConverters( this JsonSerializerOptions options) {{ options.Converters.Add(new OptionalConverterFactory()); - {snippets} +{snippets} return options; }} diff --git a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs index 2b3ba8bcd..18297b095 100644 --- a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs +++ b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs @@ -3,7 +3,7 @@ using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Reflection; +using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -14,94 +14,136 @@ namespace Discord.Net.SourceGenerators.Serialization { public void Execute(GeneratorExecutionContext context) { + if (!context.AnalyzerConfigOptions.GlobalOptions.TryGetValue( + "build_property.DiscordNet_SerializationGenerator_OptionsTypeNamespace", + out var serializerOptionsNamespace)) + throw new InvalidOperationException( + "Missing output namespace. Set DiscordNet_SerializationGenerator_OptionsTypeNamespace in your project file."); + + bool searchThroughReferencedAssemblies = + context.AnalyzerConfigOptions.GlobalOptions.TryGetValue( + "build_property.DiscordNet_SerializationGenerator_SearchThroughReferencedAssemblies", + out var _); + + var generateSerializerAttribute = context.Compilation + .GetTypeByMetadataName( + "Discord.Net.Serialization.GenerateSerializerAttribute"); + var discriminatedUnionSymbol = context.Compilation + .GetTypeByMetadataName( + "Discord.Net.Serialization.DiscriminatedUnionAttribute"); + var discriminatedUnionMemberSymbol = context.Compilation + .GetTypeByMetadataName( + "Discord.Net.Serialization.DiscriminatedUnionMemberAttribute"); + + Debug.Assert(generateSerializerAttribute != null); + Debug.Assert(discriminatedUnionSymbol != null); + Debug.Assert(discriminatedUnionMemberSymbol != null); + Debug.Assert(context.SyntaxContextReceiver != null); + var receiver = (SyntaxReceiver)context.SyntaxContextReceiver!; - var converters = new List(); + var symbolsToBuild = receiver.GetSerializedTypes( + context.Compilation); - foreach (var @class in receiver.Classes) + if (searchThroughReferencedAssemblies) { - var semanticModel = context.Compilation.GetSemanticModel( - @class.SyntaxTree); + var visitor = new VisibleTypeVisitor(context.CancellationToken); + foreach (var module in context.Compilation.Assembly.Modules) + foreach (var reference in module.ReferencedAssemblySymbols) + visitor.Visit(reference); - if (semanticModel.GetDeclaredSymbol(@class) is - not INamedTypeSymbol classSymbol) - throw new InvalidOperationException( - "Could not find named type symbol for " + - $"{@class.Identifier}"); + symbolsToBuild = symbolsToBuild + .Concat(visitor.GetVisibleTypes()); + } - context.AddSource( - $"Converters.{classSymbol.Name}", - GenerateConverter(classSymbol)); + var types = SerializedTypeUtils.BuildTypeTrees( + generateSerializerAttribute: generateSerializerAttribute!, + discriminatedUnionSymbol: discriminatedUnionSymbol!, + discriminatedUnionMemberSymbol: discriminatedUnionMemberSymbol!, + symbolsToBuild: symbolsToBuild); - converters.Add($"{classSymbol.Name}Converter"); + foreach (var type in types) + { + context.AddSource($"Converters.{type.ConverterTypeName}", + type.GenerateSourceCode(serializerOptionsNamespace)); + + if (type is DiscriminatedUnionSerializedType duDeclaration) + foreach (var member in duDeclaration.Members) + context.AddSource( + $"Converters.{type.ConverterTypeName}.{member.ConverterTypeName}", + member.GenerateSourceCode(serializerOptionsNamespace)); } - context.AddSource("SerializerOptions.Complete", - GenerateSerializerOptionsSourceCode(converters)); + context.AddSource("SerializerOptions", + GenerateSerializerOptionsSourceCode( + serializerOptionsNamespace, types)); } public void Initialize(GeneratorInitializationContext context) - { - context.RegisterForPostInitialization(PostInitialize); - context.RegisterForSyntaxNotifications(() => new SyntaxReceiver()); - } - - public static void PostInitialize( - GeneratorPostInitializationContext context) - => context.AddSource("SerializerOptions.Template", - GenerateSerializerOptionsTemplateSourceCode()); + => context.RegisterForSyntaxNotifications( + () => new SyntaxReceiver()); - internal class SyntaxReceiver : ISyntaxContextReceiver + private class SyntaxReceiver : ISyntaxContextReceiver { - public List Classes { get; } = new(); - - private readonly Dictionary _interestingAttributes - = new(); + private readonly List _classes; - public void OnVisitSyntaxNode(GeneratorSyntaxContext context) + public SyntaxReceiver() { - _ = GetOrAddAttribute(_interestingAttributes, - context.SemanticModel, - "Discord.Net.Serialization.DiscriminatedUnionAttribute"); - _ = GetOrAddAttribute(_interestingAttributes, - context.SemanticModel, - "Discord.Net.Serialization.DiscriminatedUnionMemberAttribute"); + _classes = new(); + } - if (context.Node is ClassDeclarationSyntax classDecl - && classDecl.AttributeLists is - SyntaxList attrList - && attrList.Any( - list => list.Attributes - .Any(a => IsInterestingAttribute(a, - context.SemanticModel, - _interestingAttributes.Values)))) + public IEnumerable GetSerializedTypes( + Compilation compilation) + { + foreach (var @class in _classes) { - Classes.Add(classDecl); + var semanticModel = compilation.GetSemanticModel( + @class.SyntaxTree); + + if (semanticModel.GetDeclaredSymbol(@class) is + INamedTypeSymbol classSymbol) + yield return classSymbol; } } - private static INamedTypeSymbol GetOrAddAttribute( - Dictionary cache, - SemanticModel model, string name) + private INamedTypeSymbol? _generateSerializerAttributeSymbol; + + public void OnVisitSyntaxNode(GeneratorSyntaxContext context) { - if (!cache.TryGetValue(name, out var type)) + _generateSerializerAttributeSymbol ??= + context.SemanticModel.Compilation.GetTypeByMetadataName( + "Discord.Net.Serialization.GenerateSerializerAttribute"); + + Debug.Assert(_generateSerializerAttributeSymbol != null); + + if (context.Node is ClassDeclarationSyntax classDeclaration + && classDeclaration.AttributeLists is + SyntaxList classAttributeLists + && classAttributeLists.Any( + list => list.Attributes.Any( + n => IsAttribute(n, context.SemanticModel, + _generateSerializerAttributeSymbol!)))) { - type = model.Compilation.GetTypeByMetadataName(name); - Debug.Assert(type != null); - cache.Add(name, type!); + _classes.Add(classDeclaration); + } + else if (context.Node is RecordDeclarationSyntax recordDeclaration + && recordDeclaration.AttributeLists is + SyntaxList recordAttributeLists + && recordAttributeLists.Any( + list => list.Attributes.Any( + n => IsAttribute(n, context.SemanticModel, + _generateSerializerAttributeSymbol!)))) + { + _classes.Add(recordDeclaration); } - return type!; - } - - private static bool IsInterestingAttribute( - AttributeSyntax attribute, SemanticModel model, - IEnumerable interestingAttributes) - { - var typeInfo = model.GetTypeInfo(attribute.Name); + static bool IsAttribute(AttributeSyntax attribute, + SemanticModel model, INamedTypeSymbol expected) + { + var typeInfo = model.GetTypeInfo(attribute.Name); - return interestingAttributes.Any( - x => SymbolEqualityComparer.Default - .Equals(typeInfo.Type, x)); + return SymbolEqualityComparer.Default.Equals( + typeInfo.Type, expected); + } } } } diff --git a/tools/SourceGenerators/Serialization/Structure/SerializedType.cs b/tools/SourceGenerators/Serialization/Structure/SerializedType.cs new file mode 100644 index 000000000..d57c64f83 --- /dev/null +++ b/tools/SourceGenerators/Serialization/Structure/SerializedType.cs @@ -0,0 +1,253 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using static Discord.Net.SourceGenerators.Serialization.Utils; + +namespace Discord.Net.SourceGenerators.Serialization +{ + internal record SerializedType( + INamedTypeSymbol Declaration) + { + public virtual string ConverterTypeName + => $"{Declaration.Name}Converter"; + + protected virtual IEnumerable SymbolsToSerialize + => Declaration.GetProperties(includeInherited: true) + .Where(x => !x.IsReadOnly); + + public virtual string GenerateSourceCode(string outputNamespace) + { + var deserializers = SymbolsToSerialize + .Select(GenerateFieldReader); + + var bytes = string.Join("\n", + deserializers.Select(x => x.utf8)); + var fields = string.Join("\n", + deserializers.Select(x => x.field)); + var readers = string.Join("\n", + deserializers.Select(x => x.reader)); + + var fieldUnassigned = string.Join("\n || ", + deserializers + .Where(x => x.type.NullableAnnotation != NullableAnnotation.Annotated) + .Select( + x => $"{x.snakeCase}OrDefault is not {x.type} {x.snakeCase}")); + + var constructorParams = string.Join(",\n", + deserializers + .Select(x => $" {x.name}: {x.snakeCase}{(x.type.NullableAnnotation == NullableAnnotation.Annotated ? "OrDefault" : "")}")); + +return $@"using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace {outputNamespace}.Converters +{{ + internal class {ConverterTypeName} : JsonConverter<{Declaration.ToDisplayString()}> + {{ +{bytes} + + public override {Declaration.ToDisplayString()}? Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + {{ + if (reader.TokenType != JsonTokenType.StartObject) + throw new JsonException(""Expected StartObject""); + +{fields} + + while (reader.Read()) + {{ + if (reader.TokenType == JsonTokenType.EndObject) + break; + + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(""Expected PropertyName""); + +{readers} + else if (!reader.Read()) + throw new JsonException(); + + if (reader.TokenType == JsonTokenType.StartArray + || reader.TokenType == JsonTokenType.StartObject) + reader.Skip(); + }} + + if ({fieldUnassigned}) + throw new JsonException(""Missing field""); + + return new {Declaration.ToDisplayString()}( +{constructorParams} + ); + }} + + public override void Write( + Utf8JsonWriter writer, + {Declaration.ToDisplayString()} value, + JsonSerializerOptions options) + {{ + writer.WriteNullValue(); + }} + }} +}}"; + + static (string name, ITypeSymbol type, string snakeCase, string utf8, string field, string reader) + GenerateFieldReader(IPropertySymbol member, int position) + { + var needsNullableAnnotation = false; + if (member.Type.IsValueType + && member.Type.OriginalDefinition.SpecialType != SpecialType.System_Nullable_T) + needsNullableAnnotation = true; + + var snakeCase = ConvertToSnakeCase(member.Name); + return (member.Name, member.Type, snakeCase, +$@" private static ReadOnlySpan {member.Name}Bytes => new byte[] + {{ + // {snakeCase} + {string.Join(", ", Encoding.UTF8.GetBytes(snakeCase))} + }};", +$" {member.Type.WithNullableAnnotation(NullableAnnotation.Annotated).ToDisplayString()}{(needsNullableAnnotation ? "?" : "")} {snakeCase}OrDefault = default;", +$@" {(position > 0 ? "else " : "")}if (reader.ValueTextEquals({member.Name}Bytes)) + {{ + if (!reader.Read()) + throw new JsonException(""Expected value""); + + var cvt = options.GetConverter( + typeof({member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()})); + + if (cvt is JsonConverter<{member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}> converter) + {snakeCase}OrDefault = converter.Read(ref reader, + typeof({member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}), + options); + else + {snakeCase}OrDefault = JsonSerializer.Deserialize<{member.Type.ToDisplayString()}>( + ref reader, options); + }}"); + } + } + } + + internal record DiscriminatedUnionSerializedType( + INamedTypeSymbol Declaration, + ISymbol Discriminator) + : SerializedType(Declaration) + { + public List Members { get; } + = new(); + + public override string GenerateSourceCode(string outputNamespace) + { + var discriminatorField = ConvertToSnakeCase(Discriminator.Name); + + var discriminatorType = Discriminator switch + { + IPropertySymbol prop => prop.Type, + IFieldSymbol field => field.Type, + _ => throw new InvalidOperationException( + "Unsupported discriminator member type") + }; + + var switchCaseMembers = string.Join(",\n", + Members.Select( + x => $@" {x.DiscriminatorValue.ToCSharpString()} + => JsonSerializer.Deserialize(ref copy, + typeof({x.Declaration.ToDisplayString()}), options)")); + + return $@"using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace {outputNamespace}.Internal.Converters +{{ + internal class {ConverterTypeName} : JsonConverter<{Declaration.ToDisplayString()}> + {{ + private static ReadOnlySpan DiscriminatorBytes => new byte[] + {{ + // {discriminatorField} + {string.Join(", ", Encoding.UTF8.GetBytes(discriminatorField))} + }}; + + public override {Declaration.ToDisplayString()}? Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + {{ + var copy = reader; + if (reader.TokenType != JsonTokenType.StartObject) + throw new JsonException(""Expected StartObject""); + + {discriminatorType.ToDisplayString()}? discriminator = null; + + while (reader.Read()) + {{ + if (reader.TokenType == JsonTokenType.EndObject) + break; + + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(""Expected PropertyName""); + + if (reader.ValueTextEquals(DiscriminatorBytes)) + {{ + if (!reader.Read()) + throw new JsonException(""Expected value""); + + var cvt = options.GetConverter( + typeof({discriminatorType.ToDisplayString()})); + + if (cvt is JsonConverter<{discriminatorType.ToDisplayString()}> converter) + discriminator = converter.Read(ref reader, + typeof({discriminatorType.ToDisplayString()}), + options); + else + discriminator = JsonSerializer + .Deserialize<{discriminatorType.ToDisplayString()}>( + ref reader, options); + }} + else if (!reader.Read()) + throw new JsonException(""Expected value""); + + if (reader.TokenType == JsonTokenType.StartArray + || reader.TokenType == JsonTokenType.StartObject) + reader.Skip(); + }} + + var result = discriminator switch + {{ +{switchCaseMembers}, + _ => throw new JsonException(""Unknown discriminator value"") + }} as {Declaration.ToDisplayString()}; + + reader = copy; + return result; + }} + + public override void Write( + Utf8JsonWriter writer, + {Declaration.ToDisplayString()} value, + JsonSerializerOptions options) + {{ + writer.WriteNullValue(); + }} + }} +}}"; + } + } + + internal record DiscriminatedUnionMemberSerializedType( + INamedTypeSymbol Declaration, + TypedConstant DiscriminatorValue) + : SerializedType(Declaration) + { + public DiscriminatedUnionSerializedType? DiscriminatedUnionDeclaration + { get; init; } + + protected override IEnumerable SymbolsToSerialize + => base.SymbolsToSerialize + .Where(x => !SymbolEqualityComparer.Default.Equals(x, + DiscriminatedUnionDeclaration?.Discriminator)); + } +} diff --git a/tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs b/tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs new file mode 100644 index 000000000..2b365bd3f --- /dev/null +++ b/tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Discord.Net.SourceGenerators.Serialization +{ + internal static class SerializedTypeUtils + { + public static List BuildTypeTrees( + INamedTypeSymbol generateSerializerAttribute, + INamedTypeSymbol discriminatedUnionSymbol, + INamedTypeSymbol discriminatedUnionMemberSymbol, + IEnumerable symbolsToBuild) + { + var types = new List(); + + FindAllSerializedTypes(types, generateSerializerAttribute, + discriminatedUnionSymbol, discriminatedUnionMemberSymbol, + symbolsToBuild); + + // Now, move DU members into their relevant DU declaration. + int x = 0; + while (x < types.Count) + { + var type = types[x]; + if (type is DiscriminatedUnionMemberSerializedType duMember) + { + var declaration = types.FirstOrDefault( + x => SymbolEqualityComparer.Default.Equals( + x.Declaration, duMember.Declaration.BaseType)); + + if (declaration is not DiscriminatedUnionSerializedType duDeclaration) + throw new InvalidOperationException( + "Could not find DU declaration for DU " + + $"member {duMember.Declaration.ToDisplayString()}"); + + duDeclaration.Members.Add(duMember with + { + DiscriminatedUnionDeclaration = duDeclaration + }); + types.RemoveAt(x); + continue; + } + + x++; + } + + return types; + } + + private static void FindAllSerializedTypes( + List types, + INamedTypeSymbol generateSerializerAttribute, + INamedTypeSymbol discriminatedUnionSymbol, + INamedTypeSymbol discriminatedUnionMemberSymbol, + IEnumerable symbolsToBuild) + { + foreach (var type in symbolsToBuild) + { + var generateSerializer = type.GetAttributes() + .Any(x => SymbolEqualityComparer.Default + .Equals(x.AttributeClass, generateSerializerAttribute)); + + if (!generateSerializer) + continue; + + var duDeclaration = type.GetAttributes() + .FirstOrDefault(x => SymbolEqualityComparer.Default + .Equals(x.AttributeClass, discriminatedUnionSymbol)); + + if (duDeclaration != null) + { + if (duDeclaration + .ConstructorArguments + .FirstOrDefault() + .Value is not string memberName) + throw new InvalidOperationException( + "Failed to get DU discriminator member name"); + + var member = type.GetMembers(memberName) + .FirstOrDefault( + x => x is IPropertySymbol or IFieldSymbol); + + if (member is null) + throw new InvalidOperationException( + "Failed to get DU discriminator member symbol"); + + types.Add(new DiscriminatedUnionSerializedType( + type, member)); + + continue; + } + + var duMemberDeclaration = type + .GetAttributes() + .FirstOrDefault(x => SymbolEqualityComparer.Default + .Equals(x.AttributeClass, + discriminatedUnionMemberSymbol)); + + if (duMemberDeclaration != null) + { + if (duMemberDeclaration.ConstructorArguments.Length == 0 + || duMemberDeclaration.ConstructorArguments[0].IsNull) + throw new InvalidOperationException( + "Failed to get DU discriminator value"); + + types.Add(new DiscriminatedUnionMemberSerializedType( + type, duMemberDeclaration.ConstructorArguments[0])); + + continue; + } + + types.Add(new SerializedType(type)); + } + } + } +} diff --git a/tools/SourceGenerators/Serialization/SymbolExtensions.cs b/tools/SourceGenerators/Serialization/SymbolExtensions.cs new file mode 100644 index 000000000..c3b866bea --- /dev/null +++ b/tools/SourceGenerators/Serialization/SymbolExtensions.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.SymbolStore; +using Microsoft.CodeAnalysis; + +namespace Discord.Net.SourceGenerators.Serialization +{ + internal static class SymbolExtensions + { + public static IEnumerable GetProperties( + this INamedTypeSymbol symbol, + bool includeInherited) + { + var s = symbol; + do + { + foreach (var member in s.GetMembers()) + if (member is IPropertySymbol property) + yield return property; + + s = s.BaseType; + } + while (includeInherited && s != null); + } + } +} diff --git a/tools/SourceGenerators/Serialization/Utils.cs b/tools/SourceGenerators/Serialization/Utils.cs new file mode 100644 index 000000000..e26af1982 --- /dev/null +++ b/tools/SourceGenerators/Serialization/Utils.cs @@ -0,0 +1,25 @@ +using System.Text; + +namespace Discord.Net.SourceGenerators.Serialization +{ + internal static class Utils + { + private static readonly StringBuilder CaseChangeBuffer = new(); + + public static string ConvertToSnakeCase(string value) + { + foreach (var c in value) + { + if (char.IsUpper(c) && CaseChangeBuffer.Length > 0) + _ = CaseChangeBuffer.Append('_'); + + _ = CaseChangeBuffer.Append(char.ToLower(c)); + } + + var result = CaseChangeBuffer.ToString(); + _ = CaseChangeBuffer.Clear(); + + return result; + } + } +} diff --git a/tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs b/tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs new file mode 100644 index 000000000..3adebe575 --- /dev/null +++ b/tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs @@ -0,0 +1,59 @@ +using System.Collections.Generic; +using System.Threading; +using Microsoft.CodeAnalysis; + +namespace Discord.Net.SourceGenerators.Serialization +{ + internal sealed class VisibleTypeVisitor + : SymbolVisitor + { + private readonly CancellationToken _cancellationToken; + private readonly HashSet _typeSymbols; + + public VisibleTypeVisitor(CancellationToken cancellationToken) + { + _cancellationToken = cancellationToken; + _typeSymbols = new(SymbolEqualityComparer.Default); + } + + public IEnumerable GetVisibleTypes() + => _typeSymbols; + + public override void VisitAssembly(IAssemblySymbol symbol) + { + _cancellationToken.ThrowIfCancellationRequested(); + symbol.GlobalNamespace.Accept(this); + } + + public override void VisitNamespace(INamespaceSymbol symbol) + { + foreach (var member in symbol.GetMembers()) + { + _cancellationToken.ThrowIfCancellationRequested(); + member.Accept(this); + } + } + + public override void VisitNamedType(INamedTypeSymbol symbol) + { + _cancellationToken.ThrowIfCancellationRequested(); + + var isVisible = symbol.DeclaredAccessibility switch + { + Accessibility.Protected => true, + Accessibility.ProtectedOrInternal => true, + Accessibility.Public => true, + _ => false, + }; + + if (!isVisible || !_typeSymbols.Add(symbol)) + return; + + foreach (var member in symbol.GetTypeMembers()) + { + _cancellationToken.ThrowIfCancellationRequested(); + member.Accept(this); + } + } + } +}