From 9a5ebd667163bb04d2d5578ae1144d5ae03555c0 Mon Sep 17 00:00:00 2001 From: Student Main Date: Sun, 29 Mar 2020 23:45:27 +0800 Subject: [PATCH] CachedNetworkStream to provide first chunk cache for Socket --- .../Controller/CachedNetworkStream.cs | 222 +++++++++++++++++++++ shadowsocks-csharp/Controller/Service/Listener.cs | 2 + shadowsocks-csharp/Controller/Service/PACServer.cs | 10 +- .../Controller/Service/PortForwarder.cs | 7 + shadowsocks-csharp/Controller/Service/TCPRelay.cs | 7 + shadowsocks-csharp/Controller/Service/UDPRelay.cs | 7 + test/CachedNetworkStreamTest.cs | 84 ++++++++ test/CryptographyTest.cs | 32 +-- test/ShadowsocksTest.csproj | 2 + test/TestUtils.cs | 41 ++++ 10 files changed, 380 insertions(+), 34 deletions(-) create mode 100644 shadowsocks-csharp/Controller/CachedNetworkStream.cs create mode 100644 test/CachedNetworkStreamTest.cs create mode 100644 test/TestUtils.cs diff --git a/shadowsocks-csharp/Controller/CachedNetworkStream.cs b/shadowsocks-csharp/Controller/CachedNetworkStream.cs new file mode 100644 index 00000000..00738945 --- /dev/null +++ b/shadowsocks-csharp/Controller/CachedNetworkStream.cs @@ -0,0 +1,222 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Shadowsocks.Controller +{ + // cache first packet for duty-chain pattern listener + public class CachedNetworkStream : Stream + { + // 256 byte first packet buffer should enough for 99.999...% situation + // socks5: 0x05 0x.... + // http-pac: GET /pac HTTP/1.1 + // http-proxy: /[a-z]+ .+ HTTP\/1\.[01]/i + + public const int MaxCache = 256; + + public Socket Socket { get; private set; } + + private readonly Stream s; + + private byte[] cache = new byte[MaxCache]; + private long cachePtr = 0; + + private long readPtr = 0; + + public CachedNetworkStream(Socket socket) + { + s = new NetworkStream(socket); + Socket = socket; + } + + /// + /// Only for test purpose + /// + /// + public CachedNetworkStream(Stream stream) + { + s = stream; + } + + public override bool CanRead => s.CanRead; + + // we haven't run out of cache + public override bool CanSeek => cachePtr == readPtr; + + public override bool CanWrite => s.CanWrite; + + public override long Length => s.Length; + + public override long Position { get => readPtr; set => Seek(value, SeekOrigin.Begin); } + + public override void Flush() + { + s.Flush(); + } + + //public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + //{ + // var endPtr = buffer.Length + readPtr; // expected ptr after operation + // var uncachedLen = Math.Max(endPtr - cachePtr, 0); // how many data from socket + // var cachedLen = buffer.Length - uncachedLen; // how many data from cache + // var emptyCacheLen = MaxCache - cachePtr; // how many cache remain + + // int readLen = 0; + // var cachedMem = buffer[..(int)cachedLen]; + // var uncachedMem = buffer[(int)cachedLen..]; + // if (cachedLen > 0) + // { + // cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedMem); + + // readPtr += cachedLen; + // readLen += (int)cachedLen; + // } + // if (uncachedLen > 0) + // { + // int readStreamLen = await s.ReadAsync(cachedMem, cancellationToken); + + // int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); // how many data need to cache + // if (lengthToCache > 0) + // { + // uncachedMem[0..lengthToCache].CopyTo(cache[(int)cachePtr..]); + // cachePtr += lengthToCache; + // } + + // readPtr += readStreamLen; + // readLen += readStreamLen; + // } + // return readLen; + //} + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override int Read(byte[] buffer, int offset, int count) + { + Span span = buffer.AsSpan(offset, count); + return Read(span); + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public override int Read(Span buffer) + { + // how many data from socket + + // r: readPtr, c: cachePtr, e: endPtr + // ptr 0 r c e + // cached ####+++++ + // read ++++ + + // ptr 0 c r e + // cached ##### + // read +++++ + + var endPtr = buffer.Length + readPtr; // expected ptr after operation + var uncachedLen = Math.Max(endPtr - Math.Max(cachePtr, readPtr), 0); + var cachedLen = buffer.Length - uncachedLen; // how many data from cache + var emptyCacheLen = MaxCache - cachePtr; // how many cache remain + + int readLen = 0; + + Span cachedSpan = buffer[..(int)cachedLen]; + Span uncachedSpan = buffer[(int)cachedLen..]; + if (cachedLen > 0) + { + cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedSpan); + + readPtr += cachedLen; + readLen += (int)cachedLen; + } + if (uncachedLen > 0) + { + int readStreamLen = s.Read(uncachedSpan); + + // how many data need to cache + int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); + if (lengthToCache > 0) + { + uncachedSpan[0..lengthToCache].ToArray().CopyTo(cache, cachePtr); + cachePtr += lengthToCache; + } + + readPtr += readStreamLen; + readLen += readStreamLen; + } + return readLen; + } + + /// + /// Read first block, will never read into non-cache range + /// + /// + /// + public int ReadFirstBlock(Span buffer) + { + Seek(0, SeekOrigin.Begin); + int len = Math.Min(MaxCache, buffer.Length); + return Read(buffer[0..len]); + } + + /// + /// Seek position, only support seek to cached range when we haven't read into non-cache range + /// + /// + /// Set it to System.IO.SeekOrigin.Begin, otherwise it will throw System.NotSupportedException + /// + /// + /// + /// + public override long Seek(long offset, SeekOrigin origin) + { + if (!CanSeek) throw new NotSupportedException("Non cache data has been read"); + if (origin != SeekOrigin.Begin) throw new NotSupportedException("We don't know network stream's length"); + if (offset < 0 || offset > cachePtr) throw new NotSupportedException("Can't seek to uncached position"); + + readPtr = offset; + return Position; + } + + /// + /// Useless + /// + /// + /// + public override void SetLength(long value) + { + s.SetLength(value); + } + + /// + /// Write to underly stream + /// + /// + /// + /// + /// + /// + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return s.WriteAsync(buffer, offset, count, cancellationToken); + } + + /// + /// Write to underly stream + /// + /// + /// + /// + public override void Write(byte[] buffer, int offset, int count) + { + s.Write(buffer, offset, count); + } + + protected override void Dispose(bool disposing) + { + s.Dispose(); + base.Dispose(disposing); + } + } +} diff --git a/shadowsocks-csharp/Controller/Service/Listener.cs b/shadowsocks-csharp/Controller/Service/Listener.cs index 2627587c..6457bad0 100644 --- a/shadowsocks-csharp/Controller/Service/Listener.cs +++ b/shadowsocks-csharp/Controller/Service/Listener.cs @@ -24,6 +24,8 @@ namespace Shadowsocks.Controller { public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state); + public abstract bool Handle(CachedNetworkStream stream, object state); + public virtual void Stop() { } } diff --git a/shadowsocks-csharp/Controller/Service/PACServer.cs b/shadowsocks-csharp/Controller/Service/PACServer.cs index 9e02b68d..612d2cce 100644 --- a/shadowsocks-csharp/Controller/Service/PACServer.cs +++ b/shadowsocks-csharp/Controller/Service/PACServer.cs @@ -53,6 +53,13 @@ namespace Shadowsocks.Controller return HttpServerUtilityUrlToken.Encode(CryptoUtils.MD5(Encoding.ASCII.GetBytes(content))); } + public override bool Handle(CachedNetworkStream stream, object state) + { + byte[] fp = new byte[256]; + int len = stream.ReadFirstBlock(fp); + return Handle(fp, len, stream.Socket, state); + } + public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) { if (socket.ProtocolType != ProtocolType.Tcp) @@ -154,8 +161,6 @@ namespace Shadowsocks.Controller } } - - public void SendResponse(Socket socket, bool useSocks) { try @@ -195,7 +200,6 @@ Connection: Close { } } - private string GetPACAddress(IPEndPoint localEndPoint, bool useSocks) { return localEndPoint.AddressFamily == AddressFamily.InterNetworkV6 diff --git a/shadowsocks-csharp/Controller/Service/PortForwarder.cs b/shadowsocks-csharp/Controller/Service/PortForwarder.cs index 8284fd11..3da12086 100644 --- a/shadowsocks-csharp/Controller/Service/PortForwarder.cs +++ b/shadowsocks-csharp/Controller/Service/PortForwarder.cs @@ -15,6 +15,13 @@ namespace Shadowsocks.Controller _targetPort = targetPort; } + public override bool Handle(CachedNetworkStream stream, object state) + { + byte[] fp = new byte[256]; + int len = stream.ReadFirstBlock(fp); + return Handle(fp, len, stream.Socket, state); + } + public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) { if (socket.ProtocolType != ProtocolType.Tcp) diff --git a/shadowsocks-csharp/Controller/Service/TCPRelay.cs b/shadowsocks-csharp/Controller/Service/TCPRelay.cs index fe626dac..53206127 100644 --- a/shadowsocks-csharp/Controller/Service/TCPRelay.cs +++ b/shadowsocks-csharp/Controller/Service/TCPRelay.cs @@ -33,6 +33,13 @@ namespace Shadowsocks.Controller _lastSweepTime = DateTime.Now; } + public override bool Handle(CachedNetworkStream stream, object state) + { + byte[] fp = new byte[256]; + int len = stream.ReadFirstBlock(fp); + return Handle(fp, len, stream.Socket, state); + } + public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) { if (socket.ProtocolType != ProtocolType.Tcp diff --git a/shadowsocks-csharp/Controller/Service/UDPRelay.cs b/shadowsocks-csharp/Controller/Service/UDPRelay.cs index 42a647d9..839a3184 100644 --- a/shadowsocks-csharp/Controller/Service/UDPRelay.cs +++ b/shadowsocks-csharp/Controller/Service/UDPRelay.cs @@ -25,6 +25,13 @@ namespace Shadowsocks.Controller this._controller = controller; } + public override bool Handle(CachedNetworkStream stream, object state) + { + byte[] fp = new byte[256]; + int len = stream.ReadFirstBlock(fp); + return Handle(fp, len, stream.Socket, state); + } + public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) { if (socket.ProtocolType != ProtocolType.Udp) diff --git a/test/CachedNetworkStreamTest.cs b/test/CachedNetworkStreamTest.cs new file mode 100644 index 00000000..c1391866 --- /dev/null +++ b/test/CachedNetworkStreamTest.cs @@ -0,0 +1,84 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Shadowsocks.Controller; + +namespace Shadowsocks.Test +{ + [TestClass] + public class CachedNetworkStreamTest + { + byte[] b0 = new byte[256]; + byte[] b1 = new byte[256]; + byte[] b2 = new byte[1024]; + + // [TestInitialize] + [TestInitialize] + public void init() + { + for (int i = 0; i < 256; i++) + { + b0[i] = (byte)i; + b1[i] = (byte)(255 - i); + } + + b0.CopyTo(b2, 0); + b1.CopyTo(b2, 256); + b0.CopyTo(b2, 512); + } + + [TestMethod] + public void StreamTest() + { + using MemoryStream ms = new MemoryStream(b2); + using CachedNetworkStream s = new CachedNetworkStream(ms); + + byte[] o = new byte[128]; + + Assert.AreEqual(128, s.Read(o, 0, 128)); + TestUtils.ArrayEqual(b0[0..128], o); + + Assert.AreEqual(64, s.Read(o, 0, 64)); + TestUtils.ArrayEqual(b0[128..192], o[0..64]); + + s.Seek(0, SeekOrigin.Begin); + Assert.AreEqual(64, s.Read(o, 0, 64)); + TestUtils.ArrayEqual(b0[0..64], o[0..64]); + // refuse to go out of cached range + Assert.ThrowsException(() => + { + s.Seek(193, SeekOrigin.Begin); + }); + Assert.AreEqual(128, s.Read(o, 0, 128)); + TestUtils.ArrayEqual(b0[64..192], o); + + Assert.IsTrue(s.CanSeek); + Assert.AreEqual(128, s.Read(o, 0, 128)); + TestUtils.ArrayEqual(b0[192..256], o[0..64]); + TestUtils.ArrayEqual(b1[0..64], o[64..128]); + + Assert.IsFalse(s.CanSeek); + // refuse to go back when non-cache data has been read + Assert.ThrowsException(() => + { + s.Seek(0, SeekOrigin.Begin); + }); + + // read in non-cache range + Assert.AreEqual(64, s.Read(o, 0, 64)); + s.Read(o, 0, 128); + Assert.AreEqual(512, s.Position); + + Assert.AreEqual(128, s.Read(o, 0, 128)); + TestUtils.ArrayEqual(b0[0..128], o); + s.Read(o, 0, 128); + s.Read(o, 0, 128); + s.Read(o, 0, 128); + + // read at eos + Assert.AreEqual(0, s.Read(o, 0, 128)); + } + } +} diff --git a/test/CryptographyTest.cs b/test/CryptographyTest.cs index 57e987a9..f90f6f81 100644 --- a/test/CryptographyTest.cs +++ b/test/CryptographyTest.cs @@ -42,7 +42,7 @@ namespace Shadowsocks.Test //encryptor.Encrypt(plain, length, cipher, out int outLen); //decryptor.Decrypt(cipher, outLen, plain2, out int outLen2); Assert.AreEqual(length, outLen2); - ArrayEqual(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray()); + TestUtils.ArrayEqual(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray()); } const string password = "barfoo!"; @@ -70,36 +70,6 @@ namespace Shadowsocks.Test throw; } } - - private void ArrayEqual(IEnumerable expected, IEnumerable actual, string msg = "") - { - var e1 = expected.GetEnumerator(); - var e2 = actual.GetEnumerator(); - int ctr = 0; - while (true) - { - var e1next = e1.MoveNext(); - var e2next = e2.MoveNext(); - - if (e1next && e2next) - { - Assert.AreEqual(e1.Current, e2.Current, "at " + ctr); - } - else if (!e1next && !e2next) - { - return; - } - else if (!e1next) - { - Assert.Fail($"actual longer than expected ({ctr}) {msg}"); - } - else - { - Assert.Fail($"actual shorter than expected ({ctr}) {msg}"); - } - } - } - private static bool encryptionFailed = false; private void TestEncryptionMethod(Type enc, string method) diff --git a/test/ShadowsocksTest.csproj b/test/ShadowsocksTest.csproj index 9c3db5f1..ecda496b 100644 --- a/test/ShadowsocksTest.csproj +++ b/test/ShadowsocksTest.csproj @@ -4,6 +4,8 @@ netcoreapp3.1 false + + Shadowsocks.Test diff --git a/test/TestUtils.cs b/test/TestUtils.cs new file mode 100644 index 00000000..b7ac979e --- /dev/null +++ b/test/TestUtils.cs @@ -0,0 +1,41 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Shadowsocks.Test +{ + class TestUtils + { + public static void ArrayEqual(IEnumerable expected, IEnumerable actual, string msg = "") + { + var e1 = expected.GetEnumerator(); + var e2 = actual.GetEnumerator(); + int ctr = 0; + while (true) + { + var e1next = e1.MoveNext(); + var e2next = e2.MoveNext(); + + if (e1next && e2next) + { + Assert.AreEqual(e1.Current, e2.Current, "at " + ctr); + } + else if (!e1next && !e2next) + { + return; + } + else if (!e1next) + { + Assert.Fail($"actual longer than expected ({ctr}) {msg}"); + } + else + { + Assert.Fail($"actual shorter than expected ({ctr}) {msg}"); + } + ctr++; + } + } + + } +}