You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AEADEncryptor.cs 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. using NLog;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Net;
  6. using System.Text;
  7. using Shadowsocks.Encryption.CircularBuffer;
  8. using Shadowsocks.Controller;
  9. using Shadowsocks.Encryption.Exception;
  10. namespace Shadowsocks.Encryption.AEAD
  11. {
  12. public abstract class AEADEncryptor
  13. : EncryptorBase
  14. {
  15. private static Logger logger = LogManager.GetCurrentClassLogger();
  16. // We are using the same saltLen and keyLen
  17. private const string Info = "ss-subkey";
  18. private static readonly byte[] InfoBytes = Encoding.ASCII.GetBytes(Info);
  19. // for UDP only
  20. protected static byte[] _udpTmpBuf = new byte[65536];
  21. // every connection should create its own buffer
  22. private ByteCircularBuffer _encCircularBuffer = new ByteCircularBuffer(MAX_INPUT_SIZE * 2);
  23. private ByteCircularBuffer _decCircularBuffer = new ByteCircularBuffer(MAX_INPUT_SIZE * 2);
  24. public const int CHUNK_LEN_BYTES = 2;
  25. public const uint CHUNK_LEN_MASK = 0x3FFFu;
  26. protected Dictionary<string, EncryptorInfo> ciphers;
  27. protected string _method;
  28. protected int _cipher;
  29. // internal name in the crypto library
  30. protected string _innerLibName;
  31. protected EncryptorInfo CipherInfo;
  32. protected static byte[] _Masterkey = null;
  33. protected byte[] _sessionKey;
  34. protected int keyLen;
  35. protected int saltLen;
  36. protected int tagLen;
  37. protected int nonceLen;
  38. protected byte[] _encryptSalt;
  39. protected byte[] _decryptSalt;
  40. protected object _nonceIncrementLock = new object();
  41. protected byte[] _encNonce;
  42. protected byte[] _decNonce;
  43. // Is first packet
  44. protected bool _decryptSaltReceived;
  45. protected bool _encryptSaltSent;
  46. // Is first chunk(tcp request)
  47. protected bool _tcpRequestSent;
  48. public AEADEncryptor(string method, string password)
  49. : base(method, password)
  50. {
  51. InitEncryptorInfo(method);
  52. InitKey(password);
  53. // Initialize all-zero nonce for each connection
  54. _encNonce = new byte[nonceLen];
  55. _decNonce = new byte[nonceLen];
  56. }
  57. protected abstract Dictionary<string, EncryptorInfo> getCiphers();
  58. protected void InitEncryptorInfo(string method)
  59. {
  60. method = method.ToLower();
  61. _method = method;
  62. ciphers = getCiphers();
  63. CipherInfo = ciphers[_method];
  64. _innerLibName = CipherInfo.InnerLibName;
  65. _cipher = CipherInfo.Type;
  66. if (_cipher == 0) {
  67. throw new System.Exception("method not found");
  68. }
  69. keyLen = CipherInfo.KeySize;
  70. saltLen = CipherInfo.SaltSize;
  71. tagLen = CipherInfo.TagSize;
  72. nonceLen = CipherInfo.NonceSize;
  73. }
  74. protected void InitKey(string password)
  75. {
  76. byte[] passbuf = Encoding.UTF8.GetBytes(password);
  77. // init master key
  78. if (_Masterkey == null) _Masterkey = new byte[keyLen];
  79. if (_Masterkey.Length != keyLen) Array.Resize(ref _Masterkey, keyLen);
  80. DeriveKey(passbuf, _Masterkey, keyLen);
  81. // init session key
  82. if (_sessionKey == null) _sessionKey = new byte[keyLen];
  83. }
  84. public void DeriveKey(byte[] password, byte[] key, int keylen)
  85. {
  86. byte[] result = new byte[password.Length + MD5_LEN];
  87. int i = 0;
  88. byte[] md5sum = null;
  89. while (i < keylen)
  90. {
  91. if (i == 0)
  92. {
  93. md5sum = MbedTLS.MD5(password);
  94. }
  95. else
  96. {
  97. Array.Copy(md5sum, 0, result, 0, MD5_LEN);
  98. Array.Copy(password, 0, result, MD5_LEN, password.Length);
  99. md5sum = MbedTLS.MD5(result);
  100. }
  101. Array.Copy(md5sum, 0, key, i, Math.Min(MD5_LEN, keylen - i));
  102. i += MD5_LEN;
  103. }
  104. }
  105. public void DeriveSessionKey(byte[] salt, byte[] masterKey, byte[] sessionKey)
  106. {
  107. int ret = MbedTLS.hkdf(salt, saltLen, masterKey, keyLen, InfoBytes, InfoBytes.Length, sessionKey,
  108. keyLen);
  109. if (ret != 0) throw new System.Exception("failed to generate session key");
  110. }
  111. protected void IncrementNonce(bool isEncrypt)
  112. {
  113. lock (_nonceIncrementLock) {
  114. Sodium.sodium_increment(isEncrypt ? _encNonce : _decNonce, nonceLen);
  115. }
  116. }
  117. public virtual void InitCipher(byte[] salt, bool isEncrypt, bool isUdp)
  118. {
  119. if (isEncrypt) {
  120. _encryptSalt = new byte[saltLen];
  121. Array.Copy(salt, _encryptSalt, saltLen);
  122. } else {
  123. _decryptSalt = new byte[saltLen];
  124. Array.Copy(salt, _decryptSalt, saltLen);
  125. }
  126. logger.Dump("Salt", salt, saltLen);
  127. }
  128. public static void randBytes(byte[] buf, int length) { RNG.GetBytes(buf, length); }
  129. public abstract void cipherEncrypt(byte[] plaintext, uint plen, byte[] ciphertext, ref uint clen);
  130. public abstract void cipherDecrypt(byte[] ciphertext, uint clen, byte[] plaintext, ref uint plen);
  131. #region TCP
  132. public override void Encrypt(byte[] buf, int length, byte[] outbuf, out int outlength)
  133. {
  134. Debug.Assert(_encCircularBuffer != null, "_encCircularBuffer != null");
  135. _encCircularBuffer.Put(buf, 0, length);
  136. outlength = 0;
  137. logger.Trace("---Start Encryption");
  138. if (! _encryptSaltSent) {
  139. _encryptSaltSent = true;
  140. // Generate salt
  141. byte[] saltBytes = new byte[saltLen];
  142. randBytes(saltBytes, saltLen);
  143. InitCipher(saltBytes, true, false);
  144. Array.Copy(saltBytes, 0, outbuf, 0, saltLen);
  145. outlength = saltLen;
  146. logger.Trace($"_encryptSaltSent outlength {outlength}");
  147. }
  148. if (! _tcpRequestSent) {
  149. _tcpRequestSent = true;
  150. // The first TCP request
  151. int encAddrBufLength;
  152. byte[] encAddrBufBytes = new byte[AddrBufLength + tagLen * 2 + CHUNK_LEN_BYTES];
  153. byte[] addrBytes = _encCircularBuffer.Get(AddrBufLength);
  154. ChunkEncrypt(addrBytes, AddrBufLength, encAddrBufBytes, out encAddrBufLength);
  155. Debug.Assert(encAddrBufLength == AddrBufLength + tagLen * 2 + CHUNK_LEN_BYTES);
  156. Array.Copy(encAddrBufBytes, 0, outbuf, outlength, encAddrBufLength);
  157. outlength += encAddrBufLength;
  158. logger.Trace($"_tcpRequestSent outlength {outlength}");
  159. }
  160. // handle other chunks
  161. while (true) {
  162. uint bufSize = (uint)_encCircularBuffer.Size;
  163. if (bufSize <= 0) return;
  164. var chunklength = (int)Math.Min(bufSize, CHUNK_LEN_MASK);
  165. byte[] chunkBytes = _encCircularBuffer.Get(chunklength);
  166. int encChunkLength;
  167. byte[] encChunkBytes = new byte[chunklength + tagLen * 2 + CHUNK_LEN_BYTES];
  168. ChunkEncrypt(chunkBytes, chunklength, encChunkBytes, out encChunkLength);
  169. Debug.Assert(encChunkLength == chunklength + tagLen * 2 + CHUNK_LEN_BYTES);
  170. Buffer.BlockCopy(encChunkBytes, 0, outbuf, outlength, encChunkLength);
  171. outlength += encChunkLength;
  172. logger.Trace("chunks enc outlength " + outlength);
  173. // check if we have enough space for outbuf
  174. if (outlength + TCPHandler.ChunkOverheadSize > TCPHandler.BufferSize) {
  175. logger.Trace("enc outbuf almost full, giving up");
  176. return;
  177. }
  178. bufSize = (uint)_encCircularBuffer.Size;
  179. if (bufSize <= 0) {
  180. logger.Trace("No more data to encrypt, leaving");
  181. return;
  182. }
  183. }
  184. }
  185. public override void Decrypt(byte[] buf, int length, byte[] outbuf, out int outlength)
  186. {
  187. Debug.Assert(_decCircularBuffer != null, "_decCircularBuffer != null");
  188. int bufSize;
  189. outlength = 0;
  190. // drop all into buffer
  191. _decCircularBuffer.Put(buf, 0, length);
  192. logger.Trace("---Start Decryption");
  193. if (! _decryptSaltReceived) {
  194. bufSize = _decCircularBuffer.Size;
  195. // check if we get the leading salt
  196. if (bufSize <= saltLen) {
  197. // need more
  198. return;
  199. }
  200. _decryptSaltReceived = true;
  201. byte[] salt = _decCircularBuffer.Get(saltLen);
  202. InitCipher(salt, false, false);
  203. logger.Trace("get salt len " + saltLen);
  204. }
  205. // handle chunks
  206. while (true) {
  207. bufSize = _decCircularBuffer.Size;
  208. // check if we have any data
  209. if (bufSize <= 0) {
  210. logger.Trace("No data in _decCircularBuffer");
  211. return;
  212. }
  213. // first get chunk length
  214. if (bufSize <= CHUNK_LEN_BYTES + tagLen) {
  215. // so we only have chunk length and its tag?
  216. return;
  217. }
  218. #region Chunk Decryption
  219. byte[] encLenBytes = _decCircularBuffer.Peek(CHUNK_LEN_BYTES + tagLen);
  220. uint decChunkLenLength = 0;
  221. byte[] decChunkLenBytes = new byte[CHUNK_LEN_BYTES];
  222. // try to dec chunk len
  223. cipherDecrypt(encLenBytes, CHUNK_LEN_BYTES + (uint)tagLen, decChunkLenBytes, ref decChunkLenLength);
  224. Debug.Assert(decChunkLenLength == CHUNK_LEN_BYTES);
  225. // finally we get the real chunk len
  226. ushort chunkLen = (ushort) IPAddress.NetworkToHostOrder((short)BitConverter.ToUInt16(decChunkLenBytes, 0));
  227. if (chunkLen > CHUNK_LEN_MASK)
  228. {
  229. // we get invalid chunk
  230. logger.Error($"Invalid chunk length: {chunkLen}");
  231. throw new CryptoErrorException();
  232. }
  233. logger.Trace("Get the real chunk len:" + chunkLen);
  234. bufSize = _decCircularBuffer.Size;
  235. if (bufSize < CHUNK_LEN_BYTES + tagLen /* we haven't remove them */+ chunkLen + tagLen) {
  236. logger.Trace("No more data to decrypt one chunk");
  237. return;
  238. }
  239. IncrementNonce(false);
  240. // we have enough data to decrypt one chunk
  241. // drop chunk len and its tag from buffer
  242. _decCircularBuffer.Skip(CHUNK_LEN_BYTES + tagLen);
  243. byte[] encChunkBytes = _decCircularBuffer.Get(chunkLen + tagLen);
  244. byte[] decChunkBytes = new byte[chunkLen];
  245. uint decChunkLen = 0;
  246. cipherDecrypt(encChunkBytes, chunkLen + (uint)tagLen, decChunkBytes, ref decChunkLen);
  247. Debug.Assert(decChunkLen == chunkLen);
  248. IncrementNonce(false);
  249. #endregion
  250. // output to outbuf
  251. Buffer.BlockCopy(decChunkBytes, 0, outbuf, outlength, (int) decChunkLen);
  252. outlength += (int)decChunkLen;
  253. logger.Trace("aead dec outlength " + outlength);
  254. if (outlength + 100 > TCPHandler.BufferSize)
  255. {
  256. logger.Trace("dec outbuf almost full, giving up");
  257. return;
  258. }
  259. bufSize = _decCircularBuffer.Size;
  260. // check if we already done all of them
  261. if (bufSize <= 0) {
  262. logger.Trace("No data in _decCircularBuffer, already all done");
  263. return;
  264. }
  265. }
  266. }
  267. #endregion
  268. #region UDP
  269. public override void EncryptUDP(byte[] buf, int length, byte[] outbuf, out int outlength)
  270. {
  271. // Generate salt
  272. randBytes(outbuf, saltLen);
  273. InitCipher(outbuf, true, true);
  274. uint olen = 0;
  275. lock (_udpTmpBuf) {
  276. cipherEncrypt(buf, (uint) length, _udpTmpBuf, ref olen);
  277. Debug.Assert(olen == length + tagLen);
  278. Buffer.BlockCopy(_udpTmpBuf, 0, outbuf, saltLen, (int) olen);
  279. outlength = (int) (saltLen + olen);
  280. }
  281. }
  282. public override void DecryptUDP(byte[] buf, int length, byte[] outbuf, out int outlength)
  283. {
  284. InitCipher(buf, false, true);
  285. uint olen = 0;
  286. lock (_udpTmpBuf) {
  287. // copy remaining data to first pos
  288. Buffer.BlockCopy(buf, saltLen, buf, 0, length - saltLen);
  289. cipherDecrypt(buf, (uint) (length - saltLen), _udpTmpBuf, ref olen);
  290. Buffer.BlockCopy(_udpTmpBuf, 0, outbuf, 0, (int) olen);
  291. outlength = (int) olen;
  292. }
  293. }
  294. #endregion
  295. // we know the plaintext length before encryption, so we can do it in one operation
  296. private void ChunkEncrypt(byte[] plaintext, int plainLen, byte[] ciphertext, out int cipherLen)
  297. {
  298. if (plainLen > CHUNK_LEN_MASK) {
  299. logger.Error("enc chunk too big");
  300. throw new CryptoErrorException();
  301. }
  302. // encrypt len
  303. byte[] encLenBytes = new byte[CHUNK_LEN_BYTES + tagLen];
  304. uint encChunkLenLength = 0;
  305. byte[] lenbuf = BitConverter.GetBytes((ushort) IPAddress.HostToNetworkOrder((short)plainLen));
  306. cipherEncrypt(lenbuf, CHUNK_LEN_BYTES, encLenBytes, ref encChunkLenLength);
  307. Debug.Assert(encChunkLenLength == CHUNK_LEN_BYTES + tagLen);
  308. IncrementNonce(true);
  309. // encrypt corresponding data
  310. byte[] encBytes = new byte[plainLen + tagLen];
  311. uint encBufLength = 0;
  312. cipherEncrypt(plaintext, (uint) plainLen, encBytes, ref encBufLength);
  313. Debug.Assert(encBufLength == plainLen + tagLen);
  314. IncrementNonce(true);
  315. // construct outbuf
  316. Array.Copy(encLenBytes, 0, ciphertext, 0, (int) encChunkLenLength);
  317. Buffer.BlockCopy(encBytes, 0, ciphertext, (int) encChunkLenLength, (int) encBufLength);
  318. cipherLen = (int) (encChunkLenLength + encBufLength);
  319. }
  320. }
  321. }