using System; using System.Collections.Generic; using System.IO; using System.Net.Sockets; using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; namespace Shadowsocks.Net { // cache first packet for duty-chain pattern listener public class CachedNetworkStream : Stream { // 256 byte first packet buffer should enough for 99.999...% situation // socks5: 0x05 0x.... // http-pac: GET /pac HTTP/1.1 // http-proxy: /[a-z]+ .+ HTTP\/1\.[01]/i public const int MaxCache = 256; public Socket Socket { get; private set; } private readonly Stream s; private byte[] cache = new byte[MaxCache]; private long cachePtr = 0; private long readPtr = 0; public CachedNetworkStream(Socket socket) { s = new NetworkStream(socket); Socket = socket; } /// /// Only for test purpose /// /// public CachedNetworkStream(Stream stream) { s = stream; } public override bool CanRead => s.CanRead; // we haven't run out of cache public override bool CanSeek => cachePtr == readPtr; public override bool CanWrite => s.CanWrite; public override long Length => s.Length; public override long Position { get => readPtr; set => Seek(value, SeekOrigin.Begin); } public override void Flush() { s.Flush(); } //public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) //{ // var endPtr = buffer.Length + readPtr; // expected ptr after operation // var uncachedLen = Math.Max(endPtr - cachePtr, 0); // how many data from socket // var cachedLen = buffer.Length - uncachedLen; // how many data from cache // var emptyCacheLen = MaxCache - cachePtr; // how many cache remain // int readLen = 0; // var cachedMem = buffer[..(int)cachedLen]; // var uncachedMem = buffer[(int)cachedLen..]; // if (cachedLen > 0) // { // cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedMem); // readPtr += cachedLen; // readLen += (int)cachedLen; // } // if (uncachedLen > 0) // { // int readStreamLen = await s.ReadAsync(cachedMem, cancellationToken); // int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); // how many data need to cache // if (lengthToCache > 0) // { // uncachedMem[0..lengthToCache].CopyTo(cache[(int)cachePtr..]); // cachePtr += lengthToCache; // } // readPtr += readStreamLen; // readLen += readStreamLen; // } // return readLen; //} [MethodImpl(MethodImplOptions.AggressiveInlining)] public override int Read(byte[] buffer, int offset, int count) { Span span = buffer.AsSpan(offset, count); return Read(span); } [MethodImpl(MethodImplOptions.AggressiveOptimization)] public override int Read(Span buffer) { // how many data from socket // r: readPtr, c: cachePtr, e: endPtr // ptr 0 r c e // cached ####+++++ // read ++++ // ptr 0 c r e // cached ##### // read +++++ var endPtr = buffer.Length + readPtr; // expected ptr after operation var uncachedLen = Math.Max(endPtr - Math.Max(cachePtr, readPtr), 0); var cachedLen = buffer.Length - uncachedLen; // how many data from cache var emptyCacheLen = MaxCache - cachePtr; // how many cache remain int readLen = 0; Span cachedSpan = buffer[..(int)cachedLen]; Span uncachedSpan = buffer[(int)cachedLen..]; if (cachedLen > 0) { cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedSpan); readPtr += cachedLen; readLen += (int)cachedLen; } if (uncachedLen > 0) { int readStreamLen = s.Read(uncachedSpan); // how many data need to cache int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); if (lengthToCache > 0) { uncachedSpan[0..lengthToCache].ToArray().CopyTo(cache, cachePtr); cachePtr += lengthToCache; } readPtr += readStreamLen; readLen += readStreamLen; } return readLen; } /// /// Read first block, will never read into non-cache range /// /// /// public int ReadFirstBlock(Span buffer) { Seek(0, SeekOrigin.Begin); int len = Math.Min(MaxCache, buffer.Length); return Read(buffer[0..len]); } /// /// Seek position, only support seek to cached range when we haven't read into non-cache range /// /// /// Set it to System.IO.SeekOrigin.Begin, otherwise it will throw System.NotSupportedException /// /// /// /// public override long Seek(long offset, SeekOrigin origin) { if (!CanSeek) throw new NotSupportedException("Non cache data has been read"); if (origin != SeekOrigin.Begin) throw new NotSupportedException("We don't know network stream's length"); if (offset < 0 || offset > cachePtr) throw new NotSupportedException("Can't seek to uncached position"); readPtr = offset; return Position; } /// /// Useless /// /// /// public override void SetLength(long value) { s.SetLength(value); } /// /// Write to underly stream /// /// /// /// /// /// public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { return s.WriteAsync(buffer, offset, count, cancellationToken); } /// /// Write to underly stream /// /// /// /// public override void Write(byte[] buffer, int offset, int count) { s.Write(buffer, offset, count); } protected override void Dispose(bool disposing) { s.Dispose(); base.Dispose(disposing); } } }