using Shadowsocks.Models; using Shadowsocks.Net.Crypto; using Splat; using System; using System.Buffers; using System.Collections.Generic; using System.Net; using System.Net.Sockets; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading.Tasks; namespace Shadowsocks.Net { public class UDPRelay : DatagramService { Server _server; // TODO: choose a smart number private LRUCache _cache = new LRUCache(512); public long outbound = 0; public long inbound = 0; public UDPRelay(Server server) { _server = server; } public override async Task Handle(Memory packet, Socket socket, EndPoint client) { if (socket.ProtocolType != ProtocolType.Udp) { return false; } if (packet.Length < 4) { return false; } IPEndPoint remoteEndPoint = (IPEndPoint)client; UDPHandler handler = _cache.get(remoteEndPoint); if (handler == null) { handler = new UDPHandler(socket, _server, remoteEndPoint); handler.Receive(); _cache.add(remoteEndPoint, handler); } await handler.SendAsync(packet); return true; } public class UDPHandler : IEnableLogger { private static MemoryPool pool = MemoryPool.Shared; private Socket _local; private Socket _remote; private Server _server; private byte[] _buffer = new byte[65536]; private IPEndPoint _localEndPoint; private IPEndPoint _remoteEndPoint; private IPAddress ListenAddress { get { return _remote.AddressFamily switch { AddressFamily.InterNetwork => IPAddress.Any, AddressFamily.InterNetworkV6 => IPAddress.IPv6Any, _ => throw new NotSupportedException(), }; } } public UDPHandler(Socket local, Server server, IPEndPoint localEndPoint) { _local = local; _server = server; _localEndPoint = localEndPoint; // TODO async resolving bool parsed = IPAddress.TryParse(server.Host, out IPAddress ipAddress); if (!parsed) { IPHostEntry ipHostInfo = Dns.GetHostEntry(server.Host); ipAddress = ipHostInfo.AddressList[0]; } _remoteEndPoint = new IPEndPoint(ipAddress, server.Port); _remote = new Socket(_remoteEndPoint.AddressFamily, SocketType.Dgram, ProtocolType.Udp); _remote.Bind(new IPEndPoint(ListenAddress, 0)); } public async Task SendAsync(ReadOnlyMemory data) { using ICrypto encryptor = CryptoFactory.GetEncryptor(_server.Method, _server.Password); using IMemoryOwner mem = pool.Rent(data.Length + 1000); // byte[] dataOut = new byte[slicedData.Length + 1000]; int outlen = encryptor.EncryptUDP(data.Span[3..], mem.Memory.Span); this.Log().Debug($"{_localEndPoint} {_remoteEndPoint} {outlen} UDP Relay up"); if (!MemoryMarshal.TryGetArray(mem.Memory[..outlen], out ArraySegment outData)) { throw new InvalidOperationException("Can't extract underly array segment"); }; await _remote?.SendToAsync(outData, SocketFlags.None, _remoteEndPoint); } public async Task ReceiveAsync() { EndPoint remoteEndPoint = new IPEndPoint(ListenAddress, 0); this.Log().Debug($"++++++Receive Server Port, size:" + _buffer.Length); try { while (true) { var result = await _remote.ReceiveFromAsync(_buffer, SocketFlags.None, remoteEndPoint); int bytesRead = result.ReceivedBytes; using IMemoryOwner owner = pool.Rent(bytesRead + 3); Memory o = owner.Memory; using ICrypto encryptor = CryptoFactory.GetEncryptor(_server.Method, _server.Password); int outlen = encryptor.DecryptUDP(o.Span[3..], _buffer.AsSpan(0, bytesRead)); this.Log().Debug($"{_remoteEndPoint} {_localEndPoint} {outlen} UDP Relay down"); if (!MemoryMarshal.TryGetArray(o[..(outlen + 3)], out ArraySegment data)) { throw new InvalidOperationException("Can't extract underly array segment"); }; await _local?.SendToAsync(data, SocketFlags.None, _localEndPoint); } } catch (Exception e) { this.Log().Warn(e, ""); } } public void Receive() { _ = ReceiveAsync(); } public void Close() { try { _remote?.Close(); } catch (ObjectDisposedException) { // TODO: handle the ObjectDisposedException } catch (Exception) { // TODO: need more think about handle other Exceptions, or should remove this catch(). } } } } #region LRU cache // cc by-sa 3.0 http://stackoverflow.com/a/3719378/1124054 class LRUCache where V : UDPRelay.UDPHandler { private int capacity; private Dictionary>> cacheMap = new Dictionary>>(); private LinkedList> lruList = new LinkedList>(); public LRUCache(int capacity) { this.capacity = capacity; } [MethodImpl(MethodImplOptions.Synchronized)] public V get(K key) { LinkedListNode> node; if (cacheMap.TryGetValue(key, out node)) { V value = node.Value.value; lruList.Remove(node); lruList.AddLast(node); return value; } return default(V); } [MethodImpl(MethodImplOptions.Synchronized)] public void add(K key, V val) { if (cacheMap.Count >= capacity) { RemoveFirst(); } LRUCacheItem cacheItem = new LRUCacheItem(key, val); LinkedListNode> node = new LinkedListNode>(cacheItem); lruList.AddLast(node); cacheMap.Add(key, node); } private void RemoveFirst() { // Remove from LRUPriority LinkedListNode> node = lruList.First; lruList.RemoveFirst(); // Remove from cache cacheMap.Remove(node.Value.key); node.Value.value.Close(); } } class LRUCacheItem { public LRUCacheItem(K k, V v) { key = k; value = v; } public K key; public V value; } #endregion }