using System; using System.Net; using System.Net.Sockets; using System.Threading; namespace Shadowsocks.Util.Sockets { /* * A wrapped socket class which support both ipv4 and ipv6 based on the * connected remote endpoint. * * If the server address is host name, then it may have both ipv4 and ipv6 address * after resolving. The main idea is we don't want to resolve and choose the address * by ourself. Instead, Socket.ConnectAsync() do handle this thing internally by trying * each address and returning an established socket connection. */ public class WrappedSocket { public EndPoint LocalEndPoint => _activeSocket?.LocalEndPoint; // Only used during connection and close, so it won't cost too much. private SpinLock _socketSyncLock = new SpinLock(); private volatile bool _disposed; private bool Connected => _activeSocket != null; private Socket _activeSocket; public void BeginConnect(EndPoint remoteEP, AsyncCallback callback, object state) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } if (Connected) { throw new SocketException((int) SocketError.IsConnected); } var arg = new SocketAsyncEventArgs(); arg.RemoteEndPoint = remoteEP; arg.Completed += OnTcpConnectCompleted; arg.UserToken = new TcpUserToken(callback, state); Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, arg); } private class FakeAsyncResult : IAsyncResult { public bool IsCompleted { get; } = true; public WaitHandle AsyncWaitHandle { get; } = null; public object AsyncState { get; set; } public bool CompletedSynchronously { get; } = true; public Exception InternalException { get; set; } = null; } private class TcpUserToken { public AsyncCallback Callback { get; } public object AsyncState { get; } public TcpUserToken(AsyncCallback callback, object state) { Callback = callback; AsyncState = state; } } private void OnTcpConnectCompleted(object sender, SocketAsyncEventArgs args) { using (args) { args.Completed -= OnTcpConnectCompleted; var token = (TcpUserToken) args.UserToken; if (args.SocketError != SocketError.Success) { var ex = args.ConnectByNameError ?? new SocketException((int) args.SocketError); var r = new FakeAsyncResult() { AsyncState = token.AsyncState, InternalException = ex }; token.Callback(r); } else { var lockTaken = false; if (!_socketSyncLock.IsHeldByCurrentThread) { _socketSyncLock.TryEnter(ref lockTaken); } try { if (Connected) { args.ConnectSocket.FullClose(); } else { _activeSocket = args.ConnectSocket; if (_disposed) { _activeSocket.FullClose(); } var r = new FakeAsyncResult() { AsyncState = token.AsyncState }; token.Callback(r); } } finally { if (lockTaken) { _socketSyncLock.Exit(); } } } } } public void EndConnect(IAsyncResult asyncResult) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } var r = asyncResult as FakeAsyncResult; if (r == null) { throw new ArgumentException("Invalid asyncResult.", nameof(asyncResult)); } if (r.InternalException != null) { throw r.InternalException; } } public void Dispose() { if (_disposed) { return; } var lockTaken = false; if (!_socketSyncLock.IsHeldByCurrentThread) { _socketSyncLock.TryEnter(ref lockTaken); } try { _disposed = true; _activeSocket?.FullClose(); } finally { if (lockTaken) { _socketSyncLock.Exit(); } } } public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback callback, object state) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } if (!Connected) { throw new SocketException((int) SocketError.NotConnected); } return _activeSocket.BeginSend(buffer, offset, size, socketFlags, callback, state); } public int EndSend(IAsyncResult asyncResult) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } if (!Connected) { throw new SocketException((int) SocketError.NotConnected); } return _activeSocket.EndSend(asyncResult); } public IAsyncResult BeginReceive(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback callback, object state) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } if (!Connected) { throw new SocketException((int) SocketError.NotConnected); } return _activeSocket.BeginReceive(buffer, offset, size, socketFlags, callback, state); } public int EndReceive(IAsyncResult asyncResult) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } if (!Connected) { throw new SocketException((int) SocketError.NotConnected); } return _activeSocket.EndReceive(asyncResult); } public void Shutdown(SocketShutdown how) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } if (!Connected) { return; } _activeSocket.Shutdown(how); } public void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName optionName, bool optionValue) { SetSocketOption(optionLevel, optionName, optionValue ? 1 : 0); } public void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName optionName, int optionValue) { if (_disposed) { throw new ObjectDisposedException(GetType().FullName); } if (!Connected) { throw new SocketException((int)SocketError.NotConnected); } _activeSocket.SetSocketOption(optionLevel, optionName, optionValue); } } }