Browse Source

CachedNetworkStream to provide first chunk cache for Socket

pull/2865/head
Student Main 5 years ago
parent
commit
9a5ebd6671
10 changed files with 380 additions and 34 deletions
  1. +222
    -0
      shadowsocks-csharp/Controller/CachedNetworkStream.cs
  2. +2
    -0
      shadowsocks-csharp/Controller/Service/Listener.cs
  3. +7
    -3
      shadowsocks-csharp/Controller/Service/PACServer.cs
  4. +7
    -0
      shadowsocks-csharp/Controller/Service/PortForwarder.cs
  5. +7
    -0
      shadowsocks-csharp/Controller/Service/TCPRelay.cs
  6. +7
    -0
      shadowsocks-csharp/Controller/Service/UDPRelay.cs
  7. +84
    -0
      test/CachedNetworkStreamTest.cs
  8. +1
    -31
      test/CryptographyTest.cs
  9. +2
    -0
      test/ShadowsocksTest.csproj
  10. +41
    -0
      test/TestUtils.cs

+ 222
- 0
shadowsocks-csharp/Controller/CachedNetworkStream.cs View File

@@ -0,0 +1,222 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Shadowsocks.Controller
{
// cache first packet for duty-chain pattern listener
public class CachedNetworkStream : Stream
{
// 256 byte first packet buffer should enough for 99.999...% situation
// socks5: 0x05 0x....
// http-pac: GET /pac HTTP/1.1
// http-proxy: /[a-z]+ .+ HTTP\/1\.[01]/i

public const int MaxCache = 256;

public Socket Socket { get; private set; }

private readonly Stream s;

private byte[] cache = new byte[MaxCache];
private long cachePtr = 0;

private long readPtr = 0;

public CachedNetworkStream(Socket socket)
{
s = new NetworkStream(socket);
Socket = socket;
}

/// <summary>
/// Only for test purpose
/// </summary>
/// <param name="stream"></param>
public CachedNetworkStream(Stream stream)
{
s = stream;
}

public override bool CanRead => s.CanRead;

// we haven't run out of cache
public override bool CanSeek => cachePtr == readPtr;

public override bool CanWrite => s.CanWrite;

public override long Length => s.Length;

public override long Position { get => readPtr; set => Seek(value, SeekOrigin.Begin); }

public override void Flush()
{
s.Flush();
}

//public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
//{
// var endPtr = buffer.Length + readPtr; // expected ptr after operation
// var uncachedLen = Math.Max(endPtr - cachePtr, 0); // how many data from socket
// var cachedLen = buffer.Length - uncachedLen; // how many data from cache
// var emptyCacheLen = MaxCache - cachePtr; // how many cache remain

// int readLen = 0;
// var cachedMem = buffer[..(int)cachedLen];
// var uncachedMem = buffer[(int)cachedLen..];
// if (cachedLen > 0)
// {
// cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedMem);

// readPtr += cachedLen;
// readLen += (int)cachedLen;
// }
// if (uncachedLen > 0)
// {
// int readStreamLen = await s.ReadAsync(cachedMem, cancellationToken);

// int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); // how many data need to cache
// if (lengthToCache > 0)
// {
// uncachedMem[0..lengthToCache].CopyTo(cache[(int)cachePtr..]);
// cachePtr += lengthToCache;
// }

// readPtr += readStreamLen;
// readLen += readStreamLen;
// }
// return readLen;
//}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override int Read(byte[] buffer, int offset, int count)
{
Span<byte> span = buffer.AsSpan(offset, count);
return Read(span);
}

[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public override int Read(Span<byte> buffer)
{
// how many data from socket

// r: readPtr, c: cachePtr, e: endPtr
// ptr 0 r c e
// cached ####+++++
// read ++++

// ptr 0 c r e
// cached #####
// read +++++

var endPtr = buffer.Length + readPtr; // expected ptr after operation
var uncachedLen = Math.Max(endPtr - Math.Max(cachePtr, readPtr), 0);
var cachedLen = buffer.Length - uncachedLen; // how many data from cache
var emptyCacheLen = MaxCache - cachePtr; // how many cache remain

int readLen = 0;

Span<byte> cachedSpan = buffer[..(int)cachedLen];
Span<byte> uncachedSpan = buffer[(int)cachedLen..];
if (cachedLen > 0)
{
cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedSpan);

readPtr += cachedLen;
readLen += (int)cachedLen;
}
if (uncachedLen > 0)
{
int readStreamLen = s.Read(uncachedSpan);

// how many data need to cache
int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen);
if (lengthToCache > 0)
{
uncachedSpan[0..lengthToCache].ToArray().CopyTo(cache, cachePtr);
cachePtr += lengthToCache;
}

readPtr += readStreamLen;
readLen += readStreamLen;
}
return readLen;
}

/// <summary>
/// Read first block, will never read into non-cache range
/// </summary>
/// <param name="buffer"></param>
/// <returns></returns>
public int ReadFirstBlock(Span<byte> buffer)
{
Seek(0, SeekOrigin.Begin);
int len = Math.Min(MaxCache, buffer.Length);
return Read(buffer[0..len]);
}

/// <summary>
/// Seek position, only support seek to cached range when we haven't read into non-cache range
/// </summary>
/// <param name="offset"></param>
/// <param name="origin">Set it to System.IO.SeekOrigin.Begin, otherwise it will throw System.NotSupportedException</param>
/// <exception cref="IOException"></exception>
/// <exception cref="NotSupportedException"></exception>
/// <exception cref="ObjectDisposedException"></exception>
/// <returns></returns>
public override long Seek(long offset, SeekOrigin origin)
{
if (!CanSeek) throw new NotSupportedException("Non cache data has been read");
if (origin != SeekOrigin.Begin) throw new NotSupportedException("We don't know network stream's length");
if (offset < 0 || offset > cachePtr) throw new NotSupportedException("Can't seek to uncached position");

readPtr = offset;
return Position;
}

/// <summary>
/// Useless
/// </summary>
/// <param name="value"></param>
/// <exception cref="NotSupportedException"></exception>
public override void SetLength(long value)
{
s.SetLength(value);
}

/// <summary>
/// Write to underly stream
/// </summary>
/// <param name="buffer"></param>
/// <param name="offset"></param>
/// <param name="count"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return s.WriteAsync(buffer, offset, count, cancellationToken);
}

/// <summary>
/// Write to underly stream
/// </summary>
/// <param name="buffer"></param>
/// <param name="offset"></param>
/// <param name="count"></param>
public override void Write(byte[] buffer, int offset, int count)
{
s.Write(buffer, offset, count);
}

protected override void Dispose(bool disposing)
{
s.Dispose();
base.Dispose(disposing);
}
}
}

+ 2
- 0
shadowsocks-csharp/Controller/Service/Listener.cs View File

@@ -24,6 +24,8 @@ namespace Shadowsocks.Controller
{
public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state);
public abstract bool Handle(CachedNetworkStream stream, object state);
public virtual void Stop() { }
}


+ 7
- 3
shadowsocks-csharp/Controller/Service/PACServer.cs View File

@@ -53,6 +53,13 @@ namespace Shadowsocks.Controller
return HttpServerUtilityUrlToken.Encode(CryptoUtils.MD5(Encoding.ASCII.GetBytes(content)));
}
public override bool Handle(CachedNetworkStream stream, object state)
{
byte[] fp = new byte[256];
int len = stream.ReadFirstBlock(fp);
return Handle(fp, len, stream.Socket, state);
}
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Tcp)
@@ -154,8 +161,6 @@ namespace Shadowsocks.Controller
}
}
public void SendResponse(Socket socket, bool useSocks)
{
try
@@ -195,7 +200,6 @@ Connection: Close
{ }
}
private string GetPACAddress(IPEndPoint localEndPoint, bool useSocks)
{
return localEndPoint.AddressFamily == AddressFamily.InterNetworkV6


+ 7
- 0
shadowsocks-csharp/Controller/Service/PortForwarder.cs View File

@@ -15,6 +15,13 @@ namespace Shadowsocks.Controller
_targetPort = targetPort;
}
public override bool Handle(CachedNetworkStream stream, object state)
{
byte[] fp = new byte[256];
int len = stream.ReadFirstBlock(fp);
return Handle(fp, len, stream.Socket, state);
}
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Tcp)


+ 7
- 0
shadowsocks-csharp/Controller/Service/TCPRelay.cs View File

@@ -33,6 +33,13 @@ namespace Shadowsocks.Controller
_lastSweepTime = DateTime.Now;
}
public override bool Handle(CachedNetworkStream stream, object state)
{
byte[] fp = new byte[256];
int len = stream.ReadFirstBlock(fp);
return Handle(fp, len, stream.Socket, state);
}
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Tcp


+ 7
- 0
shadowsocks-csharp/Controller/Service/UDPRelay.cs View File

@@ -25,6 +25,13 @@ namespace Shadowsocks.Controller
this._controller = controller;
}
public override bool Handle(CachedNetworkStream stream, object state)
{
byte[] fp = new byte[256];
int len = stream.ReadFirstBlock(fp);
return Handle(fp, len, stream.Socket, state);
}
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{
if (socket.ProtocolType != ProtocolType.Udp)


+ 84
- 0
test/CachedNetworkStreamTest.cs View File

@@ -0,0 +1,84 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Shadowsocks.Controller;

namespace Shadowsocks.Test
{
[TestClass]
public class CachedNetworkStreamTest
{
byte[] b0 = new byte[256];
byte[] b1 = new byte[256];
byte[] b2 = new byte[1024];

// [TestInitialize]
[TestInitialize]
public void init()
{
for (int i = 0; i < 256; i++)
{
b0[i] = (byte)i;
b1[i] = (byte)(255 - i);
}

b0.CopyTo(b2, 0);
b1.CopyTo(b2, 256);
b0.CopyTo(b2, 512);
}

[TestMethod]
public void StreamTest()
{
using MemoryStream ms = new MemoryStream(b2);
using CachedNetworkStream s = new CachedNetworkStream(ms);

byte[] o = new byte[128];

Assert.AreEqual(128, s.Read(o, 0, 128));
TestUtils.ArrayEqual(b0[0..128], o);

Assert.AreEqual(64, s.Read(o, 0, 64));
TestUtils.ArrayEqual(b0[128..192], o[0..64]);

s.Seek(0, SeekOrigin.Begin);
Assert.AreEqual(64, s.Read(o, 0, 64));
TestUtils.ArrayEqual(b0[0..64], o[0..64]);
// refuse to go out of cached range
Assert.ThrowsException<NotSupportedException>(() =>
{
s.Seek(193, SeekOrigin.Begin);
});
Assert.AreEqual(128, s.Read(o, 0, 128));
TestUtils.ArrayEqual(b0[64..192], o);

Assert.IsTrue(s.CanSeek);
Assert.AreEqual(128, s.Read(o, 0, 128));
TestUtils.ArrayEqual(b0[192..256], o[0..64]);
TestUtils.ArrayEqual(b1[0..64], o[64..128]);

Assert.IsFalse(s.CanSeek);
// refuse to go back when non-cache data has been read
Assert.ThrowsException<NotSupportedException>(() =>
{
s.Seek(0, SeekOrigin.Begin);
});

// read in non-cache range
Assert.AreEqual(64, s.Read(o, 0, 64));
s.Read(o, 0, 128);
Assert.AreEqual(512, s.Position);

Assert.AreEqual(128, s.Read(o, 0, 128));
TestUtils.ArrayEqual(b0[0..128], o);
s.Read(o, 0, 128);
s.Read(o, 0, 128);
s.Read(o, 0, 128);

// read at eos
Assert.AreEqual(0, s.Read(o, 0, 128));
}
}
}

+ 1
- 31
test/CryptographyTest.cs View File

@@ -42,7 +42,7 @@ namespace Shadowsocks.Test
//encryptor.Encrypt(plain, length, cipher, out int outLen);
//decryptor.Decrypt(cipher, outLen, plain2, out int outLen2);
Assert.AreEqual(length, outLen2);
ArrayEqual<byte>(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray());
TestUtils.ArrayEqual<byte>(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray());
}
const string password = "barfoo!";
@@ -70,36 +70,6 @@ namespace Shadowsocks.Test
throw;
}
}
private void ArrayEqual<T>(IEnumerable<T> expected, IEnumerable<T> actual, string msg = "")
{
var e1 = expected.GetEnumerator();
var e2 = actual.GetEnumerator();
int ctr = 0;
while (true)
{
var e1next = e1.MoveNext();
var e2next = e2.MoveNext();
if (e1next && e2next)
{
Assert.AreEqual(e1.Current, e2.Current, "at " + ctr);
}
else if (!e1next && !e2next)
{
return;
}
else if (!e1next)
{
Assert.Fail($"actual longer than expected ({ctr}) {msg}");
}
else
{
Assert.Fail($"actual shorter than expected ({ctr}) {msg}");
}
}
}
private static bool encryptionFailed = false;
private void TestEncryptionMethod(Type enc, string method)


+ 2
- 0
test/ShadowsocksTest.csproj View File

@@ -4,6 +4,8 @@
<TargetFramework>netcoreapp3.1</TargetFramework>
<IsPackable>false</IsPackable>
<RootNamespace>Shadowsocks.Test</RootNamespace>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


+ 41
- 0
test/TestUtils.cs View File

@@ -0,0 +1,41 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;

namespace Shadowsocks.Test
{
class TestUtils
{
public static void ArrayEqual<T>(IEnumerable<T> expected, IEnumerable<T> actual, string msg = "")
{
var e1 = expected.GetEnumerator();
var e2 = actual.GetEnumerator();
int ctr = 0;
while (true)
{
var e1next = e1.MoveNext();
var e2next = e2.MoveNext();

if (e1next && e2next)
{
Assert.AreEqual(e1.Current, e2.Current, "at " + ctr);
}
else if (!e1next && !e2next)
{
return;
}
else if (!e1next)
{
Assert.Fail($"actual longer than expected ({ctr}) {msg}");
}
else
{
Assert.Fail($"actual shorter than expected ({ctr}) {msg}");
}
ctr++;
}
}

}
}

Loading…
Cancel
Save