diff --git a/samples/PingPong/PingPong.csproj b/samples/PingPong/PingPong.csproj
index 51f903439..a11fbc8f1 100644
--- a/samples/PingPong/PingPong.csproj
+++ b/samples/PingPong/PingPong.csproj
@@ -10,10 +10,6 @@
-
-
-
-
diff --git a/src/Models/Channel.cs b/src/Models/Channel.cs
new file mode 100644
index 000000000..af9c61452
--- /dev/null
+++ b/src/Models/Channel.cs
@@ -0,0 +1,241 @@
+using System;
+using Discord.Net.Serialization;
+
+namespace Discord.Net.Models
+{
+ ///
+ /// Represents a guild or DM channel within Discord.
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// The id of this channel.
+ ///
+ ///
+ /// The type of channel.
+ ///
+ [DiscriminatedUnion(nameof(Channel.Type))]
+ [GenerateSerializer]
+ public record Channel(
+ ChannelType Type,
+ Snowflake Id);
+
+ ///
+ /// Represents a text channel within a server.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildText)]
+ [GenerateSerializer]
+ public record GuildTextChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ /*Overwrite[] PermissionOverwrites,*/
+ string Name,
+ string? Topic,
+ bool Nsfw,
+ Snowflake LastMessageId,
+ int RateLimitPerUser,
+ Snowflake? ParentId,
+ DateTimeOffset? LastPinTimestamp)
+ : Channel(
+ ChannelType.GuildText,
+ Id);
+
+ /*
+
+ ///
+ /// Represents a direct message between users.
+ ///
+ [DiscriminatedUnionMember(ChannelType.DM)]
+ [GenerateSerializer]
+ public record DMChannel(
+ Snowflake Id,
+ User[] Recipients)
+ : Channel(
+ ChannelType.DM,
+ Id);
+
+ ///
+ /// Represents a voice channel within a server.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildVoice)]
+ [GenerateSerializer]
+ public record GuildVoiceChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites,
+ string Name,
+ bool Nsfw,
+ int Bitrate,
+ int UserLimit,
+ Snowflake? ParentId,
+ string? RtcRegion,
+ int VideoQualityMode)
+ : Channel(
+ ChannelType.GuildVoice,
+ Id);
+
+ ///
+ /// Represents a direct message between multiple users.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GroupDM)]
+ [GenerateSerializer]
+ public record GroupDMChannel(
+ Snowflake Id,
+ string Name,
+ Snowflake LastMessageId,
+ User[] Recipients,
+ string? Icon,
+ Snowflake? OwnerId,
+ Snowflake? ApplicationId,
+ DateTimeOffset? LastPinTimestamp)
+ : Channel(
+ ChannelType.GroupDM,
+ Id);
+
+ ///
+ /// Represents an organizational category that contains up to 50 channels.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildCategory)]
+ [GenerateSerializer]
+ public record GuildCategoryChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites,
+ string Name)
+ : Channel(
+ ChannelType.GuildCategory,
+ Id);
+
+ ///
+ /// Represents a channel that users can follow and crosspost into their own
+ /// server.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildNews)]
+ [GenerateSerializer]
+ public record GuildNewsChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites,
+ string Name,
+ string? Topic,
+ bool Nsfw,
+ Snowflake? LastMessageId,
+ int RateLimitPerUser,
+ Snowflake? ParentId,
+ Snowflake? LastPinTimestamp)
+ : Channel(
+ ChannelType.GuildNews,
+ Id);
+
+ ///
+ /// Represents a channel in which game developers can sell their game on
+ /// Discord.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildStore)]
+ [GenerateSerializer]
+ public record GuildStoreChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites, // I guess???
+ string? Name,
+ Snowflake? ParentId)
+ : Channel(
+ ChannelType.GuildStore,
+ Id);
+
+ ///
+ /// Represents a temporary sub-channel within a
+ /// .
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildNewsThread)]
+ [GenerateSerializer]
+ public record GuildNewsThreadChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites, // I guess??
+ string Name,
+ Snowflake? LastMessageId,
+ Snowflake? ParentId,
+ Snowflake? LastPinTimestamp,
+ int MessageCount,
+ int MemberCount,
+ ThreadMetadata ThreadMetadata,
+ ThreadMember Member)
+ : Channel(
+ ChannelType.GuildNewsThread,
+ Id);
+
+ ///
+ /// Represents a temporary sub-channel within a
+ /// .
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildPublicThread)]
+ [GenerateSerializer]
+ public record GuildPublicThreadChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites, // I guess??
+ string Name,
+ Snowflake? LastMessageId,
+ Snowflake? ParentId,
+ Snowflake? LastPinTimestamp,
+ int MessageCount,
+ int MemberCount,
+ ThreadMetadata ThreadMetadata,
+ ThreadMember Member)
+ : Channel(
+ ChannelType.GuildPublicThread,
+ Id);
+
+ ///
+ /// Represents a temporary sub-channel within a
+ /// that is only viewable by those invited
+ /// and those with the MANAGE_THREADS permission.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildPrivateThread)]
+ [GenerateSerializer]
+ public record GuildPrivateThreadChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites, // I guess???
+ string Name,
+ Snowflake? LastMessageId,
+ Snowflake? ParentId,
+ Snowflake? LastPinTimestamp,
+ int MessageCount,
+ int MemberCount,
+ ThreadMetadata ThreadMetadata,
+ ThreadMember Member)
+ : Channel(
+ ChannelType.GuildPrivateThread,
+ Id);
+
+ ///
+ /// Represents a voice channel for hosting events with an audience.
+ ///
+ [DiscriminatedUnionMember(ChannelType.GuildStageVoice)]
+ [GenerateSerializer]
+ public record GuildStageVoiceChannel(
+ Snowflake Id,
+ Snowflake GuildId,
+ int Position,
+ Overwrite[] PermissionOverwrites,
+ string Name,
+ int Bitrate,
+ int UserLimit,
+ string? RtcRegion)
+ : Channel(
+ ChannelType.GuildStageVoice,
+ Id);
+
+ */
+}
diff --git a/src/Models/ChannelType.cs b/src/Models/ChannelType.cs
new file mode 100644
index 000000000..ef050a637
--- /dev/null
+++ b/src/Models/ChannelType.cs
@@ -0,0 +1,68 @@
+namespace Discord.Net.Models
+{
+ ///
+ /// Declares an enum which represents the type of a .
+ ///
+ ///
+ ///
+ ///
+ public enum ChannelType
+ {
+ ///
+ /// A text channel within a server.
+ ///
+ GuildText = 0,
+
+ ///
+ /// A direct message between users.
+ ///
+ DM = 1,
+
+ ///
+ /// A voice channel within a server.
+ ///
+ GuildVoice = 2,
+
+ ///
+ /// A direct message between multiple users.
+ ///
+ GroupDM = 3,
+
+ ///
+ /// An organizational category that contains up to 50 channels.
+ ///
+ GuildCategory = 4,
+
+ ///
+ /// A channel that users can follow and crosspost into their own server.
+ ///
+ GuildNews = 5,
+
+ ///
+ /// A channel in which game developers can sell their game on Discord.
+ ///
+ GuildStore = 6,
+
+ ///
+ /// A temporary sub-channel within a channel.
+ ///
+ GuildNewsThread = 10,
+
+ ///
+ /// A temporary sub-channel within a channel.
+ ///
+ GuildPublicThread = 11,
+
+ ///
+ /// A temporary sub-channel within a channel
+ /// that is only viewable by those invited and those with the
+ /// MANAGE_THREADS permission.
+ ///
+ GuildPrivateThread = 12,
+
+ ///
+ /// A voice channel for hosting events with an audience.
+ ///
+ GuildStageVoice = 13
+ }
+}
diff --git a/src/Models/Discord.Net.Models.csproj b/src/Models/Discord.Net.Models.csproj
index 80f81f32c..35a3bb821 100644
--- a/src/Models/Discord.Net.Models.csproj
+++ b/src/Models/Discord.Net.Models.csproj
@@ -7,6 +7,8 @@
$(Description)
Shared models between the Discord REST API and Gateway.
+
+ Discord.Net.Serialization
@@ -14,6 +16,11 @@
+
+
+
+
+
diff --git a/src/Serialization/DiscriminatedUnionAttribute.cs b/src/Serialization/DiscriminatedUnionAttribute.cs
index e156d1260..1d2a00d48 100644
--- a/src/Serialization/DiscriminatedUnionAttribute.cs
+++ b/src/Serialization/DiscriminatedUnionAttribute.cs
@@ -5,6 +5,9 @@ namespace Discord.Net.Serialization
///
/// Defines an attribute used to mark discriminated unions.
///
+ [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct,
+ AllowMultiple = false,
+ Inherited = false)]
public class DiscriminatedUnionAttribute : Attribute
{
///
diff --git a/src/Serialization/DiscriminatedUnionMemberAttribute.cs b/src/Serialization/DiscriminatedUnionMemberAttribute.cs
index 619bb4070..4958d5864 100644
--- a/src/Serialization/DiscriminatedUnionMemberAttribute.cs
+++ b/src/Serialization/DiscriminatedUnionMemberAttribute.cs
@@ -5,12 +5,15 @@ namespace Discord.Net.Serialization
///
/// Defines an attribute used to mark members of discriminated unions.
///
+ [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct,
+ AllowMultiple = false,
+ Inherited = false)]
public class DiscriminatedUnionMemberAttribute : Attribute
{
///
/// Gets the discriminator value used to identify this member type.
///
- public string Discriminator { get; }
+ public object Discriminator { get; }
///
/// Creates a new
@@ -19,7 +22,7 @@ namespace Discord.Net.Serialization
///
/// The discriminator value used to identify this member type.
///
- public DiscriminatedUnionMemberAttribute(string discriminator)
+ public DiscriminatedUnionMemberAttribute(object discriminator)
{
Discriminator = discriminator;
}
diff --git a/src/Serialization/GenerateSerializerAttribute.cs b/src/Serialization/GenerateSerializerAttribute.cs
new file mode 100644
index 000000000..138095872
--- /dev/null
+++ b/src/Serialization/GenerateSerializerAttribute.cs
@@ -0,0 +1,16 @@
+using System;
+
+namespace Discord.Net.Serialization
+{
+ ///
+ /// Defines an attribute which informs the serializer generator to generate
+ /// a serializer for this type.
+ ///
+ [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct,
+ AllowMultiple = false,
+ Inherited = false)]
+ public class GenerateSerializerAttribute : Attribute
+ {
+
+ }
+}
diff --git a/tools/SourceGenerators/Directory.Build.props b/tools/SourceGenerators/Directory.Build.props
index 262b2c5c9..88a648206 100644
--- a/tools/SourceGenerators/Directory.Build.props
+++ b/tools/SourceGenerators/Directory.Build.props
@@ -21,6 +21,7 @@
false
true
+ false
$(NoWarn);RS2000;RS2001;RS2002;RS2003;RS2004;RS2005;RS2006;RS2007;RS2008
@@ -33,4 +34,8 @@
+
+
+
+
diff --git a/tools/SourceGenerators/IsExternalInit.cs b/tools/SourceGenerators/IsExternalInit.cs
new file mode 100644
index 000000000..5ead4123a
--- /dev/null
+++ b/tools/SourceGenerators/IsExternalInit.cs
@@ -0,0 +1,4 @@
+namespace System.Runtime.CompilerServices
+{
+ internal static class IsExternalInit { }
+}
diff --git a/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj
index 9f5c4f4ab..864db7532 100644
--- a/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj
+++ b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.csproj
@@ -4,4 +4,8 @@
netstandard2.0
+
+
+
+
diff --git a/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props
new file mode 100644
index 000000000..4d8e9fb92
--- /dev/null
+++ b/tools/SourceGenerators/Serialization/Discord.Net.SourceGenerators.Serialization.props
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs
deleted file mode 100644
index 8b313aa79..000000000
--- a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.ConverterSource.cs
+++ /dev/null
@@ -1,37 +0,0 @@
-using Microsoft.CodeAnalysis;
-
-namespace Discord.Net.SourceGenerators.Serialization
-{
- public partial class SerializationSourceGenerator
- {
- private static string GenerateConverter(INamedTypeSymbol @class)
- {
-return $@"
-using System;
-using System.Text.Json;
-using System.Text.Json.Serialization;
-
-namespace Discord.Net.Serialization.Converters
-{{
- public class {@class.Name}Converter : JsonConverter<{@class.ToDisplayString()}>
- {{
- public override {@class.ToDisplayString()} Read(
- ref Utf8JsonReader reader,
- Type typeToConvert,
- JsonSerializerOptions options)
- {{
- return default;
- }}
-
- public override void Write(
- Utf8JsonWriter writer,
- {@class.ToDisplayString()} value,
- JsonSerializerOptions options)
- {{
- writer.WriteNull();
- }}
- }}
-}}";
- }
- }
-}
diff --git a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs
index 41674de44..b01d3f942 100644
--- a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs
+++ b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.OptionsSource.cs
@@ -5,20 +5,26 @@ namespace Discord.Net.SourceGenerators.Serialization
{
public partial class SerializationSourceGenerator
{
- private static string GenerateSerializerOptionsTemplateSourceCode()
+ private static string GenerateSerializerOptionsSourceCode(
+ string @namespace,
+ IEnumerable converters)
{
-return @"
-using System;
+ var snippets = string.Join("\n",
+ converters.Select(
+ x => $" options.Converters.Add(new {@namespace}.Internal.Converters.{x.ConverterTypeName}());"));
+
+return $@"using System;
using System.Text.Json;
+using Discord.Net.Serialization.Converters;
-namespace Discord.Net.Serialization
-{
+namespace {@namespace}
+{{
///
/// Defines extension methods for adding Discord.Net JSON converters to a
/// instance.
///
- public static partial class JsonSerializerOptionsExtensions
- {
+ public static class JsonSerializerOptionsExtensions
+ {{
///
/// Adds Discord.Net JSON converters to the passed
/// .
@@ -30,33 +36,11 @@ namespace Discord.Net.Serialization
/// The modified , so this method
/// can be chained.
///
- public static partial JsonSerializerOptions WithDiscordNetConverters(
- this JsonSerializerOptions options);
- }
-}";
- }
-
- private static string GenerateSerializerOptionsSourceCode(
- List converters)
- {
- var snippets = string.Join("\n",
- converters.Select(
- x => $"options.Converters.Add(new {x}());"));
-
-return $@"
-using System;
-using System.Text.Json;
-using Discord.Net.Serialization.Converters;
-
-namespace Discord.Net.Serialization
-{{
- public static partial class JsonSerializerOptionsExtensions
- {{
- public static partial JsonSerializerOptions WithDiscordNetConverters(
+ public static JsonSerializerOptions WithDiscordNetConverters(
this JsonSerializerOptions options)
{{
options.Converters.Add(new OptionalConverterFactory());
- {snippets}
+{snippets}
return options;
}}
diff --git a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs
index 2b3ba8bcd..18297b095 100644
--- a/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs
+++ b/tools/SourceGenerators/Serialization/SerializationSourceGenerator.cs
@@ -3,7 +3,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
-using System.Reflection;
+using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
@@ -14,94 +14,136 @@ namespace Discord.Net.SourceGenerators.Serialization
{
public void Execute(GeneratorExecutionContext context)
{
+ if (!context.AnalyzerConfigOptions.GlobalOptions.TryGetValue(
+ "build_property.DiscordNet_SerializationGenerator_OptionsTypeNamespace",
+ out var serializerOptionsNamespace))
+ throw new InvalidOperationException(
+ "Missing output namespace. Set DiscordNet_SerializationGenerator_OptionsTypeNamespace in your project file.");
+
+ bool searchThroughReferencedAssemblies =
+ context.AnalyzerConfigOptions.GlobalOptions.TryGetValue(
+ "build_property.DiscordNet_SerializationGenerator_SearchThroughReferencedAssemblies",
+ out var _);
+
+ var generateSerializerAttribute = context.Compilation
+ .GetTypeByMetadataName(
+ "Discord.Net.Serialization.GenerateSerializerAttribute");
+ var discriminatedUnionSymbol = context.Compilation
+ .GetTypeByMetadataName(
+ "Discord.Net.Serialization.DiscriminatedUnionAttribute");
+ var discriminatedUnionMemberSymbol = context.Compilation
+ .GetTypeByMetadataName(
+ "Discord.Net.Serialization.DiscriminatedUnionMemberAttribute");
+
+ Debug.Assert(generateSerializerAttribute != null);
+ Debug.Assert(discriminatedUnionSymbol != null);
+ Debug.Assert(discriminatedUnionMemberSymbol != null);
+ Debug.Assert(context.SyntaxContextReceiver != null);
+
var receiver = (SyntaxReceiver)context.SyntaxContextReceiver!;
- var converters = new List();
+ var symbolsToBuild = receiver.GetSerializedTypes(
+ context.Compilation);
- foreach (var @class in receiver.Classes)
+ if (searchThroughReferencedAssemblies)
{
- var semanticModel = context.Compilation.GetSemanticModel(
- @class.SyntaxTree);
+ var visitor = new VisibleTypeVisitor(context.CancellationToken);
+ foreach (var module in context.Compilation.Assembly.Modules)
+ foreach (var reference in module.ReferencedAssemblySymbols)
+ visitor.Visit(reference);
- if (semanticModel.GetDeclaredSymbol(@class) is
- not INamedTypeSymbol classSymbol)
- throw new InvalidOperationException(
- "Could not find named type symbol for " +
- $"{@class.Identifier}");
+ symbolsToBuild = symbolsToBuild
+ .Concat(visitor.GetVisibleTypes());
+ }
- context.AddSource(
- $"Converters.{classSymbol.Name}",
- GenerateConverter(classSymbol));
+ var types = SerializedTypeUtils.BuildTypeTrees(
+ generateSerializerAttribute: generateSerializerAttribute!,
+ discriminatedUnionSymbol: discriminatedUnionSymbol!,
+ discriminatedUnionMemberSymbol: discriminatedUnionMemberSymbol!,
+ symbolsToBuild: symbolsToBuild);
- converters.Add($"{classSymbol.Name}Converter");
+ foreach (var type in types)
+ {
+ context.AddSource($"Converters.{type.ConverterTypeName}",
+ type.GenerateSourceCode(serializerOptionsNamespace));
+
+ if (type is DiscriminatedUnionSerializedType duDeclaration)
+ foreach (var member in duDeclaration.Members)
+ context.AddSource(
+ $"Converters.{type.ConverterTypeName}.{member.ConverterTypeName}",
+ member.GenerateSourceCode(serializerOptionsNamespace));
}
- context.AddSource("SerializerOptions.Complete",
- GenerateSerializerOptionsSourceCode(converters));
+ context.AddSource("SerializerOptions",
+ GenerateSerializerOptionsSourceCode(
+ serializerOptionsNamespace, types));
}
public void Initialize(GeneratorInitializationContext context)
- {
- context.RegisterForPostInitialization(PostInitialize);
- context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
- }
-
- public static void PostInitialize(
- GeneratorPostInitializationContext context)
- => context.AddSource("SerializerOptions.Template",
- GenerateSerializerOptionsTemplateSourceCode());
+ => context.RegisterForSyntaxNotifications(
+ () => new SyntaxReceiver());
- internal class SyntaxReceiver : ISyntaxContextReceiver
+ private class SyntaxReceiver : ISyntaxContextReceiver
{
- public List Classes { get; } = new();
-
- private readonly Dictionary _interestingAttributes
- = new();
+ private readonly List _classes;
- public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
+ public SyntaxReceiver()
{
- _ = GetOrAddAttribute(_interestingAttributes,
- context.SemanticModel,
- "Discord.Net.Serialization.DiscriminatedUnionAttribute");
- _ = GetOrAddAttribute(_interestingAttributes,
- context.SemanticModel,
- "Discord.Net.Serialization.DiscriminatedUnionMemberAttribute");
+ _classes = new();
+ }
- if (context.Node is ClassDeclarationSyntax classDecl
- && classDecl.AttributeLists is
- SyntaxList attrList
- && attrList.Any(
- list => list.Attributes
- .Any(a => IsInterestingAttribute(a,
- context.SemanticModel,
- _interestingAttributes.Values))))
+ public IEnumerable GetSerializedTypes(
+ Compilation compilation)
+ {
+ foreach (var @class in _classes)
{
- Classes.Add(classDecl);
+ var semanticModel = compilation.GetSemanticModel(
+ @class.SyntaxTree);
+
+ if (semanticModel.GetDeclaredSymbol(@class) is
+ INamedTypeSymbol classSymbol)
+ yield return classSymbol;
}
}
- private static INamedTypeSymbol GetOrAddAttribute(
- Dictionary cache,
- SemanticModel model, string name)
+ private INamedTypeSymbol? _generateSerializerAttributeSymbol;
+
+ public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
{
- if (!cache.TryGetValue(name, out var type))
+ _generateSerializerAttributeSymbol ??=
+ context.SemanticModel.Compilation.GetTypeByMetadataName(
+ "Discord.Net.Serialization.GenerateSerializerAttribute");
+
+ Debug.Assert(_generateSerializerAttributeSymbol != null);
+
+ if (context.Node is ClassDeclarationSyntax classDeclaration
+ && classDeclaration.AttributeLists is
+ SyntaxList classAttributeLists
+ && classAttributeLists.Any(
+ list => list.Attributes.Any(
+ n => IsAttribute(n, context.SemanticModel,
+ _generateSerializerAttributeSymbol!))))
{
- type = model.Compilation.GetTypeByMetadataName(name);
- Debug.Assert(type != null);
- cache.Add(name, type!);
+ _classes.Add(classDeclaration);
+ }
+ else if (context.Node is RecordDeclarationSyntax recordDeclaration
+ && recordDeclaration.AttributeLists is
+ SyntaxList recordAttributeLists
+ && recordAttributeLists.Any(
+ list => list.Attributes.Any(
+ n => IsAttribute(n, context.SemanticModel,
+ _generateSerializerAttributeSymbol!))))
+ {
+ _classes.Add(recordDeclaration);
}
- return type!;
- }
-
- private static bool IsInterestingAttribute(
- AttributeSyntax attribute, SemanticModel model,
- IEnumerable interestingAttributes)
- {
- var typeInfo = model.GetTypeInfo(attribute.Name);
+ static bool IsAttribute(AttributeSyntax attribute,
+ SemanticModel model, INamedTypeSymbol expected)
+ {
+ var typeInfo = model.GetTypeInfo(attribute.Name);
- return interestingAttributes.Any(
- x => SymbolEqualityComparer.Default
- .Equals(typeInfo.Type, x));
+ return SymbolEqualityComparer.Default.Equals(
+ typeInfo.Type, expected);
+ }
}
}
}
diff --git a/tools/SourceGenerators/Serialization/Structure/SerializedType.cs b/tools/SourceGenerators/Serialization/Structure/SerializedType.cs
new file mode 100644
index 000000000..d57c64f83
--- /dev/null
+++ b/tools/SourceGenerators/Serialization/Structure/SerializedType.cs
@@ -0,0 +1,253 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using static Discord.Net.SourceGenerators.Serialization.Utils;
+
+namespace Discord.Net.SourceGenerators.Serialization
+{
+ internal record SerializedType(
+ INamedTypeSymbol Declaration)
+ {
+ public virtual string ConverterTypeName
+ => $"{Declaration.Name}Converter";
+
+ protected virtual IEnumerable SymbolsToSerialize
+ => Declaration.GetProperties(includeInherited: true)
+ .Where(x => !x.IsReadOnly);
+
+ public virtual string GenerateSourceCode(string outputNamespace)
+ {
+ var deserializers = SymbolsToSerialize
+ .Select(GenerateFieldReader);
+
+ var bytes = string.Join("\n",
+ deserializers.Select(x => x.utf8));
+ var fields = string.Join("\n",
+ deserializers.Select(x => x.field));
+ var readers = string.Join("\n",
+ deserializers.Select(x => x.reader));
+
+ var fieldUnassigned = string.Join("\n || ",
+ deserializers
+ .Where(x => x.type.NullableAnnotation != NullableAnnotation.Annotated)
+ .Select(
+ x => $"{x.snakeCase}OrDefault is not {x.type} {x.snakeCase}"));
+
+ var constructorParams = string.Join(",\n",
+ deserializers
+ .Select(x => $" {x.name}: {x.snakeCase}{(x.type.NullableAnnotation == NullableAnnotation.Annotated ? "OrDefault" : "")}"));
+
+return $@"using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace {outputNamespace}.Converters
+{{
+ internal class {ConverterTypeName} : JsonConverter<{Declaration.ToDisplayString()}>
+ {{
+{bytes}
+
+ public override {Declaration.ToDisplayString()}? Read(
+ ref Utf8JsonReader reader,
+ Type typeToConvert,
+ JsonSerializerOptions options)
+ {{
+ if (reader.TokenType != JsonTokenType.StartObject)
+ throw new JsonException(""Expected StartObject"");
+
+{fields}
+
+ while (reader.Read())
+ {{
+ if (reader.TokenType == JsonTokenType.EndObject)
+ break;
+
+ if (reader.TokenType != JsonTokenType.PropertyName)
+ throw new JsonException(""Expected PropertyName"");
+
+{readers}
+ else if (!reader.Read())
+ throw new JsonException();
+
+ if (reader.TokenType == JsonTokenType.StartArray
+ || reader.TokenType == JsonTokenType.StartObject)
+ reader.Skip();
+ }}
+
+ if ({fieldUnassigned})
+ throw new JsonException(""Missing field"");
+
+ return new {Declaration.ToDisplayString()}(
+{constructorParams}
+ );
+ }}
+
+ public override void Write(
+ Utf8JsonWriter writer,
+ {Declaration.ToDisplayString()} value,
+ JsonSerializerOptions options)
+ {{
+ writer.WriteNullValue();
+ }}
+ }}
+}}";
+
+ static (string name, ITypeSymbol type, string snakeCase, string utf8, string field, string reader)
+ GenerateFieldReader(IPropertySymbol member, int position)
+ {
+ var needsNullableAnnotation = false;
+ if (member.Type.IsValueType
+ && member.Type.OriginalDefinition.SpecialType != SpecialType.System_Nullable_T)
+ needsNullableAnnotation = true;
+
+ var snakeCase = ConvertToSnakeCase(member.Name);
+ return (member.Name, member.Type, snakeCase,
+$@" private static ReadOnlySpan {member.Name}Bytes => new byte[]
+ {{
+ // {snakeCase}
+ {string.Join(", ", Encoding.UTF8.GetBytes(snakeCase))}
+ }};",
+$" {member.Type.WithNullableAnnotation(NullableAnnotation.Annotated).ToDisplayString()}{(needsNullableAnnotation ? "?" : "")} {snakeCase}OrDefault = default;",
+$@" {(position > 0 ? "else " : "")}if (reader.ValueTextEquals({member.Name}Bytes))
+ {{
+ if (!reader.Read())
+ throw new JsonException(""Expected value"");
+
+ var cvt = options.GetConverter(
+ typeof({member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}));
+
+ if (cvt is JsonConverter<{member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}> converter)
+ {snakeCase}OrDefault = converter.Read(ref reader,
+ typeof({member.Type.WithNullableAnnotation(NullableAnnotation.NotAnnotated).ToDisplayString()}),
+ options);
+ else
+ {snakeCase}OrDefault = JsonSerializer.Deserialize<{member.Type.ToDisplayString()}>(
+ ref reader, options);
+ }}");
+ }
+ }
+ }
+
+ internal record DiscriminatedUnionSerializedType(
+ INamedTypeSymbol Declaration,
+ ISymbol Discriminator)
+ : SerializedType(Declaration)
+ {
+ public List Members { get; }
+ = new();
+
+ public override string GenerateSourceCode(string outputNamespace)
+ {
+ var discriminatorField = ConvertToSnakeCase(Discriminator.Name);
+
+ var discriminatorType = Discriminator switch
+ {
+ IPropertySymbol prop => prop.Type,
+ IFieldSymbol field => field.Type,
+ _ => throw new InvalidOperationException(
+ "Unsupported discriminator member type")
+ };
+
+ var switchCaseMembers = string.Join(",\n",
+ Members.Select(
+ x => $@" {x.DiscriminatorValue.ToCSharpString()}
+ => JsonSerializer.Deserialize(ref copy,
+ typeof({x.Declaration.ToDisplayString()}), options)"));
+
+ return $@"using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace {outputNamespace}.Internal.Converters
+{{
+ internal class {ConverterTypeName} : JsonConverter<{Declaration.ToDisplayString()}>
+ {{
+ private static ReadOnlySpan DiscriminatorBytes => new byte[]
+ {{
+ // {discriminatorField}
+ {string.Join(", ", Encoding.UTF8.GetBytes(discriminatorField))}
+ }};
+
+ public override {Declaration.ToDisplayString()}? Read(
+ ref Utf8JsonReader reader,
+ Type typeToConvert,
+ JsonSerializerOptions options)
+ {{
+ var copy = reader;
+ if (reader.TokenType != JsonTokenType.StartObject)
+ throw new JsonException(""Expected StartObject"");
+
+ {discriminatorType.ToDisplayString()}? discriminator = null;
+
+ while (reader.Read())
+ {{
+ if (reader.TokenType == JsonTokenType.EndObject)
+ break;
+
+ if (reader.TokenType != JsonTokenType.PropertyName)
+ throw new JsonException(""Expected PropertyName"");
+
+ if (reader.ValueTextEquals(DiscriminatorBytes))
+ {{
+ if (!reader.Read())
+ throw new JsonException(""Expected value"");
+
+ var cvt = options.GetConverter(
+ typeof({discriminatorType.ToDisplayString()}));
+
+ if (cvt is JsonConverter<{discriminatorType.ToDisplayString()}> converter)
+ discriminator = converter.Read(ref reader,
+ typeof({discriminatorType.ToDisplayString()}),
+ options);
+ else
+ discriminator = JsonSerializer
+ .Deserialize<{discriminatorType.ToDisplayString()}>(
+ ref reader, options);
+ }}
+ else if (!reader.Read())
+ throw new JsonException(""Expected value"");
+
+ if (reader.TokenType == JsonTokenType.StartArray
+ || reader.TokenType == JsonTokenType.StartObject)
+ reader.Skip();
+ }}
+
+ var result = discriminator switch
+ {{
+{switchCaseMembers},
+ _ => throw new JsonException(""Unknown discriminator value"")
+ }} as {Declaration.ToDisplayString()};
+
+ reader = copy;
+ return result;
+ }}
+
+ public override void Write(
+ Utf8JsonWriter writer,
+ {Declaration.ToDisplayString()} value,
+ JsonSerializerOptions options)
+ {{
+ writer.WriteNullValue();
+ }}
+ }}
+}}";
+ }
+ }
+
+ internal record DiscriminatedUnionMemberSerializedType(
+ INamedTypeSymbol Declaration,
+ TypedConstant DiscriminatorValue)
+ : SerializedType(Declaration)
+ {
+ public DiscriminatedUnionSerializedType? DiscriminatedUnionDeclaration
+ { get; init; }
+
+ protected override IEnumerable SymbolsToSerialize
+ => base.SymbolsToSerialize
+ .Where(x => !SymbolEqualityComparer.Default.Equals(x,
+ DiscriminatedUnionDeclaration?.Discriminator));
+ }
+}
diff --git a/tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs b/tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs
new file mode 100644
index 000000000..2b365bd3f
--- /dev/null
+++ b/tools/SourceGenerators/Serialization/Structure/SerializedTypeUtils.cs
@@ -0,0 +1,119 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+
+namespace Discord.Net.SourceGenerators.Serialization
+{
+ internal static class SerializedTypeUtils
+ {
+ public static List BuildTypeTrees(
+ INamedTypeSymbol generateSerializerAttribute,
+ INamedTypeSymbol discriminatedUnionSymbol,
+ INamedTypeSymbol discriminatedUnionMemberSymbol,
+ IEnumerable symbolsToBuild)
+ {
+ var types = new List();
+
+ FindAllSerializedTypes(types, generateSerializerAttribute,
+ discriminatedUnionSymbol, discriminatedUnionMemberSymbol,
+ symbolsToBuild);
+
+ // Now, move DU members into their relevant DU declaration.
+ int x = 0;
+ while (x < types.Count)
+ {
+ var type = types[x];
+ if (type is DiscriminatedUnionMemberSerializedType duMember)
+ {
+ var declaration = types.FirstOrDefault(
+ x => SymbolEqualityComparer.Default.Equals(
+ x.Declaration, duMember.Declaration.BaseType));
+
+ if (declaration is not DiscriminatedUnionSerializedType duDeclaration)
+ throw new InvalidOperationException(
+ "Could not find DU declaration for DU " +
+ $"member {duMember.Declaration.ToDisplayString()}");
+
+ duDeclaration.Members.Add(duMember with
+ {
+ DiscriminatedUnionDeclaration = duDeclaration
+ });
+ types.RemoveAt(x);
+ continue;
+ }
+
+ x++;
+ }
+
+ return types;
+ }
+
+ private static void FindAllSerializedTypes(
+ List types,
+ INamedTypeSymbol generateSerializerAttribute,
+ INamedTypeSymbol discriminatedUnionSymbol,
+ INamedTypeSymbol discriminatedUnionMemberSymbol,
+ IEnumerable symbolsToBuild)
+ {
+ foreach (var type in symbolsToBuild)
+ {
+ var generateSerializer = type.GetAttributes()
+ .Any(x => SymbolEqualityComparer.Default
+ .Equals(x.AttributeClass, generateSerializerAttribute));
+
+ if (!generateSerializer)
+ continue;
+
+ var duDeclaration = type.GetAttributes()
+ .FirstOrDefault(x => SymbolEqualityComparer.Default
+ .Equals(x.AttributeClass, discriminatedUnionSymbol));
+
+ if (duDeclaration != null)
+ {
+ if (duDeclaration
+ .ConstructorArguments
+ .FirstOrDefault()
+ .Value is not string memberName)
+ throw new InvalidOperationException(
+ "Failed to get DU discriminator member name");
+
+ var member = type.GetMembers(memberName)
+ .FirstOrDefault(
+ x => x is IPropertySymbol or IFieldSymbol);
+
+ if (member is null)
+ throw new InvalidOperationException(
+ "Failed to get DU discriminator member symbol");
+
+ types.Add(new DiscriminatedUnionSerializedType(
+ type, member));
+
+ continue;
+ }
+
+ var duMemberDeclaration = type
+ .GetAttributes()
+ .FirstOrDefault(x => SymbolEqualityComparer.Default
+ .Equals(x.AttributeClass,
+ discriminatedUnionMemberSymbol));
+
+ if (duMemberDeclaration != null)
+ {
+ if (duMemberDeclaration.ConstructorArguments.Length == 0
+ || duMemberDeclaration.ConstructorArguments[0].IsNull)
+ throw new InvalidOperationException(
+ "Failed to get DU discriminator value");
+
+ types.Add(new DiscriminatedUnionMemberSerializedType(
+ type, duMemberDeclaration.ConstructorArguments[0]));
+
+ continue;
+ }
+
+ types.Add(new SerializedType(type));
+ }
+ }
+ }
+}
diff --git a/tools/SourceGenerators/Serialization/SymbolExtensions.cs b/tools/SourceGenerators/Serialization/SymbolExtensions.cs
new file mode 100644
index 000000000..c3b866bea
--- /dev/null
+++ b/tools/SourceGenerators/Serialization/SymbolExtensions.cs
@@ -0,0 +1,27 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics.SymbolStore;
+using Microsoft.CodeAnalysis;
+
+namespace Discord.Net.SourceGenerators.Serialization
+{
+ internal static class SymbolExtensions
+ {
+ public static IEnumerable GetProperties(
+ this INamedTypeSymbol symbol,
+ bool includeInherited)
+ {
+ var s = symbol;
+ do
+ {
+ foreach (var member in s.GetMembers())
+ if (member is IPropertySymbol property)
+ yield return property;
+
+ s = s.BaseType;
+ }
+ while (includeInherited && s != null);
+ }
+ }
+}
diff --git a/tools/SourceGenerators/Serialization/Utils.cs b/tools/SourceGenerators/Serialization/Utils.cs
new file mode 100644
index 000000000..e26af1982
--- /dev/null
+++ b/tools/SourceGenerators/Serialization/Utils.cs
@@ -0,0 +1,25 @@
+using System.Text;
+
+namespace Discord.Net.SourceGenerators.Serialization
+{
+ internal static class Utils
+ {
+ private static readonly StringBuilder CaseChangeBuffer = new();
+
+ public static string ConvertToSnakeCase(string value)
+ {
+ foreach (var c in value)
+ {
+ if (char.IsUpper(c) && CaseChangeBuffer.Length > 0)
+ _ = CaseChangeBuffer.Append('_');
+
+ _ = CaseChangeBuffer.Append(char.ToLower(c));
+ }
+
+ var result = CaseChangeBuffer.ToString();
+ _ = CaseChangeBuffer.Clear();
+
+ return result;
+ }
+ }
+}
diff --git a/tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs b/tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs
new file mode 100644
index 000000000..3adebe575
--- /dev/null
+++ b/tools/SourceGenerators/Serialization/VisibleTypeVisitor.cs
@@ -0,0 +1,59 @@
+using System.Collections.Generic;
+using System.Threading;
+using Microsoft.CodeAnalysis;
+
+namespace Discord.Net.SourceGenerators.Serialization
+{
+ internal sealed class VisibleTypeVisitor
+ : SymbolVisitor
+ {
+ private readonly CancellationToken _cancellationToken;
+ private readonly HashSet _typeSymbols;
+
+ public VisibleTypeVisitor(CancellationToken cancellationToken)
+ {
+ _cancellationToken = cancellationToken;
+ _typeSymbols = new(SymbolEqualityComparer.Default);
+ }
+
+ public IEnumerable GetVisibleTypes()
+ => _typeSymbols;
+
+ public override void VisitAssembly(IAssemblySymbol symbol)
+ {
+ _cancellationToken.ThrowIfCancellationRequested();
+ symbol.GlobalNamespace.Accept(this);
+ }
+
+ public override void VisitNamespace(INamespaceSymbol symbol)
+ {
+ foreach (var member in symbol.GetMembers())
+ {
+ _cancellationToken.ThrowIfCancellationRequested();
+ member.Accept(this);
+ }
+ }
+
+ public override void VisitNamedType(INamedTypeSymbol symbol)
+ {
+ _cancellationToken.ThrowIfCancellationRequested();
+
+ var isVisible = symbol.DeclaredAccessibility switch
+ {
+ Accessibility.Protected => true,
+ Accessibility.ProtectedOrInternal => true,
+ Accessibility.Public => true,
+ _ => false,
+ };
+
+ if (!isVisible || !_typeSymbols.Add(symbol))
+ return;
+
+ foreach (var member in symbol.GetTypeMembers())
+ {
+ _cancellationToken.ThrowIfCancellationRequested();
+ member.Accept(this);
+ }
+ }
+ }
+}