@@ -0,0 +1,222 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Net.Sockets; | |||
using System.Runtime.CompilerServices; | |||
using System.Text; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
namespace Shadowsocks.Controller | |||
{ | |||
// cache first packet for duty-chain pattern listener | |||
public class CachedNetworkStream : Stream | |||
{ | |||
// 256 byte first packet buffer should enough for 99.999...% situation | |||
// socks5: 0x05 0x.... | |||
// http-pac: GET /pac HTTP/1.1 | |||
// http-proxy: /[a-z]+ .+ HTTP\/1\.[01]/i | |||
public const int MaxCache = 256; | |||
public Socket Socket { get; private set; } | |||
private readonly Stream s; | |||
private byte[] cache = new byte[MaxCache]; | |||
private long cachePtr = 0; | |||
private long readPtr = 0; | |||
public CachedNetworkStream(Socket socket) | |||
{ | |||
s = new NetworkStream(socket); | |||
Socket = socket; | |||
} | |||
/// <summary> | |||
/// Only for test purpose | |||
/// </summary> | |||
/// <param name="stream"></param> | |||
public CachedNetworkStream(Stream stream) | |||
{ | |||
s = stream; | |||
} | |||
public override bool CanRead => s.CanRead; | |||
// we haven't run out of cache | |||
public override bool CanSeek => cachePtr == readPtr; | |||
public override bool CanWrite => s.CanWrite; | |||
public override long Length => s.Length; | |||
public override long Position { get => readPtr; set => Seek(value, SeekOrigin.Begin); } | |||
public override void Flush() | |||
{ | |||
s.Flush(); | |||
} | |||
//public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default) | |||
//{ | |||
// var endPtr = buffer.Length + readPtr; // expected ptr after operation | |||
// var uncachedLen = Math.Max(endPtr - cachePtr, 0); // how many data from socket | |||
// var cachedLen = buffer.Length - uncachedLen; // how many data from cache | |||
// var emptyCacheLen = MaxCache - cachePtr; // how many cache remain | |||
// int readLen = 0; | |||
// var cachedMem = buffer[..(int)cachedLen]; | |||
// var uncachedMem = buffer[(int)cachedLen..]; | |||
// if (cachedLen > 0) | |||
// { | |||
// cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedMem); | |||
// readPtr += cachedLen; | |||
// readLen += (int)cachedLen; | |||
// } | |||
// if (uncachedLen > 0) | |||
// { | |||
// int readStreamLen = await s.ReadAsync(cachedMem, cancellationToken); | |||
// int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); // how many data need to cache | |||
// if (lengthToCache > 0) | |||
// { | |||
// uncachedMem[0..lengthToCache].CopyTo(cache[(int)cachePtr..]); | |||
// cachePtr += lengthToCache; | |||
// } | |||
// readPtr += readStreamLen; | |||
// readLen += readStreamLen; | |||
// } | |||
// return readLen; | |||
//} | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
public override int Read(byte[] buffer, int offset, int count) | |||
{ | |||
Span<byte> span = buffer.AsSpan(offset, count); | |||
return Read(span); | |||
} | |||
[MethodImpl(MethodImplOptions.AggressiveOptimization)] | |||
public override int Read(Span<byte> buffer) | |||
{ | |||
// how many data from socket | |||
// r: readPtr, c: cachePtr, e: endPtr | |||
// ptr 0 r c e | |||
// cached ####+++++ | |||
// read ++++ | |||
// ptr 0 c r e | |||
// cached ##### | |||
// read +++++ | |||
var endPtr = buffer.Length + readPtr; // expected ptr after operation | |||
var uncachedLen = Math.Max(endPtr - Math.Max(cachePtr, readPtr), 0); | |||
var cachedLen = buffer.Length - uncachedLen; // how many data from cache | |||
var emptyCacheLen = MaxCache - cachePtr; // how many cache remain | |||
int readLen = 0; | |||
Span<byte> cachedSpan = buffer[..(int)cachedLen]; | |||
Span<byte> uncachedSpan = buffer[(int)cachedLen..]; | |||
if (cachedLen > 0) | |||
{ | |||
cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedSpan); | |||
readPtr += cachedLen; | |||
readLen += (int)cachedLen; | |||
} | |||
if (uncachedLen > 0) | |||
{ | |||
int readStreamLen = s.Read(uncachedSpan); | |||
// how many data need to cache | |||
int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); | |||
if (lengthToCache > 0) | |||
{ | |||
uncachedSpan[0..lengthToCache].ToArray().CopyTo(cache, cachePtr); | |||
cachePtr += lengthToCache; | |||
} | |||
readPtr += readStreamLen; | |||
readLen += readStreamLen; | |||
} | |||
return readLen; | |||
} | |||
/// <summary> | |||
/// Read first block, will never read into non-cache range | |||
/// </summary> | |||
/// <param name="buffer"></param> | |||
/// <returns></returns> | |||
public int ReadFirstBlock(Span<byte> buffer) | |||
{ | |||
Seek(0, SeekOrigin.Begin); | |||
int len = Math.Min(MaxCache, buffer.Length); | |||
return Read(buffer[0..len]); | |||
} | |||
/// <summary> | |||
/// Seek position, only support seek to cached range when we haven't read into non-cache range | |||
/// </summary> | |||
/// <param name="offset"></param> | |||
/// <param name="origin">Set it to System.IO.SeekOrigin.Begin, otherwise it will throw System.NotSupportedException</param> | |||
/// <exception cref="IOException"></exception> | |||
/// <exception cref="NotSupportedException"></exception> | |||
/// <exception cref="ObjectDisposedException"></exception> | |||
/// <returns></returns> | |||
public override long Seek(long offset, SeekOrigin origin) | |||
{ | |||
if (!CanSeek) throw new NotSupportedException("Non cache data has been read"); | |||
if (origin != SeekOrigin.Begin) throw new NotSupportedException("We don't know network stream's length"); | |||
if (offset < 0 || offset > cachePtr) throw new NotSupportedException("Can't seek to uncached position"); | |||
readPtr = offset; | |||
return Position; | |||
} | |||
/// <summary> | |||
/// Useless | |||
/// </summary> | |||
/// <param name="value"></param> | |||
/// <exception cref="NotSupportedException"></exception> | |||
public override void SetLength(long value) | |||
{ | |||
s.SetLength(value); | |||
} | |||
/// <summary> | |||
/// Write to underly stream | |||
/// </summary> | |||
/// <param name="buffer"></param> | |||
/// <param name="offset"></param> | |||
/// <param name="count"></param> | |||
/// <param name="cancellationToken"></param> | |||
/// <returns></returns> | |||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return s.WriteAsync(buffer, offset, count, cancellationToken); | |||
} | |||
/// <summary> | |||
/// Write to underly stream | |||
/// </summary> | |||
/// <param name="buffer"></param> | |||
/// <param name="offset"></param> | |||
/// <param name="count"></param> | |||
public override void Write(byte[] buffer, int offset, int count) | |||
{ | |||
s.Write(buffer, offset, count); | |||
} | |||
protected override void Dispose(bool disposing) | |||
{ | |||
s.Dispose(); | |||
base.Dispose(disposing); | |||
} | |||
} | |||
} |
@@ -24,6 +24,8 @@ namespace Shadowsocks.Controller | |||
{ | |||
public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state); | |||
public abstract bool Handle(CachedNetworkStream stream, object state); | |||
public virtual void Stop() { } | |||
} | |||
@@ -53,6 +53,13 @@ namespace Shadowsocks.Controller | |||
return HttpServerUtilityUrlToken.Encode(CryptoUtils.MD5(Encoding.ASCII.GetBytes(content))); | |||
} | |||
public override bool Handle(CachedNetworkStream stream, object state) | |||
{ | |||
byte[] fp = new byte[256]; | |||
int len = stream.ReadFirstBlock(fp); | |||
return Handle(fp, len, stream.Socket, state); | |||
} | |||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | |||
{ | |||
if (socket.ProtocolType != ProtocolType.Tcp) | |||
@@ -154,8 +161,6 @@ namespace Shadowsocks.Controller | |||
} | |||
} | |||
public void SendResponse(Socket socket, bool useSocks) | |||
{ | |||
try | |||
@@ -195,7 +200,6 @@ Connection: Close | |||
{ } | |||
} | |||
private string GetPACAddress(IPEndPoint localEndPoint, bool useSocks) | |||
{ | |||
return localEndPoint.AddressFamily == AddressFamily.InterNetworkV6 | |||
@@ -15,6 +15,13 @@ namespace Shadowsocks.Controller | |||
_targetPort = targetPort; | |||
} | |||
public override bool Handle(CachedNetworkStream stream, object state) | |||
{ | |||
byte[] fp = new byte[256]; | |||
int len = stream.ReadFirstBlock(fp); | |||
return Handle(fp, len, stream.Socket, state); | |||
} | |||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | |||
{ | |||
if (socket.ProtocolType != ProtocolType.Tcp) | |||
@@ -33,6 +33,13 @@ namespace Shadowsocks.Controller | |||
_lastSweepTime = DateTime.Now; | |||
} | |||
public override bool Handle(CachedNetworkStream stream, object state) | |||
{ | |||
byte[] fp = new byte[256]; | |||
int len = stream.ReadFirstBlock(fp); | |||
return Handle(fp, len, stream.Socket, state); | |||
} | |||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | |||
{ | |||
if (socket.ProtocolType != ProtocolType.Tcp | |||
@@ -25,6 +25,13 @@ namespace Shadowsocks.Controller | |||
this._controller = controller; | |||
} | |||
public override bool Handle(CachedNetworkStream stream, object state) | |||
{ | |||
byte[] fp = new byte[256]; | |||
int len = stream.ReadFirstBlock(fp); | |||
return Handle(fp, len, stream.Socket, state); | |||
} | |||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | |||
{ | |||
if (socket.ProtocolType != ProtocolType.Udp) | |||
@@ -0,0 +1,84 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
using Shadowsocks.Controller; | |||
namespace Shadowsocks.Test | |||
{ | |||
[TestClass] | |||
public class CachedNetworkStreamTest | |||
{ | |||
byte[] b0 = new byte[256]; | |||
byte[] b1 = new byte[256]; | |||
byte[] b2 = new byte[1024]; | |||
// [TestInitialize] | |||
[TestInitialize] | |||
public void init() | |||
{ | |||
for (int i = 0; i < 256; i++) | |||
{ | |||
b0[i] = (byte)i; | |||
b1[i] = (byte)(255 - i); | |||
} | |||
b0.CopyTo(b2, 0); | |||
b1.CopyTo(b2, 256); | |||
b0.CopyTo(b2, 512); | |||
} | |||
[TestMethod] | |||
public void StreamTest() | |||
{ | |||
using MemoryStream ms = new MemoryStream(b2); | |||
using CachedNetworkStream s = new CachedNetworkStream(ms); | |||
byte[] o = new byte[128]; | |||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||
TestUtils.ArrayEqual(b0[0..128], o); | |||
Assert.AreEqual(64, s.Read(o, 0, 64)); | |||
TestUtils.ArrayEqual(b0[128..192], o[0..64]); | |||
s.Seek(0, SeekOrigin.Begin); | |||
Assert.AreEqual(64, s.Read(o, 0, 64)); | |||
TestUtils.ArrayEqual(b0[0..64], o[0..64]); | |||
// refuse to go out of cached range | |||
Assert.ThrowsException<NotSupportedException>(() => | |||
{ | |||
s.Seek(193, SeekOrigin.Begin); | |||
}); | |||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||
TestUtils.ArrayEqual(b0[64..192], o); | |||
Assert.IsTrue(s.CanSeek); | |||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||
TestUtils.ArrayEqual(b0[192..256], o[0..64]); | |||
TestUtils.ArrayEqual(b1[0..64], o[64..128]); | |||
Assert.IsFalse(s.CanSeek); | |||
// refuse to go back when non-cache data has been read | |||
Assert.ThrowsException<NotSupportedException>(() => | |||
{ | |||
s.Seek(0, SeekOrigin.Begin); | |||
}); | |||
// read in non-cache range | |||
Assert.AreEqual(64, s.Read(o, 0, 64)); | |||
s.Read(o, 0, 128); | |||
Assert.AreEqual(512, s.Position); | |||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||
TestUtils.ArrayEqual(b0[0..128], o); | |||
s.Read(o, 0, 128); | |||
s.Read(o, 0, 128); | |||
s.Read(o, 0, 128); | |||
// read at eos | |||
Assert.AreEqual(0, s.Read(o, 0, 128)); | |||
} | |||
} | |||
} |
@@ -42,7 +42,7 @@ namespace Shadowsocks.Test | |||
//encryptor.Encrypt(plain, length, cipher, out int outLen); | |||
//decryptor.Decrypt(cipher, outLen, plain2, out int outLen2); | |||
Assert.AreEqual(length, outLen2); | |||
ArrayEqual<byte>(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray()); | |||
TestUtils.ArrayEqual<byte>(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray()); | |||
} | |||
const string password = "barfoo!"; | |||
@@ -70,36 +70,6 @@ namespace Shadowsocks.Test | |||
throw; | |||
} | |||
} | |||
private void ArrayEqual<T>(IEnumerable<T> expected, IEnumerable<T> actual, string msg = "") | |||
{ | |||
var e1 = expected.GetEnumerator(); | |||
var e2 = actual.GetEnumerator(); | |||
int ctr = 0; | |||
while (true) | |||
{ | |||
var e1next = e1.MoveNext(); | |||
var e2next = e2.MoveNext(); | |||
if (e1next && e2next) | |||
{ | |||
Assert.AreEqual(e1.Current, e2.Current, "at " + ctr); | |||
} | |||
else if (!e1next && !e2next) | |||
{ | |||
return; | |||
} | |||
else if (!e1next) | |||
{ | |||
Assert.Fail($"actual longer than expected ({ctr}) {msg}"); | |||
} | |||
else | |||
{ | |||
Assert.Fail($"actual shorter than expected ({ctr}) {msg}"); | |||
} | |||
} | |||
} | |||
private static bool encryptionFailed = false; | |||
private void TestEncryptionMethod(Type enc, string method) | |||
@@ -4,6 +4,8 @@ | |||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||
<IsPackable>false</IsPackable> | |||
<RootNamespace>Shadowsocks.Test</RootNamespace> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
@@ -0,0 +1,41 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Shadowsocks.Test | |||
{ | |||
class TestUtils | |||
{ | |||
public static void ArrayEqual<T>(IEnumerable<T> expected, IEnumerable<T> actual, string msg = "") | |||
{ | |||
var e1 = expected.GetEnumerator(); | |||
var e2 = actual.GetEnumerator(); | |||
int ctr = 0; | |||
while (true) | |||
{ | |||
var e1next = e1.MoveNext(); | |||
var e2next = e2.MoveNext(); | |||
if (e1next && e2next) | |||
{ | |||
Assert.AreEqual(e1.Current, e2.Current, "at " + ctr); | |||
} | |||
else if (!e1next && !e2next) | |||
{ | |||
return; | |||
} | |||
else if (!e1next) | |||
{ | |||
Assert.Fail($"actual longer than expected ({ctr}) {msg}"); | |||
} | |||
else | |||
{ | |||
Assert.Fail($"actual shorter than expected ({ctr}) {msg}"); | |||
} | |||
ctr++; | |||
} | |||
} | |||
} | |||
} |