You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AEADEncryptor.cs 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. using NLog;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Net;
  6. using System.Text;
  7. using Shadowsocks.Encryption.CircularBuffer;
  8. using Shadowsocks.Controller;
  9. using Shadowsocks.Encryption.Exception;
  10. using Shadowsocks.Encryption.Stream;
  11. namespace Shadowsocks.Encryption.AEAD
  12. {
  13. public abstract class AEADEncryptor
  14. : EncryptorBase
  15. {
  16. private static Logger logger = LogManager.GetCurrentClassLogger();
  17. // We are using the same saltLen and keyLen
  18. private const string Info = "ss-subkey";
  19. private static readonly byte[] InfoBytes = Encoding.ASCII.GetBytes(Info);
  20. // for UDP only
  21. protected static byte[] _udpTmpBuf = new byte[65536];
  22. // every connection should create its own buffer
  23. private ByteCircularBuffer _encCircularBuffer = new ByteCircularBuffer(MAX_INPUT_SIZE * 2);
  24. private ByteCircularBuffer _decCircularBuffer = new ByteCircularBuffer(MAX_INPUT_SIZE * 2);
  25. public const int CHUNK_LEN_BYTES = 2;
  26. public const uint CHUNK_LEN_MASK = 0x3FFFu;
  27. protected Dictionary<string, EncryptorInfo> ciphers;
  28. protected string _method;
  29. protected int _cipher;
  30. // internal name in the crypto library
  31. protected string _innerLibName;
  32. protected EncryptorInfo CipherInfo;
  33. protected static byte[] _Masterkey = null;
  34. protected byte[] _sessionKey;
  35. protected int keyLen;
  36. protected int saltLen;
  37. protected int tagLen;
  38. protected int nonceLen;
  39. protected byte[] _encryptSalt;
  40. protected byte[] _decryptSalt;
  41. protected object _nonceIncrementLock = new object();
  42. protected byte[] _encNonce;
  43. protected byte[] _decNonce;
  44. // Is first packet
  45. protected bool _decryptSaltReceived;
  46. protected bool _encryptSaltSent;
  47. // Is first chunk(tcp request)
  48. protected bool _tcpRequestSent;
  49. public AEADEncryptor(string method, string password)
  50. : base(method, password)
  51. {
  52. InitEncryptorInfo(method);
  53. InitKey(password);
  54. // Initialize all-zero nonce for each connection
  55. _encNonce = new byte[nonceLen];
  56. _decNonce = new byte[nonceLen];
  57. }
  58. protected abstract Dictionary<string, EncryptorInfo> getCiphers();
  59. protected void InitEncryptorInfo(string method)
  60. {
  61. method = method.ToLower();
  62. _method = method;
  63. ciphers = getCiphers();
  64. CipherInfo = ciphers[_method];
  65. _innerLibName = CipherInfo.InnerLibName;
  66. _cipher = CipherInfo.Type;
  67. if (_cipher == 0) {
  68. throw new System.Exception("method not found");
  69. }
  70. keyLen = CipherInfo.KeySize;
  71. saltLen = CipherInfo.SaltSize;
  72. tagLen = CipherInfo.TagSize;
  73. nonceLen = CipherInfo.NonceSize;
  74. }
  75. protected void InitKey(string password)
  76. {
  77. byte[] passbuf = Encoding.UTF8.GetBytes(password);
  78. // init master key
  79. if (_Masterkey == null) _Masterkey = new byte[keyLen];
  80. if (_Masterkey.Length != keyLen) Array.Resize(ref _Masterkey, keyLen);
  81. DeriveKey(passbuf, _Masterkey, keyLen);
  82. // init session key
  83. if (_sessionKey == null) _sessionKey = new byte[keyLen];
  84. }
  85. public void DeriveKey(byte[] password, byte[] key, int keylen)
  86. {
  87. StreamEncryptor.LegacyDeriveKey(password, key, keylen);
  88. }
  89. public void DeriveSessionKey(byte[] salt, byte[] masterKey, byte[] sessionKey)
  90. {
  91. int ret = MbedTLS.hkdf(salt, saltLen, masterKey, keyLen, InfoBytes, InfoBytes.Length, sessionKey,
  92. keyLen);
  93. if (ret != 0) throw new System.Exception("failed to generate session key");
  94. }
  95. protected void IncrementNonce(bool isEncrypt)
  96. {
  97. lock (_nonceIncrementLock) {
  98. Sodium.sodium_increment(isEncrypt ? _encNonce : _decNonce, nonceLen);
  99. }
  100. }
  101. public virtual void InitCipher(byte[] salt, bool isEncrypt, bool isUdp)
  102. {
  103. if (isEncrypt) {
  104. _encryptSalt = new byte[saltLen];
  105. Array.Copy(salt, _encryptSalt, saltLen);
  106. } else {
  107. _decryptSalt = new byte[saltLen];
  108. Array.Copy(salt, _decryptSalt, saltLen);
  109. }
  110. logger.Dump("Salt", salt, saltLen);
  111. }
  112. public static void randBytes(byte[] buf, int length) { RNG.GetBytes(buf, length); }
  113. public abstract void cipherEncrypt(byte[] plaintext, uint plen, byte[] ciphertext, ref uint clen);
  114. public abstract void cipherDecrypt(byte[] ciphertext, uint clen, byte[] plaintext, ref uint plen);
  115. #region TCP
  116. public override void Encrypt(byte[] buf, int length, byte[] outbuf, out int outlength)
  117. {
  118. Debug.Assert(_encCircularBuffer != null, "_encCircularBuffer != null");
  119. _encCircularBuffer.Put(buf, 0, length);
  120. outlength = 0;
  121. logger.Trace("---Start Encryption");
  122. if (! _encryptSaltSent) {
  123. _encryptSaltSent = true;
  124. // Generate salt
  125. byte[] saltBytes = new byte[saltLen];
  126. randBytes(saltBytes, saltLen);
  127. InitCipher(saltBytes, true, false);
  128. Array.Copy(saltBytes, 0, outbuf, 0, saltLen);
  129. outlength = saltLen;
  130. logger.Trace($"_encryptSaltSent outlength {outlength}");
  131. }
  132. if (! _tcpRequestSent) {
  133. _tcpRequestSent = true;
  134. // The first TCP request
  135. int encAddrBufLength;
  136. byte[] encAddrBufBytes = new byte[AddrBufLength + tagLen * 2 + CHUNK_LEN_BYTES];
  137. byte[] addrBytes = _encCircularBuffer.Get(AddrBufLength);
  138. ChunkEncrypt(addrBytes, AddrBufLength, encAddrBufBytes, out encAddrBufLength);
  139. Debug.Assert(encAddrBufLength == AddrBufLength + tagLen * 2 + CHUNK_LEN_BYTES);
  140. Array.Copy(encAddrBufBytes, 0, outbuf, outlength, encAddrBufLength);
  141. outlength += encAddrBufLength;
  142. logger.Trace($"_tcpRequestSent outlength {outlength}");
  143. }
  144. // handle other chunks
  145. while (true) {
  146. uint bufSize = (uint)_encCircularBuffer.Size;
  147. if (bufSize <= 0) return;
  148. var chunklength = (int)Math.Min(bufSize, CHUNK_LEN_MASK);
  149. byte[] chunkBytes = _encCircularBuffer.Get(chunklength);
  150. int encChunkLength;
  151. byte[] encChunkBytes = new byte[chunklength + tagLen * 2 + CHUNK_LEN_BYTES];
  152. ChunkEncrypt(chunkBytes, chunklength, encChunkBytes, out encChunkLength);
  153. Debug.Assert(encChunkLength == chunklength + tagLen * 2 + CHUNK_LEN_BYTES);
  154. Buffer.BlockCopy(encChunkBytes, 0, outbuf, outlength, encChunkLength);
  155. outlength += encChunkLength;
  156. logger.Trace("chunks enc outlength " + outlength);
  157. // check if we have enough space for outbuf
  158. if (outlength + TCPHandler.ChunkOverheadSize > TCPHandler.BufferSize) {
  159. logger.Trace("enc outbuf almost full, giving up");
  160. return;
  161. }
  162. bufSize = (uint)_encCircularBuffer.Size;
  163. if (bufSize <= 0) {
  164. logger.Trace("No more data to encrypt, leaving");
  165. return;
  166. }
  167. }
  168. }
  169. public override void Decrypt(byte[] buf, int length, byte[] outbuf, out int outlength)
  170. {
  171. Debug.Assert(_decCircularBuffer != null, "_decCircularBuffer != null");
  172. int bufSize;
  173. outlength = 0;
  174. // drop all into buffer
  175. _decCircularBuffer.Put(buf, 0, length);
  176. logger.Trace("---Start Decryption");
  177. if (! _decryptSaltReceived) {
  178. bufSize = _decCircularBuffer.Size;
  179. // check if we get the leading salt
  180. if (bufSize <= saltLen) {
  181. // need more
  182. return;
  183. }
  184. _decryptSaltReceived = true;
  185. byte[] salt = _decCircularBuffer.Get(saltLen);
  186. InitCipher(salt, false, false);
  187. logger.Trace("get salt len " + saltLen);
  188. }
  189. // handle chunks
  190. while (true) {
  191. bufSize = _decCircularBuffer.Size;
  192. // check if we have any data
  193. if (bufSize <= 0) {
  194. logger.Trace("No data in _decCircularBuffer");
  195. return;
  196. }
  197. // first get chunk length
  198. if (bufSize <= CHUNK_LEN_BYTES + tagLen) {
  199. // so we only have chunk length and its tag?
  200. return;
  201. }
  202. #region Chunk Decryption
  203. byte[] encLenBytes = _decCircularBuffer.Peek(CHUNK_LEN_BYTES + tagLen);
  204. uint decChunkLenLength = 0;
  205. byte[] decChunkLenBytes = new byte[CHUNK_LEN_BYTES];
  206. // try to dec chunk len
  207. cipherDecrypt(encLenBytes, CHUNK_LEN_BYTES + (uint)tagLen, decChunkLenBytes, ref decChunkLenLength);
  208. Debug.Assert(decChunkLenLength == CHUNK_LEN_BYTES);
  209. // finally we get the real chunk len
  210. ushort chunkLen = (ushort) IPAddress.NetworkToHostOrder((short)BitConverter.ToUInt16(decChunkLenBytes, 0));
  211. if (chunkLen > CHUNK_LEN_MASK)
  212. {
  213. // we get invalid chunk
  214. logger.Error($"Invalid chunk length: {chunkLen}");
  215. throw new CryptoErrorException();
  216. }
  217. logger.Trace("Get the real chunk len:" + chunkLen);
  218. bufSize = _decCircularBuffer.Size;
  219. if (bufSize < CHUNK_LEN_BYTES + tagLen /* we haven't remove them */+ chunkLen + tagLen) {
  220. logger.Trace("No more data to decrypt one chunk");
  221. return;
  222. }
  223. IncrementNonce(false);
  224. // we have enough data to decrypt one chunk
  225. // drop chunk len and its tag from buffer
  226. _decCircularBuffer.Skip(CHUNK_LEN_BYTES + tagLen);
  227. byte[] encChunkBytes = _decCircularBuffer.Get(chunkLen + tagLen);
  228. byte[] decChunkBytes = new byte[chunkLen];
  229. uint decChunkLen = 0;
  230. cipherDecrypt(encChunkBytes, chunkLen + (uint)tagLen, decChunkBytes, ref decChunkLen);
  231. Debug.Assert(decChunkLen == chunkLen);
  232. IncrementNonce(false);
  233. #endregion
  234. // output to outbuf
  235. Buffer.BlockCopy(decChunkBytes, 0, outbuf, outlength, (int) decChunkLen);
  236. outlength += (int)decChunkLen;
  237. logger.Trace("aead dec outlength " + outlength);
  238. if (outlength + 100 > TCPHandler.BufferSize)
  239. {
  240. logger.Trace("dec outbuf almost full, giving up");
  241. return;
  242. }
  243. bufSize = _decCircularBuffer.Size;
  244. // check if we already done all of them
  245. if (bufSize <= 0) {
  246. logger.Trace("No data in _decCircularBuffer, already all done");
  247. return;
  248. }
  249. }
  250. }
  251. #endregion
  252. #region UDP
  253. public override void EncryptUDP(byte[] buf, int length, byte[] outbuf, out int outlength)
  254. {
  255. // Generate salt
  256. randBytes(outbuf, saltLen);
  257. InitCipher(outbuf, true, true);
  258. uint olen = 0;
  259. lock (_udpTmpBuf) {
  260. cipherEncrypt(buf, (uint) length, _udpTmpBuf, ref olen);
  261. Debug.Assert(olen == length + tagLen);
  262. Buffer.BlockCopy(_udpTmpBuf, 0, outbuf, saltLen, (int) olen);
  263. outlength = (int) (saltLen + olen);
  264. }
  265. }
  266. public override void DecryptUDP(byte[] buf, int length, byte[] outbuf, out int outlength)
  267. {
  268. InitCipher(buf, false, true);
  269. uint olen = 0;
  270. lock (_udpTmpBuf) {
  271. // copy remaining data to first pos
  272. Buffer.BlockCopy(buf, saltLen, buf, 0, length - saltLen);
  273. cipherDecrypt(buf, (uint) (length - saltLen), _udpTmpBuf, ref olen);
  274. Buffer.BlockCopy(_udpTmpBuf, 0, outbuf, 0, (int) olen);
  275. outlength = (int) olen;
  276. }
  277. }
  278. #endregion
  279. // we know the plaintext length before encryption, so we can do it in one operation
  280. private void ChunkEncrypt(byte[] plaintext, int plainLen, byte[] ciphertext, out int cipherLen)
  281. {
  282. if (plainLen > CHUNK_LEN_MASK) {
  283. logger.Error("enc chunk too big");
  284. throw new CryptoErrorException();
  285. }
  286. // encrypt len
  287. byte[] encLenBytes = new byte[CHUNK_LEN_BYTES + tagLen];
  288. uint encChunkLenLength = 0;
  289. byte[] lenbuf = BitConverter.GetBytes((ushort) IPAddress.HostToNetworkOrder((short)plainLen));
  290. cipherEncrypt(lenbuf, CHUNK_LEN_BYTES, encLenBytes, ref encChunkLenLength);
  291. Debug.Assert(encChunkLenLength == CHUNK_LEN_BYTES + tagLen);
  292. IncrementNonce(true);
  293. // encrypt corresponding data
  294. byte[] encBytes = new byte[plainLen + tagLen];
  295. uint encBufLength = 0;
  296. cipherEncrypt(plaintext, (uint) plainLen, encBytes, ref encBufLength);
  297. Debug.Assert(encBufLength == plainLen + tagLen);
  298. IncrementNonce(true);
  299. // construct outbuf
  300. Array.Copy(encLenBytes, 0, ciphertext, 0, (int) encChunkLenLength);
  301. Buffer.BlockCopy(encBytes, 0, ciphertext, (int) encChunkLenLength, (int) encBufLength);
  302. cipherLen = (int) (encChunkLenLength + encBufLength);
  303. }
  304. }
  305. }