You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api_cache.h 14 kB


  1. /**
  2. * \file dnn/src/common/api_cache.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <atomic>
  14. #include <cstring>
  15. #include <memory>
  16. #include <mutex>
  17. #include <tuple>
  18. #include <unordered_map>
  19. #include "megdnn/thin/function.h"
  20. #include "./utils.h"
  21. namespace megdnn {
  22. // https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/
  23. class RWSpin {
  24. public:
  25. class Lock {
  26. private:
  27. RWSpin* m_spin;
  28. void (RWSpin::*m_lock)(void);
  29. void (RWSpin::*m_unlock)(void);
  30. public:
  31. Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock)
  32. : m_spin{spin}, m_lock{lock}, m_unlock{unlock} {}
  33. void lock() { (m_spin->*m_lock)(); }
  34. void unlock() { (m_spin->*m_unlock)(); }
  35. };
  36. private:
  37. std::atomic<uint32_t> m_atomic{0};
  38. static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF;
  39. static constexpr uint32_t sm_writer_mask = 0x80000000;
  40. void _reader_lock() {
  41. uint32_t expected = m_atomic;
  42. do {
  43. expected &= sm_reader_mask;
  44. } while (!m_atomic.compare_exchange_strong(expected, expected + 1));
  45. }
  46. void _reader_unlock() { m_atomic--; }
  47. void _writer_lock() {
  48. uint32_t expected = m_atomic;
  49. do {
  50. expected &= sm_reader_mask;
  51. } while (!m_atomic.compare_exchange_strong(expected,
  52. expected | sm_writer_mask));
  53. while (m_atomic.load() != sm_writer_mask)
  54. ;
  55. }
  56. void _writer_unlock() {
  57. // assert m_atomic == sm_writer_mask
  58. m_atomic = 0;
  59. }
  60. public:
  61. Lock reader() {
  62. return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock};
  63. }
  64. Lock writer() {
  65. return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock};
  66. }
  67. };
  68. template <typename TSignature>
  69. class FunctionCache;
  70. template <typename TRet, typename... TArgs>
  71. class FunctionCache<TRet(TArgs...)> {
  72. public:
  73. using key_t = std::string;
  74. using value_t = TRet;
  75. using key_mapper_t = thin_function<key_t(TArgs...)>;
  76. using value_mapper_t = thin_function<value_t(TArgs...)>;
  77. using storage_t = std::unordered_map<key_t, value_t>;
  78. storage_t storage;
  79. key_mapper_t key_mapper;
  80. value_mapper_t value_mapper;
  81. RWSpin spin;
  82. public:
  83. TRet operator()(TArgs... args) {
  84. key_t key = key_mapper(args...);
  85. auto reader_lock = spin.reader();
  86. auto writer_lock = spin.writer();
  87. {
  88. MEGDNN_LOCK_GUARD(reader_lock);
  89. auto iter = storage.find(key);
  90. if (iter != storage.end()) {
  91. return iter->second;
  92. }
  93. }
  94. // RWSpin doesn't support upgrade
  95. {
  96. MEGDNN_LOCK_GUARD(writer_lock);
  97. if (storage.count(key) != 0) {
  98. return storage[key];
  99. }
  100. value_t ret = value_mapper(std::forward<TArgs>(args)...);
  101. storage[key] = ret;
  102. return ret;
  103. }
  104. }
  105. };
  106. // FIFO
  107. class StringSerializer {
  108. private:
  109. std::string m_buffer;
  110. size_t m_cursor = 0;
  111. public:
  112. template <typename T>
  113. T read_plain() {
  114. static_assert(std::is_trivially_copyable<T>::value, "invalid type");
  115. T ret;
  116. memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T));
  117. m_cursor += sizeof(T);
  118. return ret;
  119. }
  120. template <typename T>
  121. void write_plain(T value) {
  122. static_assert(std::is_trivially_copyable<T>::value,
  123. "type should be trivially copyable");
  124. m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T));
  125. }
  126. std::string take() { return std::move(m_buffer); }
  127. void reset(std::string new_buf) {
  128. m_cursor = 0;
  129. m_buffer = new_buf;
  130. }
  131. };
  132. struct Empty {};
  133. // in: seq[1, 2, ..., m]
  134. // out: seq[N+1, N+2, ... N+m]
  135. template <std::size_t N, std::size_t... Seq>
  136. static std::index_sequence<N + Seq...> inc_index_sequence(
  137. std::index_sequence<Seq...>) {
  138. return {};
  139. }
  140. template <typename... TParams>
  141. class ParamBundle {
  142. private:
  143. // out: Min, Min+1, ..., Max
  144. template <std::size_t Min, std::size_t Max>
  145. using make_index_range = decltype(
  146. inc_index_sequence<Min>(std::make_index_sequence<Max - Min>()));
  147. // store params in a tuple
  148. using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>;
  149. storage_t m_storage;
  150. // deconstruct tuple and call functor
  151. template <typename TFunctor, size_t... Indices>
  152. auto call_helper(TFunctor functor, std::index_sequence<Indices...>) {
  153. return functor(std::get<Indices>(m_storage).value...);
  154. }
  155. template <size_t Index, size_t... Indices, typename TPrev>
  156. auto serialize_helper(StringSerializer& ser, TPrev&& prev,
  157. std::index_sequence<Index, Indices...>) {
  158. return serialize_helper(ser,
  159. std::get<Index>(m_storage).serialize(ser, prev),
  160. std::index_sequence<Indices...>());
  161. }
  162. template <typename TPrev>
  163. auto serialize_helper(StringSerializer& ser, TPrev&& prev,
  164. std::index_sequence<>) {}
  165. template <size_t Index, size_t... Indices, typename TPrev>
  166. auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
  167. std::index_sequence<Index, Indices...>) {
  168. return deserialize_helper(
  169. ser, std::get<Index>(m_storage).deserialize(ser, prev),
  170. std::index_sequence<Indices...>());
  171. }
  172. template <typename TPrev>
  173. auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
  174. std::index_sequence<>) {}
  175. template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
  176. void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg,
  177. TArgs&&... args) {
  178. std::get<Index>(m_storage).value = arg;
  179. set_values_helper(std::index_sequence<Indices...>(),
  180. std::forward<TArgs>(args)...);
  181. }
  182. template <size_t... Indices>
  183. void set_values_helper(std::index_sequence<Indices...>) {
  184. static_assert(sizeof...(Indices) == 0, "redundant indices");
  185. }
  186. public:
  187. template <typename TFunctor>
  188. auto call_by(TFunctor&& functor) {
  189. return call_helper(std::forward<TFunctor>(functor),
  190. std::make_index_sequence<sizeof...(TParams)>());
  191. }
  192. // recursively store params into ser
  193. template <size_t NBegin, size_t NEnd>
  194. void serialize_params(StringSerializer& ser) {
  195. static_assert(NEnd >= NBegin, "invalid range");
  196. serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
  197. }
  198. // recursively load params from ser
  199. template <size_t NBegin, size_t NEnd>
  200. void deserialize_params(StringSerializer& ser) {
  201. static_assert(NEnd >= NBegin, "invalid range");
  202. deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
  203. }
  204. // recursively set params into m_storage
  205. template <size_t NBegin, size_t NEnd, typename... TArgs>
  206. void set_values(TArgs&&... args) {
  207. set_values_helper(make_index_range<NBegin, NEnd>(),
  208. std::forward<TArgs>(args)...);
  209. }
  210. };
  211. template <typename T>
  212. class Param {
  213. public:
  214. T value;
  215. Empty serialize(StringSerializer& ser, Empty) {
  216. ser.write_plain(value);
  217. return Empty{};
  218. }
  219. Empty deserialize(StringSerializer& ser, Empty) {
  220. value = ser.read_plain<T>();
  221. return Empty{};
  222. }
  223. };
  224. template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>,
  225. typename TOutputs = std::tuple<>>
  226. class FunctionCacheBuilder {
  227. private:
  228. // decl value with type of tuple-of-args
  229. static auto declargs()
  230. -> decltype(std::tuple_cat(std::declval<TInputs>(),
  231. std::declval<TOutputs>())) {
  232. return {};
  233. }
  234. template <size_t... Indices>
  235. static auto declfunction_helper(std::index_sequence<Indices...>)
  236. -> thin_function<decltype(std::declval<TRet>().value)(
  237. decltype(std::get<Indices>(declargs()).value)...)> {
  238. return {};
  239. }
  240. // decl value with type of original function
  241. static auto declfunction() {
  242. return declfunction_helper(
  243. std::make_index_sequence<std::tuple_size<TInputs>::value +
  244. std::tuple_size<TOutputs>::value>());
  245. }
  246. template <size_t... Indices>
  247. static auto declbundle_helper(std::index_sequence<Indices...>)
  248. -> ParamBundle<decltype(std::get<Indices>(declargs()))...> {
  249. return {};
  250. }
  251. // decl value with type of bundle-of-args
  252. static auto declbundle() {
  253. return declbundle_helper(
  254. std::make_index_sequence<std::tuple_size<TInputs>::value +
  255. std::tuple_size<TOutputs>::value>());
  256. }
  257. // type of original function
  258. using function_t = decltype(declfunction());
  259. // type of bundle-of-args
  260. using bundle_t = decltype(declbundle());
  261. public:
  262. // declare new return type, cannot be override
  263. template <typename TNewRet>
  264. auto ret() {
  265. static_assert(std::is_same<TRet, Param<Empty>>::value,
  266. "return value redefinition");
  267. return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{};
  268. }
  269. // declare new input
  270. template <typename TNewInput>
  271. auto input() {
  272. using TNewInputs = decltype(
  273. std::tuple_cat(std::declval<TInputs>(),
  274. std::make_tuple(std::declval<TNewInput>())));
  275. return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
  276. }
  277. // declare new output
  278. template <typename TNewOutput>
  279. auto output() {
  280. using TNewOutputs = decltype(
  281. std::tuple_cat(std::declval<TOutputs>(),
  282. std::make_tuple(std::declval<TNewOutput>())));
  283. return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
  284. }
  285. // summary
  286. template <typename TFunctor>
  287. function_t build(TFunctor func) {
  288. auto cache = std::make_shared<FunctionCache<std::string(bundle_t)>>();
  289. // bundle -> ser(in args)
  290. cache->key_mapper = [](bundle_t bundle) {
  291. StringSerializer ser;
  292. bundle.template serialize_params<0,
  293. std::tuple_size<TInputs>::value>(
  294. ser);
  295. return ser.take();
  296. };
  297. // bundle -> ser(out args)
  298. cache->value_mapper = [=](bundle_t bundle) {
  299. StringSerializer ser;
  300. TRet ret;
  301. ret.value = bundle.call_by(func);
  302. ret.serialize(ser, Empty{});
  303. bundle.template serialize_params<
  304. std::tuple_size<TInputs>::value,
  305. std::tuple_size<TInputs>::value +
  306. std::tuple_size<TOutputs>::value>(ser);
  307. return ser.take();
  308. };
  309. return [=](auto&&... args) mutable {
  310. bundle_t bundle;
  311. TRet ret;
  312. StringSerializer ser;
  313. static_assert(
  314. sizeof...(args) == std::tuple_size<TInputs>::value +
  315. std::tuple_size<TOutputs>::value,
  316. "args count mismatch");
  317. bundle.template set_values<0, sizeof...(args)>(
  318. std::forward<decltype(args)>(args)...);
  319. ser.reset((*cache)(bundle));
  320. ret.deserialize(ser, Empty{});
  321. constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
  322. constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
  323. bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>(
  324. ser);
  325. return ret.value;
  326. };
  327. }
  328. };
  329. template <typename T>
  330. class RefParam {
  331. public:
  332. T* value;
  333. Empty serialize(StringSerializer& ser, Empty) {
  334. ser.write_plain(*value);
  335. return Empty{};
  336. }
  337. Empty deserialize(StringSerializer& ser, Empty) {
  338. *value = ser.read_plain<T>();
  339. return Empty{};
  340. }
  341. };
  342. // like RefParam but return *value while ser and deser. Working with ArrayParam
  343. template <typename T>
  344. class RefArraySizeParam {
  345. public:
  346. T* value;
  347. T serialize(StringSerializer& ser, Empty) {
  348. ser.write_plain(*value);
  349. return *value;
  350. }
  351. T deserialize(StringSerializer& ser, Empty) {
  352. return *value = ser.read_plain<T>();
  353. }
  354. };
  355. // accept array length from previous param. Working with RefArraySizeParam
  356. template <typename TSize, typename TItem>
  357. class ArrayParam {
  358. public:
  359. TItem* value;
  360. Empty serialize(StringSerializer& ser, TSize size) {
  361. for (TSize i = 0; i < size; ++i) {
  362. ser.write_plain(value[i]);
  363. }
  364. return Empty{};
  365. }
  366. Empty deserialize(StringSerializer& ser, TSize size) {
  367. for (TSize i = 0; i < size; ++i) {
  368. value[i] = ser.read_plain<TItem>();
  369. }
  370. return Empty{};
  371. }
  372. };
  373. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台