using System; using System.Diagnostics; using System.Net; using System.Net.Sockets; using System.Text; namespace Shadowsocks.Protocol.Socks5 { public abstract class Socks5Message : IProtocolMessage { public abstract int Serialize(Memory buffer); public abstract (bool success, int length) TryLoad(ReadOnlyMemory buffer); public abstract bool Equals(IProtocolMessage other); #region Socks5 constants public const byte AuthNone = 0; public const byte AuthGssApi = 1; public const byte AuthUserPass = 2; public const byte AuthChallengeHandshake = 3; public const byte AuthChallengeResponse = 5; public const byte AuthSsl = 6; public const byte AuthNds = 7; public const byte AuthMultiAuthenticationFramework = 8; public const byte AuthJsonParameterBlock = 9; public const byte AuthNoAcceptable = 0xff; public const byte AddressIPv4 = 1; public const byte AddressDomain = 3; public const byte AddressIPv6 = 4; public const byte CmdConnect = 1; public const byte CmdBind = 2; public const byte CmdUdpAssociation = 3; public const byte ReplySucceed = 0; public const byte ReplyFailure = 1; public const byte ReplyNotAllowed = 2; public const byte ReplyNetworkUnreachable = 3; public const byte ReplyHostUnreachable = 4; public const byte ReplyConnectionRefused = 5; public const byte ReplyTtlExpired = 6; public const byte ReplyCommandNotSupport = 7; public const byte ReplyAddressNotSupport = 8; #endregion private static readonly NotSupportedException _addressNotSupport = new NotSupportedException("Socks5 only support IPv4, IPv6, Domain name address"); #region Address convert private static (byte high, byte low) ExpandPort(int port) { Debug.Assert(port >= 0 && port <= 65535); return ((byte) (port / 256), (byte) (port % 256)); } private static int TransformPort(byte high, byte low) => high * 256 + low; protected static int NeededBytes(EndPoint endPoint) { switch (endPoint) { case IPEndPoint ipEndPoint when ipEndPoint.AddressFamily == AddressFamily.InterNetwork: return 7; case IPEndPoint ipEndPoint when ipEndPoint.AddressFamily == AddressFamily.InterNetworkV6: return 19; case DnsEndPoint dnsEndPoint: var host = Util.EncodeHostName(dnsEndPoint.Host); return host.Length + 4; default: throw _addressNotSupport; } } public static int SerializeAddress(Memory buffer, EndPoint endPoint) { switch (endPoint) { case IPEndPoint ipEndPoint when ipEndPoint.AddressFamily == AddressFamily.InterNetwork: { if (buffer.Length < 7) throw Util.BufferTooSmall(7, buffer.Length, nameof(buffer)); buffer.Span[0] = AddressIPv4; Debug.Assert(ipEndPoint.Address.TryWriteBytes(buffer.Span[1..], out var l)); Debug.Assert(l == 4); (var high, var low) = ExpandPort(ipEndPoint.Port); buffer.Span[5] = high; buffer.Span[6] = low; return 7; } case IPEndPoint ipEndPoint when ipEndPoint.AddressFamily == AddressFamily.InterNetworkV6: { if (buffer.Length < 19) throw Util.BufferTooSmall(19, buffer.Length, nameof(buffer)); buffer.Span[0] = AddressIPv6; Debug.Assert(ipEndPoint.Address.TryWriteBytes(buffer.Span[1..], out var l)); Debug.Assert(l == 16); (var high, var low) = ExpandPort(ipEndPoint.Port); buffer.Span[18] = low; buffer.Span[17] = high; return 19; } case DnsEndPoint dnsEndPoint: { // 3 lHost [Host] port port var host = Util.EncodeHostName(dnsEndPoint.Host); if (host.Length > 255) throw new NotSupportedException("Host name too long"); if (buffer.Length < host.Length + 4) throw Util.BufferTooSmall(host.Length + 4, buffer.Length, nameof(buffer)); buffer.Span[0] = AddressDomain; buffer.Span[1] = (byte) host.Length; Encoding.ASCII.GetBytes(host, buffer.Span[2..]); (var high, var low) = ExpandPort(dnsEndPoint.Port); buffer.Span[host.Length + 2] = high; buffer.Span[host.Length + 3] = low; return host.Length + 4; } default: throw _addressNotSupport; } } public static (bool success, int length) TryParseAddress(ReadOnlyMemory buffer, out EndPoint result) { result = default; if (buffer.Length < 1) return (false, 1); var addrType = buffer.Span[0]; int len; switch (addrType) { case AddressIPv4: if (buffer.Length < 7) return (false, 7); var s = buffer[1..5]; result = new IPEndPoint( new IPAddress(Util.GetArray(s)), TransformPort(buffer.Span[5], buffer.Span[6]) ); len = 7; break; case AddressDomain: if (buffer.Length < 2) return (false, 2); var nameLength = buffer.Span[1]; if (buffer.Length < nameLength + 4) return (false, nameLength + 4); result = new DnsEndPoint( Encoding.ASCII.GetString(buffer.Span[2..(nameLength + 2)]), TransformPort(buffer.Span[nameLength + 2], buffer.Span[nameLength + 3]) ); len = nameLength + 4; break; case AddressIPv6: if (buffer.Length < 19) return (false, 19); result = new IPEndPoint(new IPAddress(Util.GetArray(buffer[1..17])), TransformPort(buffer.Span[17], buffer.Span[18])); len = 19; break; default: return (false, 0); } return (true, len); } #endregion } }