diff --git a/shadowsocks-csharp/Controller/CachedNetworkStream.cs b/shadowsocks-csharp/Controller/CachedNetworkStream.cs
new file mode 100644
index 00000000..00738945
--- /dev/null
+++ b/shadowsocks-csharp/Controller/CachedNetworkStream.cs
@@ -0,0 +1,222 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Net.Sockets;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Shadowsocks.Controller
+{
+ // cache first packet for duty-chain pattern listener
+ public class CachedNetworkStream : Stream
+ {
+ // 256 byte first packet buffer should enough for 99.999...% situation
+ // socks5: 0x05 0x....
+ // http-pac: GET /pac HTTP/1.1
+ // http-proxy: /[a-z]+ .+ HTTP\/1\.[01]/i
+
+ public const int MaxCache = 256;
+
+ public Socket Socket { get; private set; }
+
+ private readonly Stream s;
+
+ private byte[] cache = new byte[MaxCache];
+ private long cachePtr = 0;
+
+ private long readPtr = 0;
+
+ public CachedNetworkStream(Socket socket)
+ {
+ s = new NetworkStream(socket);
+ Socket = socket;
+ }
+
+ ///
+ /// Only for test purpose
+ ///
+ ///
+ public CachedNetworkStream(Stream stream)
+ {
+ s = stream;
+ }
+
+ public override bool CanRead => s.CanRead;
+
+ // we haven't run out of cache
+ public override bool CanSeek => cachePtr == readPtr;
+
+ public override bool CanWrite => s.CanWrite;
+
+ public override long Length => s.Length;
+
+ public override long Position { get => readPtr; set => Seek(value, SeekOrigin.Begin); }
+
+ public override void Flush()
+ {
+ s.Flush();
+ }
+
+ //public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default)
+ //{
+ // var endPtr = buffer.Length + readPtr; // expected ptr after operation
+ // var uncachedLen = Math.Max(endPtr - cachePtr, 0); // how many data from socket
+ // var cachedLen = buffer.Length - uncachedLen; // how many data from cache
+ // var emptyCacheLen = MaxCache - cachePtr; // how many cache remain
+
+ // int readLen = 0;
+ // var cachedMem = buffer[..(int)cachedLen];
+ // var uncachedMem = buffer[(int)cachedLen..];
+ // if (cachedLen > 0)
+ // {
+ // cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedMem);
+
+ // readPtr += cachedLen;
+ // readLen += (int)cachedLen;
+ // }
+ // if (uncachedLen > 0)
+ // {
+ // int readStreamLen = await s.ReadAsync(cachedMem, cancellationToken);
+
+ // int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); // how many data need to cache
+ // if (lengthToCache > 0)
+ // {
+ // uncachedMem[0..lengthToCache].CopyTo(cache[(int)cachePtr..]);
+ // cachePtr += lengthToCache;
+ // }
+
+ // readPtr += readStreamLen;
+ // readLen += readStreamLen;
+ // }
+ // return readLen;
+ //}
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public override int Read(byte[] buffer, int offset, int count)
+ {
+ Span span = buffer.AsSpan(offset, count);
+ return Read(span);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveOptimization)]
+ public override int Read(Span buffer)
+ {
+ // how many data from socket
+
+ // r: readPtr, c: cachePtr, e: endPtr
+ // ptr 0 r c e
+ // cached ####+++++
+ // read ++++
+
+ // ptr 0 c r e
+ // cached #####
+ // read +++++
+
+ var endPtr = buffer.Length + readPtr; // expected ptr after operation
+ var uncachedLen = Math.Max(endPtr - Math.Max(cachePtr, readPtr), 0);
+ var cachedLen = buffer.Length - uncachedLen; // how many data from cache
+ var emptyCacheLen = MaxCache - cachePtr; // how many cache remain
+
+ int readLen = 0;
+
+ Span cachedSpan = buffer[..(int)cachedLen];
+ Span uncachedSpan = buffer[(int)cachedLen..];
+ if (cachedLen > 0)
+ {
+ cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedSpan);
+
+ readPtr += cachedLen;
+ readLen += (int)cachedLen;
+ }
+ if (uncachedLen > 0)
+ {
+ int readStreamLen = s.Read(uncachedSpan);
+
+ // how many data need to cache
+ int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen);
+ if (lengthToCache > 0)
+ {
+ uncachedSpan[0..lengthToCache].ToArray().CopyTo(cache, cachePtr);
+ cachePtr += lengthToCache;
+ }
+
+ readPtr += readStreamLen;
+ readLen += readStreamLen;
+ }
+ return readLen;
+ }
+
+ ///
+ /// Read first block, will never read into non-cache range
+ ///
+ ///
+ ///
+ public int ReadFirstBlock(Span buffer)
+ {
+ Seek(0, SeekOrigin.Begin);
+ int len = Math.Min(MaxCache, buffer.Length);
+ return Read(buffer[0..len]);
+ }
+
+ ///
+ /// Seek position, only support seek to cached range when we haven't read into non-cache range
+ ///
+ ///
+ /// Set it to System.IO.SeekOrigin.Begin, otherwise it will throw System.NotSupportedException
+ ///
+ ///
+ ///
+ ///
+ public override long Seek(long offset, SeekOrigin origin)
+ {
+ if (!CanSeek) throw new NotSupportedException("Non cache data has been read");
+ if (origin != SeekOrigin.Begin) throw new NotSupportedException("We don't know network stream's length");
+ if (offset < 0 || offset > cachePtr) throw new NotSupportedException("Can't seek to uncached position");
+
+ readPtr = offset;
+ return Position;
+ }
+
+ ///
+ /// Useless
+ ///
+ ///
+ ///
+ public override void SetLength(long value)
+ {
+ s.SetLength(value);
+ }
+
+ ///
+ /// Write to underly stream
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ return s.WriteAsync(buffer, offset, count, cancellationToken);
+ }
+
+ ///
+ /// Write to underly stream
+ ///
+ ///
+ ///
+ ///
+ public override void Write(byte[] buffer, int offset, int count)
+ {
+ s.Write(buffer, offset, count);
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ s.Dispose();
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/shadowsocks-csharp/Controller/Service/Listener.cs b/shadowsocks-csharp/Controller/Service/Listener.cs
index 2627587c..6457bad0 100644
--- a/shadowsocks-csharp/Controller/Service/Listener.cs
+++ b/shadowsocks-csharp/Controller/Service/Listener.cs
@@ -24,6 +24,8 @@ namespace Shadowsocks.Controller
{
public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state);
+ public abstract bool Handle(CachedNetworkStream stream, object state);
+
public virtual void Stop() { }
}
diff --git a/shadowsocks-csharp/Controller/Service/PACServer.cs b/shadowsocks-csharp/Controller/Service/PACServer.cs
index 9e02b68d..612d2cce 100644
--- a/shadowsocks-csharp/Controller/Service/PACServer.cs
+++ b/shadowsocks-csharp/Controller/Service/PACServer.cs
@@ -53,6 +53,13 @@ namespace Shadowsocks.Controller
return HttpServerUtilityUrlToken.Encode(CryptoUtils.MD5(Encoding.ASCII.GetBytes(content)));
}
+ public override bool Handle(CachedNetworkStream stream, object state)
+ {
+ byte[] fp = new byte[256];
+ int len = stream.ReadFirstBlock(fp);
+ return Handle(fp, len, stream.Socket, state);
+ }
+
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Tcp)
@@ -154,8 +161,6 @@ namespace Shadowsocks.Controller
}
}
-
-
public void SendResponse(Socket socket, bool useSocks)
{
try
@@ -195,7 +200,6 @@ Connection: Close
{ }
}
-
private string GetPACAddress(IPEndPoint localEndPoint, bool useSocks)
{
return localEndPoint.AddressFamily == AddressFamily.InterNetworkV6
diff --git a/shadowsocks-csharp/Controller/Service/PortForwarder.cs b/shadowsocks-csharp/Controller/Service/PortForwarder.cs
index 8284fd11..3da12086 100644
--- a/shadowsocks-csharp/Controller/Service/PortForwarder.cs
+++ b/shadowsocks-csharp/Controller/Service/PortForwarder.cs
@@ -15,6 +15,13 @@ namespace Shadowsocks.Controller
_targetPort = targetPort;
}
+ public override bool Handle(CachedNetworkStream stream, object state)
+ {
+ byte[] fp = new byte[256];
+ int len = stream.ReadFirstBlock(fp);
+ return Handle(fp, len, stream.Socket, state);
+ }
+
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Tcp)
diff --git a/shadowsocks-csharp/Controller/Service/TCPRelay.cs b/shadowsocks-csharp/Controller/Service/TCPRelay.cs
index fe626dac..53206127 100644
--- a/shadowsocks-csharp/Controller/Service/TCPRelay.cs
+++ b/shadowsocks-csharp/Controller/Service/TCPRelay.cs
@@ -33,6 +33,13 @@ namespace Shadowsocks.Controller
_lastSweepTime = DateTime.Now;
}
+ public override bool Handle(CachedNetworkStream stream, object state)
+ {
+ byte[] fp = new byte[256];
+ int len = stream.ReadFirstBlock(fp);
+ return Handle(fp, len, stream.Socket, state);
+ }
+
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Tcp
diff --git a/shadowsocks-csharp/Controller/Service/UDPRelay.cs b/shadowsocks-csharp/Controller/Service/UDPRelay.cs
index 42a647d9..839a3184 100644
--- a/shadowsocks-csharp/Controller/Service/UDPRelay.cs
+++ b/shadowsocks-csharp/Controller/Service/UDPRelay.cs
@@ -25,6 +25,13 @@ namespace Shadowsocks.Controller
this._controller = controller;
}
+ public override bool Handle(CachedNetworkStream stream, object state)
+ {
+ byte[] fp = new byte[256];
+ int len = stream.ReadFirstBlock(fp);
+ return Handle(fp, len, stream.Socket, state);
+ }
+
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Udp)
diff --git a/test/CachedNetworkStreamTest.cs b/test/CachedNetworkStreamTest.cs
new file mode 100644
index 00000000..c1391866
--- /dev/null
+++ b/test/CachedNetworkStreamTest.cs
@@ -0,0 +1,84 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using Shadowsocks.Controller;
+
+namespace Shadowsocks.Test
+{
+ [TestClass]
+ public class CachedNetworkStreamTest
+ {
+ byte[] b0 = new byte[256];
+ byte[] b1 = new byte[256];
+ byte[] b2 = new byte[1024];
+
+ // [TestInitialize]
+ [TestInitialize]
+ public void init()
+ {
+ for (int i = 0; i < 256; i++)
+ {
+ b0[i] = (byte)i;
+ b1[i] = (byte)(255 - i);
+ }
+
+ b0.CopyTo(b2, 0);
+ b1.CopyTo(b2, 256);
+ b0.CopyTo(b2, 512);
+ }
+
+ [TestMethod]
+ public void StreamTest()
+ {
+ using MemoryStream ms = new MemoryStream(b2);
+ using CachedNetworkStream s = new CachedNetworkStream(ms);
+
+ byte[] o = new byte[128];
+
+ Assert.AreEqual(128, s.Read(o, 0, 128));
+ TestUtils.ArrayEqual(b0[0..128], o);
+
+ Assert.AreEqual(64, s.Read(o, 0, 64));
+ TestUtils.ArrayEqual(b0[128..192], o[0..64]);
+
+ s.Seek(0, SeekOrigin.Begin);
+ Assert.AreEqual(64, s.Read(o, 0, 64));
+ TestUtils.ArrayEqual(b0[0..64], o[0..64]);
+ // refuse to go out of cached range
+ Assert.ThrowsException(() =>
+ {
+ s.Seek(193, SeekOrigin.Begin);
+ });
+ Assert.AreEqual(128, s.Read(o, 0, 128));
+ TestUtils.ArrayEqual(b0[64..192], o);
+
+ Assert.IsTrue(s.CanSeek);
+ Assert.AreEqual(128, s.Read(o, 0, 128));
+ TestUtils.ArrayEqual(b0[192..256], o[0..64]);
+ TestUtils.ArrayEqual(b1[0..64], o[64..128]);
+
+ Assert.IsFalse(s.CanSeek);
+ // refuse to go back when non-cache data has been read
+ Assert.ThrowsException(() =>
+ {
+ s.Seek(0, SeekOrigin.Begin);
+ });
+
+ // read in non-cache range
+ Assert.AreEqual(64, s.Read(o, 0, 64));
+ s.Read(o, 0, 128);
+ Assert.AreEqual(512, s.Position);
+
+ Assert.AreEqual(128, s.Read(o, 0, 128));
+ TestUtils.ArrayEqual(b0[0..128], o);
+ s.Read(o, 0, 128);
+ s.Read(o, 0, 128);
+ s.Read(o, 0, 128);
+
+ // read at eos
+ Assert.AreEqual(0, s.Read(o, 0, 128));
+ }
+ }
+}
diff --git a/test/CryptographyTest.cs b/test/CryptographyTest.cs
index 57e987a9..f90f6f81 100644
--- a/test/CryptographyTest.cs
+++ b/test/CryptographyTest.cs
@@ -42,7 +42,7 @@ namespace Shadowsocks.Test
//encryptor.Encrypt(plain, length, cipher, out int outLen);
//decryptor.Decrypt(cipher, outLen, plain2, out int outLen2);
Assert.AreEqual(length, outLen2);
- ArrayEqual(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray());
+ TestUtils.ArrayEqual(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray());
}
const string password = "barfoo!";
@@ -70,36 +70,6 @@ namespace Shadowsocks.Test
throw;
}
}
-
- private void ArrayEqual(IEnumerable expected, IEnumerable actual, string msg = "")
- {
- var e1 = expected.GetEnumerator();
- var e2 = actual.GetEnumerator();
- int ctr = 0;
- while (true)
- {
- var e1next = e1.MoveNext();
- var e2next = e2.MoveNext();
-
- if (e1next && e2next)
- {
- Assert.AreEqual(e1.Current, e2.Current, "at " + ctr);
- }
- else if (!e1next && !e2next)
- {
- return;
- }
- else if (!e1next)
- {
- Assert.Fail($"actual longer than expected ({ctr}) {msg}");
- }
- else
- {
- Assert.Fail($"actual shorter than expected ({ctr}) {msg}");
- }
- }
- }
-
private static bool encryptionFailed = false;
private void TestEncryptionMethod(Type enc, string method)
diff --git a/test/ShadowsocksTest.csproj b/test/ShadowsocksTest.csproj
index 9c3db5f1..ecda496b 100644
--- a/test/ShadowsocksTest.csproj
+++ b/test/ShadowsocksTest.csproj
@@ -4,6 +4,8 @@
netcoreapp3.1
false
+
+ Shadowsocks.Test
diff --git a/test/TestUtils.cs b/test/TestUtils.cs
new file mode 100644
index 00000000..b7ac979e
--- /dev/null
+++ b/test/TestUtils.cs
@@ -0,0 +1,41 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Shadowsocks.Test
+{
+ class TestUtils
+ {
+ public static void ArrayEqual(IEnumerable expected, IEnumerable actual, string msg = "")
+ {
+ var e1 = expected.GetEnumerator();
+ var e2 = actual.GetEnumerator();
+ int ctr = 0;
+ while (true)
+ {
+ var e1next = e1.MoveNext();
+ var e2next = e2.MoveNext();
+
+ if (e1next && e2next)
+ {
+ Assert.AreEqual(e1.Current, e2.Current, "at " + ctr);
+ }
+ else if (!e1next && !e2next)
+ {
+ return;
+ }
+ else if (!e1next)
+ {
+ Assert.Fail($"actual longer than expected ({ctr}) {msg}");
+ }
+ else
+ {
+ Assert.Fail($"actual shorter than expected ({ctr}) {msg}");
+ }
+ ctr++;
+ }
+ }
+
+ }
+}