You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

redis_cache.cpp 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. /**
  2. * \file lite/src/mge/algo_cache/redis_cache.cpp
  3. *
  4. * This file is part of MegEngine, a deep learning framework developed by
  5. * Megvii.
  6. *
  7. * \copyright Copyright (c) 2020-2020 Megvii Inc. All rights reserved.
  8. */
  9. #include "lite_build_config.h"
  10. #if !defined(WIN32) && LITE_BUILD_WITH_MGE && LITE_WITH_CUDA
  11. #include "../../misc.h"
  12. #include "redis_cache.h"
  13. #include <iostream>
  14. #include <vector>
  15. namespace {
  16. /*
  17. ** Translation Table as described in RFC1113
  18. */
  19. static const char cb64[] =
  20. "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
  21. /*
  22. ** Translation Table to decode:
  23. *https://github.com/dgiardini/imgcalkap/blob/master/base64.c
  24. */
  25. static const char cd64[] =
  26. "|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`"
  27. "abcdefghijklmnopq";
  28. /*
  29. ** encodeblock
  30. **
  31. ** encode 3 8-bit binary bytes as 4 '6-bit' characters
  32. */
  33. void encodeblock(unsigned char in[3], unsigned char out[4], int len) {
  34. out[0] = cb64[in[0] >> 2];
  35. out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)];
  36. out[2] = (unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) |
  37. ((in[2] & 0xc0) >> 6)]
  38. : '=');
  39. out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '=');
  40. }
  41. /*
  42. ** decodeblock
  43. **
  44. ** decode 4 '6-bit' characters into 3 8-bit binary bytes
  45. */
  46. void decodeblock(unsigned char in[4], unsigned char out[3]) {
  47. out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4);
  48. out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2);
  49. out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]);
  50. }
  51. /**
  52. * Encode string to base64 string
  53. * @param input - source string
  54. * @param outdata - target base64 string
  55. * @param linesize - max size of line
  56. */
  57. void encode(const std::vector<std::uint8_t>& input,
  58. std::vector<std::uint8_t>& outdata, int linesize = 76) {
  59. outdata.clear();
  60. unsigned char in[3], out[4];
  61. int i, len, blocksout = 0;
  62. size_t j = 0;
  63. auto* indata = reinterpret_cast<const unsigned char*>(input.data());
  64. unsigned int insize = input.size();
  65. while (j <= insize) {
  66. len = 0;
  67. for (i = 0; i < 3; i++) {
  68. in[i] = (unsigned char)indata[j];
  69. j++;
  70. if (j <= insize) {
  71. len++;
  72. } else {
  73. in[i] = 0;
  74. }
  75. }
  76. if (len) {
  77. encodeblock(in, out, len);
  78. for (i = 0; i < 4; i++) {
  79. outdata.push_back(out[i]);
  80. }
  81. blocksout++;
  82. }
  83. if (blocksout >= (linesize / 4) || (j == insize)) {
  84. if (blocksout) {
  85. outdata.push_back('\r');
  86. outdata.push_back('\n');
  87. }
  88. blocksout = 0;
  89. }
  90. }
  91. }
  92. /**
  93. * Decode base64 string ot source
  94. * @param input - base64 string
  95. * @param outdata - source string
  96. */
  97. void decode(const std::vector<std::uint8_t>& input,
  98. std::vector<std::uint8_t>& outdata) {
  99. outdata.clear();
  100. unsigned char in[4], out[3], v;
  101. int i, len;
  102. size_t j = 0;
  103. auto* indata = reinterpret_cast<const unsigned char*>(input.data());
  104. unsigned int insize = input.size();
  105. while (j <= insize) {
  106. for (len = 0, i = 0; i < 4 && (j <= insize); i++) {
  107. v = 0;
  108. while ((j <= insize) && v == 0) {
  109. v = (unsigned char)indata[j++];
  110. v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]);
  111. if (v) {
  112. v = (unsigned char)((v == '$') ? 0 : v - 61);
  113. }
  114. }
  115. if (j <= insize) {
  116. len++;
  117. if (v) {
  118. in[i] = (unsigned char)(v - 1);
  119. }
  120. } else {
  121. in[i] = 0;
  122. }
  123. }
  124. if (len) {
  125. decodeblock(in, out);
  126. for (i = 0; i < len - 1; i++) {
  127. outdata.push_back(out[i]);
  128. }
  129. }
  130. }
  131. }
  132. /**
  133. * Encode binary data to base64 buffer
  134. * @param input - source data
  135. * @param outdata - target base64 buffer
  136. * @param linesize
  137. */
  138. void encode(const std::string& input, std::string& outdata, int linesize = 76) {
  139. std::vector<std::uint8_t> out;
  140. std::vector<std::uint8_t> in(input.begin(), input.end());
  141. encode(in, out, linesize);
  142. outdata = std::string(out.begin(), out.end());
  143. }
  144. /**
  145. * Decode base64 buffer to source binary data
  146. * @param input - base64 buffer
  147. * @param outdata - source binary data
  148. */
  149. void decode(const std::string& input, std::string& outdata) {
  150. std::vector<std::uint8_t> in(input.begin(), input.end());
  151. std::vector<std::uint8_t> out;
  152. decode(in, out);
  153. outdata = std::string(out.begin(), out.end());
  154. }
  155. } // namespace
  156. using namespace lite;
  157. RedisCache::RedisCache(std::string redis_ip, size_t port, std::string password)
  158. : m_ip(redis_ip), m_port(port), m_password(password) {
  159. m_client.auth(password);
  160. m_client.connect(
  161. m_ip, m_port,
  162. [](const std::string& host, std::size_t port,
  163. cpp_redis::connect_state status) {
  164. if (status == cpp_redis::connect_state::dropped) {
  165. LITE_LOG("client disconnected from %s.", host.c_str());
  166. LITE_LOG("Redis server connect to %s :%zu failed.",
  167. host.c_str(), port);
  168. }
  169. },
  170. std::uint32_t(200));
  171. }
  172. mgb::Maybe<mgb::PersistentCache::Blob> RedisCache::get(
  173. const std::string& category, const mgb::PersistentCache::Blob& key) {
  174. LITE_LOCK_GUARD(m_mtx);
  175. if (m_old == nullptr) {
  176. return mgb::None;
  177. }
  178. auto mem_result = m_old->get(category, key);
  179. if (mem_result.valid())
  180. return mem_result;
  181. std::string key_str(static_cast<const char*>(key.ptr), key.size);
  182. std::string redis_key_str;
  183. encode(category + '@' + key_str, redis_key_str, 24);
  184. auto result = m_client.get(redis_key_str);
  185. m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(100));
  186. LITE_ASSERT(is_valid());
  187. auto content = result.get();
  188. if (content.is_null())
  189. return mgb::None;
  190. std::string decode_content;
  191. decode(content.as_string(), decode_content);
  192. m_old->put(category, key, {decode_content.data(), decode_content.length()});
  193. return m_old->get(category, key);
  194. }
  195. void RedisCache::put(const std::string& category, const Blob& key,
  196. const mgb::PersistentCache::Blob& value) {
  197. // ScopedTimer t1(std::string("put") + category);
  198. LITE_LOCK_GUARD(m_mtx);
  199. std::string key_str(static_cast<const char*>(key.ptr), key.size);
  200. std::string redis_key_str;
  201. encode(category + '@' + key_str, redis_key_str);
  202. std::string value_str(static_cast<const char*>(value.ptr), value.size);
  203. std::string redis_value_str;
  204. encode(value_str, redis_value_str);
  205. auto result = m_client.set(redis_key_str, redis_value_str);
  206. if (m_old == nullptr) {
  207. return;
  208. }
  209. m_old->put(category, key, value);
  210. m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(100));
  211. LITE_ASSERT(is_valid());
  212. }
  213. #endif
  214. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台