You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api_cache.h 11 kB


  1. /**
  2. * \file dnn/src/common/api_cache.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <cstring>
  14. #include <memory>
  15. #include <tuple>
  16. #include <unordered_map>
  17. #include "megdnn/thin/function.h"
  18. namespace megdnn {
  19. template <typename... TArgs>
  20. class FunctionCache {
  21. public:
  22. using key_t = std::string;
  23. using value_t = std::string;
  24. using key_mapper_t = thin_function<key_t(TArgs...)>;
  25. using value_mapper_t = thin_function<value_t(TArgs...)>;
  26. using storage_t = std::unordered_map<key_t, value_t>;
  27. storage_t storage;
  28. key_mapper_t key_mapper;
  29. value_mapper_t value_mapper;
  30. value_t operator()(TArgs... args) {
  31. key_t key = key_mapper(args...);
  32. if (storage.count(key) == 0) {
  33. storage[key] = value_mapper(std::forward<TArgs>(args)...);
  34. }
  35. return storage[key];
  36. }
  37. };
  38. // FIFO
  39. class StringSerializer {
  40. private:
  41. std::string m_buffer;
  42. size_t m_cursor = 0;
  43. public:
  44. template <typename T>
  45. T read_plain() {
  46. static_assert(std::is_trivially_copyable<T>::value, "invalid type");
  47. T ret;
  48. memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T));
  49. m_cursor += sizeof(T);
  50. return ret;
  51. }
  52. template <typename T>
  53. void write_plain(T value) {
  54. static_assert(std::is_trivially_copyable<T>::value,
  55. "type should be trivially copyable");
  56. m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T));
  57. }
  58. std::string take() {
  59. return std::move(m_buffer);
  60. }
  61. void set(std::string new_buf) {
  62. m_cursor = 0;
  63. m_buffer = new_buf;
  64. }
  65. };
  66. struct Empty {};
  67. template <typename... TParams>
  68. class ParamBundle {
  69. private:
  70. template <std::size_t N, std::size_t... Seq>
  71. static std::index_sequence<N + Seq...> add_all(
  72. std::index_sequence<Seq...>) {
  73. return {};
  74. }
  75. template <std::size_t Min, std::size_t Max>
  76. using make_index_range =
  77. decltype(add_all<Min>(std::make_index_sequence<Max - Min>()));
  78. using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>;
  79. storage_t m_storage;
  80. template <typename TFunctor, size_t... Indices>
  81. auto call_helper(TFunctor functor, std::index_sequence<Indices...>) {
  82. return functor(std::get<Indices>(m_storage).value...);
  83. }
  84. template <size_t Index, size_t... Indices, typename TPrev>
  85. auto serialize_helper(StringSerializer& ser, TPrev&& prev,
  86. std::index_sequence<Index, Indices...>) {
  87. return serialize_helper(ser,
  88. std::get<Index>(m_storage).serialize(ser, prev),
  89. std::index_sequence<Indices...>());
  90. }
  91. template <typename TPrev>
  92. auto serialize_helper(StringSerializer& ser, TPrev&& prev,
  93. std::index_sequence<>) {}
  94. template <size_t Index, size_t... Indices, typename TPrev>
  95. auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
  96. std::index_sequence<Index, Indices...>) {
  97. return deserialize_helper(
  98. ser, std::get<Index>(m_storage).deserialize(ser, prev),
  99. std::index_sequence<Indices...>());
  100. }
  101. template <typename TPrev>
  102. auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
  103. std::index_sequence<>) {}
  104. template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
  105. void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg,
  106. TArgs&&... args) {
  107. std::get<Index>(m_storage).value = arg;
  108. set_values_helper(std::index_sequence<Indices...>(),
  109. std::forward<TArgs>(args)...);
  110. }
  111. template <size_t... Indices>
  112. void set_values_helper(std::index_sequence<Indices...>) {
  113. static_assert(sizeof...(Indices) == 0, "redundant indices");
  114. }
  115. public:
  116. template <typename TFunctor>
  117. auto call_by(TFunctor&& functor) {
  118. return call_helper(std::forward<TFunctor>(functor),
  119. std::make_index_sequence<sizeof...(TParams)>());
  120. }
  121. template <size_t NBegin, size_t NEnd>
  122. void serialize_params(StringSerializer& ser) {
  123. static_assert(NEnd >= NBegin, "invalid range");
  124. serialize_helper(
  125. ser, Empty{},
  126. add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()));
  127. }
  128. template <size_t NBegin, size_t NEnd>
  129. void deserialize_params(StringSerializer& ser) {
  130. static_assert(NEnd >= NBegin, "invalid range");
  131. deserialize_helper(
  132. ser, Empty{},
  133. add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()));
  134. }
  135. template <size_t NBegin, size_t NEnd, typename... TArgs>
  136. void set_values(TArgs&&... args) {
  137. set_values_helper(
  138. add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()),
  139. std::forward<TArgs>(args)...);
  140. }
  141. };
  142. template <typename T>
  143. class Param {
  144. public:
  145. T value;
  146. Empty serialize(StringSerializer& ser, Empty) {
  147. ser.write_plain(value);
  148. return Empty{};
  149. }
  150. Empty deserialize(StringSerializer& ser, Empty) {
  151. value = ser.read_plain<T>();
  152. return Empty{};
  153. }
  154. };
  155. template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>,
  156. typename TOutputs = std::tuple<>>
  157. class FunctionCacheBuilder {
  158. private:
  159. static auto declargs()
  160. -> decltype(std::tuple_cat(std::declval<TInputs>(),
  161. std::declval<TOutputs>())) {
  162. return {};
  163. }
  164. template <size_t... Indices>
  165. static auto declfunction_helper(std::index_sequence<Indices...>)
  166. -> thin_function<decltype(std::declval<TRet>().value)(
  167. decltype(std::get<Indices>(declargs()).value)...)> {
  168. return {};
  169. }
  170. static auto declfunction() {
  171. return declfunction_helper(
  172. std::make_index_sequence<std::tuple_size<TInputs>::value +
  173. std::tuple_size<TOutputs>::value>());
  174. }
  175. template <size_t... Indices>
  176. static auto declbundle_helper(std::index_sequence<Indices...>)
  177. -> ParamBundle<decltype(std::get<Indices>(declargs()))...> {
  178. return {};
  179. }
  180. static auto declbundle() {
  181. return declbundle_helper(
  182. std::make_index_sequence<std::tuple_size<TInputs>::value +
  183. std::tuple_size<TOutputs>::value>());
  184. }
  185. using function_t = decltype(declfunction());
  186. using bundle_t = decltype(declbundle());
  187. public:
  188. template <typename TNewRet>
  189. auto ret() {
  190. static_assert(std::is_same<TRet, Param<Empty>>::value,
  191. "return value redefinition");
  192. return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{};
  193. }
  194. template <typename TNewInput>
  195. auto input() {
  196. using TNewInputs = decltype(
  197. std::tuple_cat(std::declval<TInputs>(),
  198. std::make_tuple(std::declval<TNewInput>())));
  199. return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
  200. }
  201. template <typename TNewOutput>
  202. auto output() {
  203. using TNewOutputs = decltype(
  204. std::tuple_cat(std::declval<TOutputs>(),
  205. std::make_tuple(std::declval<TNewOutput>())));
  206. return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
  207. }
  208. template <typename TFunctor>
  209. function_t build(TFunctor func) {
  210. FunctionCache<bundle_t> cache;
  211. cache.key_mapper = [](bundle_t bundle) {
  212. StringSerializer ser;
  213. bundle.template serialize_params<0,
  214. std::tuple_size<TInputs>::value>(
  215. ser);
  216. return ser.take();
  217. };
  218. cache.value_mapper = [=](bundle_t bundle) {
  219. StringSerializer ser;
  220. TRet ret;
  221. ret.value = bundle.call_by(func);
  222. ret.serialize(ser, Empty{});
  223. bundle.template serialize_params<
  224. std::tuple_size<TInputs>::value,
  225. std::tuple_size<TInputs>::value +
  226. std::tuple_size<TOutputs>::value>(ser);
  227. return ser.take();
  228. };
  229. return [=](auto&&... args) mutable {
  230. bundle_t bundle;
  231. TRet ret;
  232. StringSerializer ser;
  233. static_assert(
  234. sizeof...(args) == std::tuple_size<TInputs>::value +
  235. std::tuple_size<TOutputs>::value,
  236. "args count mismatch");
  237. bundle.template set_values<0, sizeof...(args)>(
  238. std::forward<decltype(args)>(args)...);
  239. ser.set(cache(bundle));
  240. ret.deserialize(ser, Empty{});
  241. constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
  242. constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
  243. bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>(
  244. ser);
  245. return ret.value;
  246. };
  247. }
  248. };
  249. template <typename T>
  250. class RefParam {
  251. public:
  252. T* value;
  253. Empty serialize(StringSerializer& ser, Empty) {
  254. ser.write_plain(*value);
  255. return Empty{};
  256. }
  257. Empty deserialize(StringSerializer& ser, Empty) {
  258. *value = ser.read_plain<T>();
  259. return Empty{};
  260. }
  261. };
  262. template <typename T>
  263. class RefArraySizeParam {
  264. public:
  265. T* value;
  266. T serialize(StringSerializer& ser, Empty) {
  267. ser.write_plain(*value);
  268. return *value;
  269. }
  270. T deserialize(StringSerializer& ser, Empty) {
  271. return *value = ser.read_plain<T>();
  272. }
  273. };
  274. template <typename TSize, typename TItem>
  275. class ArrayParam {
  276. public:
  277. TItem* value;
  278. Empty serialize(StringSerializer& ser, TSize size) {
  279. for (TSize i = 0; i < size; ++i) {
  280. ser.write_plain(value[i]);
  281. }
  282. return Empty{};
  283. }
  284. Empty deserialize(StringSerializer& ser, TSize size) {
  285. for (TSize i = 0; i < size; ++i) {
  286. value[i] = ser.read_plain<TItem>();
  287. }
  288. return Empty{};
  289. }
  290. };
  291. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台