You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AEADEncryptor.cs 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Net;
  5. using System.Text;
  6. using Shadowsocks.Encryption.CircularBuffer;
  7. using Shadowsocks.Controller;
  8. using Shadowsocks.Encryption.Exception;
  9. using Shadowsocks.Encryption.Stream;
  10. namespace Shadowsocks.Encryption.AEAD
  11. {
  12. public abstract class AEADEncryptor
  13. : EncryptorBase
  14. {
  15. // We are using the same saltLen and keyLen
  16. private const string Info = "ss-subkey";
  17. private static readonly byte[] InfoBytes = Encoding.ASCII.GetBytes(Info);
  18. // for UDP only
  19. protected static byte[] _udpTmpBuf = new byte[65536];
  20. // every connection should create its own buffer
  21. private ByteCircularBuffer _encCircularBuffer = new ByteCircularBuffer(MAX_INPUT_SIZE * 2);
  22. private ByteCircularBuffer _decCircularBuffer = new ByteCircularBuffer(MAX_INPUT_SIZE * 2);
  23. public const int CHUNK_LEN_BYTES = 2;
  24. public const uint CHUNK_LEN_MASK = 0x3FFFu;
  25. protected Dictionary<string, EncryptorInfo> ciphers;
  26. protected string _method;
  27. protected int _cipher;
  28. // internal name in the crypto library
  29. protected string _innerLibName;
  30. protected EncryptorInfo CipherInfo;
  31. protected static byte[] _Masterkey = null;
  32. protected byte[] _sessionKey;
  33. protected int keyLen;
  34. protected int saltLen;
  35. protected int tagLen;
  36. protected int nonceLen;
  37. protected byte[] _encryptSalt;
  38. protected byte[] _decryptSalt;
  39. protected object _nonceIncrementLock = new object();
  40. protected byte[] _encNonce;
  41. protected byte[] _decNonce;
  42. // Is first packet
  43. protected bool _decryptSaltReceived;
  44. protected bool _encryptSaltSent;
  45. // Is first chunk(tcp request)
  46. protected bool _tcpRequestSent;
  47. public AEADEncryptor(string method, string password)
  48. : base(method, password)
  49. {
  50. InitEncryptorInfo(method);
  51. InitKey(password);
  52. // Initialize all-zero nonce for each connection
  53. _encNonce = new byte[nonceLen];
  54. _decNonce = new byte[nonceLen];
  55. }
  56. protected abstract Dictionary<string, EncryptorInfo> getCiphers();
  57. protected void InitEncryptorInfo(string method)
  58. {
  59. method = method.ToLower();
  60. _method = method;
  61. ciphers = getCiphers();
  62. CipherInfo = ciphers[_method];
  63. _innerLibName = CipherInfo.InnerLibName;
  64. _cipher = CipherInfo.Type;
  65. if (_cipher == 0) {
  66. throw new System.Exception("method not found");
  67. }
  68. keyLen = CipherInfo.KeySize;
  69. saltLen = CipherInfo.SaltSize;
  70. tagLen = CipherInfo.TagSize;
  71. nonceLen = CipherInfo.NonceSize;
  72. }
  73. protected void InitKey(string password)
  74. {
  75. byte[] passbuf = Encoding.UTF8.GetBytes(password);
  76. // init master key
  77. if (_Masterkey == null) _Masterkey = new byte[keyLen];
  78. if (_Masterkey.Length != keyLen) Array.Resize(ref _Masterkey, keyLen);
  79. DeriveKey(passbuf, _Masterkey, keyLen);
  80. // init session key
  81. if (_sessionKey == null) _sessionKey = new byte[keyLen];
  82. }
  83. public void DeriveKey(byte[] password, byte[] key, int keylen)
  84. {
  85. StreamEncryptor.LegacyDeriveKey(password, key, keylen);
  86. }
  87. public void DeriveSessionKey(byte[] salt, byte[] masterKey, byte[] sessionKey)
  88. {
  89. int ret = MbedTLS.hkdf(salt, saltLen, masterKey, keyLen, InfoBytes, InfoBytes.Length, sessionKey,
  90. keyLen);
  91. if (ret != 0) throw new System.Exception("failed to generate session key");
  92. }
  93. protected void IncrementNonce(bool isEncrypt)
  94. {
  95. lock (_nonceIncrementLock) {
  96. Sodium.sodium_increment(isEncrypt ? _encNonce : _decNonce, nonceLen);
  97. }
  98. }
  99. public virtual void InitCipher(byte[] salt, bool isEncrypt, bool isUdp)
  100. {
  101. if (isEncrypt) {
  102. _encryptSalt = new byte[saltLen];
  103. Array.Copy(salt, _encryptSalt, saltLen);
  104. } else {
  105. _decryptSalt = new byte[saltLen];
  106. Array.Copy(salt, _decryptSalt, saltLen);
  107. }
  108. Logging.Dump("Salt", salt, saltLen);
  109. }
  110. public static void randBytes(byte[] buf, int length) { RNG.GetBytes(buf, length); }
  111. public abstract void cipherEncrypt(byte[] plaintext, uint plen, byte[] ciphertext, ref uint clen);
  112. public abstract void cipherDecrypt(byte[] ciphertext, uint clen, byte[] plaintext, ref uint plen);
  113. #region TCP
  114. public override void Encrypt(byte[] buf, int length, byte[] outbuf, out int outlength)
  115. {
  116. Debug.Assert(_encCircularBuffer != null, "_encCircularBuffer != null");
  117. _encCircularBuffer.Put(buf, 0, length);
  118. outlength = 0;
  119. Logging.Debug("---Start Encryption");
  120. if (! _encryptSaltSent) {
  121. _encryptSaltSent = true;
  122. // Generate salt
  123. byte[] saltBytes = new byte[saltLen];
  124. randBytes(saltBytes, saltLen);
  125. InitCipher(saltBytes, true, false);
  126. Array.Copy(saltBytes, 0, outbuf, 0, saltLen);
  127. outlength = saltLen;
  128. Logging.Debug($"_encryptSaltSent outlength {outlength}");
  129. }
  130. if (! _tcpRequestSent) {
  131. _tcpRequestSent = true;
  132. // The first TCP request
  133. int encAddrBufLength;
  134. byte[] encAddrBufBytes = new byte[AddrBufLength + tagLen * 2 + CHUNK_LEN_BYTES];
  135. byte[] addrBytes = _encCircularBuffer.Get(AddrBufLength);
  136. ChunkEncrypt(addrBytes, AddrBufLength, encAddrBufBytes, out encAddrBufLength);
  137. Debug.Assert(encAddrBufLength == AddrBufLength + tagLen * 2 + CHUNK_LEN_BYTES);
  138. Array.Copy(encAddrBufBytes, 0, outbuf, outlength, encAddrBufLength);
  139. outlength += encAddrBufLength;
  140. Logging.Debug($"_tcpRequestSent outlength {outlength}");
  141. }
  142. // handle other chunks
  143. while (true) {
  144. uint bufSize = (uint)_encCircularBuffer.Size;
  145. if (bufSize <= 0) return;
  146. var chunklength = (int)Math.Min(bufSize, CHUNK_LEN_MASK);
  147. byte[] chunkBytes = _encCircularBuffer.Get(chunklength);
  148. int encChunkLength;
  149. byte[] encChunkBytes = new byte[chunklength + tagLen * 2 + CHUNK_LEN_BYTES];
  150. ChunkEncrypt(chunkBytes, chunklength, encChunkBytes, out encChunkLength);
  151. Debug.Assert(encChunkLength == chunklength + tagLen * 2 + CHUNK_LEN_BYTES);
  152. Buffer.BlockCopy(encChunkBytes, 0, outbuf, outlength, encChunkLength);
  153. outlength += encChunkLength;
  154. Logging.Debug("chunks enc outlength " + outlength);
  155. // check if we have enough space for outbuf
  156. if (outlength + TCPHandler.ChunkOverheadSize > TCPHandler.BufferSize) {
  157. Logging.Debug("enc outbuf almost full, giving up");
  158. return;
  159. }
  160. bufSize = (uint)_encCircularBuffer.Size;
  161. if (bufSize <= 0) {
  162. Logging.Debug("No more data to encrypt, leaving");
  163. return;
  164. }
  165. }
  166. }
  167. public override void Decrypt(byte[] buf, int length, byte[] outbuf, out int outlength)
  168. {
  169. Debug.Assert(_decCircularBuffer != null, "_decCircularBuffer != null");
  170. int bufSize;
  171. outlength = 0;
  172. // drop all into buffer
  173. _decCircularBuffer.Put(buf, 0, length);
  174. Logging.Debug("---Start Decryption");
  175. if (! _decryptSaltReceived) {
  176. bufSize = _decCircularBuffer.Size;
  177. // check if we get the leading salt
  178. if (bufSize <= saltLen) {
  179. // need more
  180. return;
  181. }
  182. _decryptSaltReceived = true;
  183. byte[] salt = _decCircularBuffer.Get(saltLen);
  184. InitCipher(salt, false, false);
  185. Logging.Debug("get salt len " + saltLen);
  186. }
  187. // handle chunks
  188. while (true) {
  189. bufSize = _decCircularBuffer.Size;
  190. // check if we have any data
  191. if (bufSize <= 0) {
  192. Logging.Debug("No data in _decCircularBuffer");
  193. return;
  194. }
  195. // first get chunk length
  196. if (bufSize <= CHUNK_LEN_BYTES + tagLen) {
  197. // so we only have chunk length and its tag?
  198. return;
  199. }
  200. #region Chunk Decryption
  201. byte[] encLenBytes = _decCircularBuffer.Peek(CHUNK_LEN_BYTES + tagLen);
  202. uint decChunkLenLength = 0;
  203. byte[] decChunkLenBytes = new byte[CHUNK_LEN_BYTES];
  204. // try to dec chunk len
  205. cipherDecrypt(encLenBytes, CHUNK_LEN_BYTES + (uint)tagLen, decChunkLenBytes, ref decChunkLenLength);
  206. Debug.Assert(decChunkLenLength == CHUNK_LEN_BYTES);
  207. // finally we get the real chunk len
  208. ushort chunkLen = (ushort) IPAddress.NetworkToHostOrder((short)BitConverter.ToUInt16(decChunkLenBytes, 0));
  209. if (chunkLen > CHUNK_LEN_MASK)
  210. {
  211. // we get invalid chunk
  212. Logging.Error($"Invalid chunk length: {chunkLen}");
  213. throw new CryptoErrorException();
  214. }
  215. Logging.Debug("Get the real chunk len:" + chunkLen);
  216. bufSize = _decCircularBuffer.Size;
  217. if (bufSize < CHUNK_LEN_BYTES + tagLen /* we haven't remove them */+ chunkLen + tagLen) {
  218. Logging.Debug("No more data to decrypt one chunk");
  219. return;
  220. }
  221. IncrementNonce(false);
  222. // we have enough data to decrypt one chunk
  223. // drop chunk len and its tag from buffer
  224. _decCircularBuffer.Skip(CHUNK_LEN_BYTES + tagLen);
  225. byte[] encChunkBytes = _decCircularBuffer.Get(chunkLen + tagLen);
  226. byte[] decChunkBytes = new byte[chunkLen];
  227. uint decChunkLen = 0;
  228. cipherDecrypt(encChunkBytes, chunkLen + (uint)tagLen, decChunkBytes, ref decChunkLen);
  229. Debug.Assert(decChunkLen == chunkLen);
  230. IncrementNonce(false);
  231. #endregion
  232. // output to outbuf
  233. Buffer.BlockCopy(decChunkBytes, 0, outbuf, outlength, (int) decChunkLen);
  234. outlength += (int)decChunkLen;
  235. Logging.Debug("aead dec outlength " + outlength);
  236. if (outlength + 100 > TCPHandler.BufferSize)
  237. {
  238. Logging.Debug("dec outbuf almost full, giving up");
  239. return;
  240. }
  241. bufSize = _decCircularBuffer.Size;
  242. // check if we already done all of them
  243. if (bufSize <= 0) {
  244. Logging.Debug("No data in _decCircularBuffer, already all done");
  245. return;
  246. }
  247. }
  248. }
  249. #endregion
  250. #region UDP
  251. public override void EncryptUDP(byte[] buf, int length, byte[] outbuf, out int outlength)
  252. {
  253. // Generate salt
  254. randBytes(outbuf, saltLen);
  255. InitCipher(outbuf, true, true);
  256. uint olen = 0;
  257. lock (_udpTmpBuf) {
  258. cipherEncrypt(buf, (uint) length, _udpTmpBuf, ref olen);
  259. Debug.Assert(olen == length + tagLen);
  260. Buffer.BlockCopy(_udpTmpBuf, 0, outbuf, saltLen, (int) olen);
  261. outlength = (int) (saltLen + olen);
  262. }
  263. }
  264. public override void DecryptUDP(byte[] buf, int length, byte[] outbuf, out int outlength)
  265. {
  266. InitCipher(buf, false, true);
  267. uint olen = 0;
  268. lock (_udpTmpBuf) {
  269. // copy remaining data to first pos
  270. Buffer.BlockCopy(buf, saltLen, buf, 0, length - saltLen);
  271. cipherDecrypt(buf, (uint) (length - saltLen), _udpTmpBuf, ref olen);
  272. Buffer.BlockCopy(_udpTmpBuf, 0, outbuf, 0, (int) olen);
  273. outlength = (int) olen;
  274. }
  275. }
  276. #endregion
  277. // we know the plaintext length before encryption, so we can do it in one operation
  278. private void ChunkEncrypt(byte[] plaintext, int plainLen, byte[] ciphertext, out int cipherLen)
  279. {
  280. if (plainLen > CHUNK_LEN_MASK) {
  281. Logging.Error("enc chunk too big");
  282. throw new CryptoErrorException();
  283. }
  284. // encrypt len
  285. byte[] encLenBytes = new byte[CHUNK_LEN_BYTES + tagLen];
  286. uint encChunkLenLength = 0;
  287. byte[] lenbuf = BitConverter.GetBytes((ushort) IPAddress.HostToNetworkOrder((short)plainLen));
  288. cipherEncrypt(lenbuf, CHUNK_LEN_BYTES, encLenBytes, ref encChunkLenLength);
  289. Debug.Assert(encChunkLenLength == CHUNK_LEN_BYTES + tagLen);
  290. IncrementNonce(true);
  291. // encrypt corresponding data
  292. byte[] encBytes = new byte[plainLen + tagLen];
  293. uint encBufLength = 0;
  294. cipherEncrypt(plaintext, (uint) plainLen, encBytes, ref encBufLength);
  295. Debug.Assert(encBufLength == plainLen + tagLen);
  296. IncrementNonce(true);
  297. // construct outbuf
  298. Array.Copy(encLenBytes, 0, ciphertext, 0, (int) encChunkLenLength);
  299. Buffer.BlockCopy(encBytes, 0, ciphertext, (int) encChunkLenLength, (int) encBufLength);
  300. cipherLen = (int) (encChunkLenLength + encBufLength);
  301. }
  302. }
  303. }