You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_proxy.h 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. /**
  2. * \file dnn/test/common/opr_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "test/common/deduce_layout_proxy.h"
  13. #include "test/common/exec_proxy.h"
  14. #include "test/common/inspect_type.h"
  15. #include "test/common/opr_trait.h"
  16. #include "test/common/timer.h"
  17. #include "test/common/workspace_wrapper.h"
  18. #include <algorithm>
  19. namespace megdnn {
  20. namespace test {
  21. template <typename Opr, size_t arity = OprTrait<Opr>::arity,
  22. bool has_workspace = OprTrait<Opr>::has_workspace,
  23. bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
  24. struct OprProxyDefaultImpl
  25. : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
  26. public ExecProxy<Opr, arity, has_workspace> {};
  27. template <typename Opr>
  28. struct OprProxy : public OprProxyDefaultImpl<Opr> {};
  29. template <typename Opr>
  30. struct OprProxyVectorToSingle {};
  31. template <>
  32. struct OprProxy<ElemwiseForward> {
  33. static void deduce_layout(ElemwiseForward* opr,
  34. TensorLayoutArray& layouts) {
  35. megdnn_assert(layouts.size() >= 2);
  36. auto inp = layouts;
  37. inp.pop_back();
  38. opr->deduce_layout(inp, layouts.back());
  39. }
  40. static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
  41. megdnn_assert(tensors.size() >= 2);
  42. auto inp = tensors;
  43. inp.pop_back();
  44. opr->exec(inp, tensors.back());
  45. }
  46. };
  47. template <>
  48. struct OprProxy<ElemwiseMultiType> {
  49. static void deduce_layout(ElemwiseMultiType* opr,
  50. TensorLayoutArray& layouts) {
  51. megdnn_assert(layouts.size() >= 2);
  52. auto inp = layouts;
  53. inp.pop_back();
  54. opr->deduce_layout(inp, layouts.back());
  55. }
  56. static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
  57. megdnn_assert(tensors.size() >= 2);
  58. auto inp = tensors;
  59. inp.pop_back();
  60. opr->exec(inp, tensors.back());
  61. }
  62. };
  63. template <>
  64. struct OprProxy<ConcatForward> {
  65. static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
  66. megdnn_assert(layouts.size() >= 2);
  67. auto inp = layouts;
  68. inp.pop_back();
  69. opr->deduce_layout(inp, layouts.back());
  70. }
  71. static void exec(ConcatForward* opr, const TensorNDArray& tensors) {
  72. megdnn_assert(tensors.size() >= 2);
  73. auto inp = tensors;
  74. inp.pop_back();
  75. TensorLayoutArray layouts(tensors.size());
  76. std::transform(tensors.begin(), tensors.end(), layouts.begin(),
  77. [](const TensorND& tensor) { return tensor.layout; });
  78. auto inp_layouts = layouts;
  79. inp_layouts.pop_back();
  80. WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes(
  81. inp_layouts, layouts.back()));
  82. auto inp_tensors = tensors;
  83. inp_tensors.pop_back();
  84. opr->exec(inp_tensors, tensors.back(), W.workspace());
  85. }
  86. };
  87. template <>
  88. struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
  89. static void exec(SplitForward* opr, const TensorNDArray& tensors) {
  90. megdnn_assert(tensors.size() >= 2);
  91. auto out = tensors;
  92. out.erase(out.begin());
  93. TensorLayoutArray layouts(tensors.size());
  94. std::transform(tensors.begin(), tensors.end(), layouts.begin(),
  95. [](const TensorND& tensor) { return tensor.layout; });
  96. auto out_layouts = layouts;
  97. out_layouts.erase(out_layouts.begin());
  98. WorkspaceWrapper W(
  99. opr->handle(),
  100. opr->get_workspace_in_bytes(layouts.front(), out_layouts));
  101. auto out_tensors = tensors;
  102. out_tensors.erase(out_tensors.begin());
  103. opr->exec(tensors.front(), out_tensors, W.workspace());
  104. }
  105. };
  106. //! OprProxy impl for tenary oprs with profiling support
  107. template <class Opr, int arity>
  108. struct OprProxyProfilingBase
  109. : public DeduceLayoutProxy<Opr, arity,
  110. OprTrait<Opr>::can_deduce_layout> {
  111. size_t warmup_times = 10, exec_times = 100;
  112. //! whether to enable profiling
  113. bool m_profiling;
  114. WorkspaceWrapper W;
  115. //! target algo setup by profiler; it can also be directly specified by the
  116. //! caller
  117. typename Opr::Algorithm* target_algo = nullptr;
  118. OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
  119. };
  120. template <class Opr>
  121. struct OprProxyProfilingTernary : public OprProxyProfilingBase<Opr, 3> {
  122. using Base = OprProxyProfilingBase<Opr, 3>;
  123. using OprProxyProfilingBase<Opr, 3>::OprProxyProfilingBase;
  124. void exec(Opr* opr, const TensorNDArray& tensors) {
  125. megdnn_assert(tensors.size() == 3);
  126. if (!Base::W.valid()) {
  127. Base::W = WorkspaceWrapper(opr->handle(), 0);
  128. }
  129. if (Base::m_profiling && !Base::target_algo) {
  130. size_t min_time = std::numeric_limits<size_t>::max();
  131. for (auto algo :
  132. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  133. tensors[2].layout)) {
  134. opr->execution_policy().algorithm = algo;
  135. auto workspace_size = opr->get_workspace_in_bytes(
  136. tensors[0].layout, tensors[1].layout,
  137. tensors[2].layout);
  138. Base::W.update(workspace_size);
  139. for (size_t times = 0; times < Base::warmup_times; ++times)
  140. opr->exec(tensors[0], tensors[1], tensors[2],
  141. Base::W.workspace());
  142. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  143. Timer timer;
  144. timer.start();
  145. for (size_t times = 0; times < Base::exec_times; ++times) {
  146. opr->exec(tensors[0], tensors[1], tensors[2],
  147. Base::W.workspace());
  148. }
  149. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  150. timer.stop();
  151. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  152. algo->name());
  153. if (min_time > timer.get_time_in_us()) {
  154. min_time = timer.get_time_in_us();
  155. Base::target_algo = algo;
  156. }
  157. }
  158. opr->execution_policy().algorithm = Base::target_algo;
  159. auto workspace_size = opr->get_workspace_in_bytes(
  160. tensors[0].layout, tensors[1].layout, tensors[2].layout);
  161. Base::W.update(workspace_size);
  162. }
  163. if (!Base::target_algo) {
  164. auto workspace_size = opr->get_workspace_in_bytes(
  165. tensors[0].layout, tensors[1].layout, tensors[2].layout);
  166. Base::W.update(workspace_size);
  167. }
  168. opr->exec(tensors[0], tensors[1], tensors[2], Base::W.workspace());
  169. }
  170. };
  171. #define DEF_PROF3(c) \
  172. template <> \
  173. struct OprProxy<c> : public OprProxyProfilingTernary<c> { \
  174. using OprProxyProfilingTernary<c>::OprProxyProfilingTernary; \
  175. }
  176. DEF_PROF3(ConvolutionForward);
  177. DEF_PROF3(ConvolutionBackwardData);
  178. DEF_PROF3(ConvolutionBackwardFilter);
  179. DEF_PROF3(LocalShareForward);
  180. DEF_PROF3(LocalShareBackwardData);
  181. DEF_PROF3(LocalShareBackwardFilter);
  182. #undef DEF_PROF3
  183. template <class Opr>
  184. struct OprProxyProfiling5 : public OprProxyProfilingBase<Opr, 5> {
  185. using Base = OprProxyProfilingBase<Opr, 5>;
  186. using OprProxyProfilingBase<Opr, 5>::OprProxyProfilingBase;
  187. void exec(Opr* opr, const TensorNDArray& tensors) {
  188. megdnn_assert(tensors.size() == 5);
  189. if (!Base::W.valid()) {
  190. Base::W = WorkspaceWrapper(opr->handle(), 0);
  191. }
  192. if (Base::m_profiling && !Base::target_algo) {
  193. size_t min_time = std::numeric_limits<size_t>::max();
  194. for (auto algo :
  195. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  196. tensors[2].layout, tensors[3].layout,
  197. tensors[4].layout)) {
  198. opr->execution_policy().algorithm = algo;
  199. auto workspace_size = opr->get_workspace_in_bytes(
  200. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  201. tensors[3].layout, tensors[4].layout);
  202. Base::W.update(workspace_size);
  203. for (size_t times = 0; times < Base::warmup_times; ++times)
  204. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  205. tensors[4], Base::W.workspace());
  206. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  207. Timer timer;
  208. timer.start();
  209. for (size_t times = 0; times < Base::exec_times; ++times) {
  210. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  211. tensors[4], Base::W.workspace());
  212. }
  213. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  214. timer.stop();
  215. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  216. algo->name());
  217. if (min_time > timer.get_time_in_us()) {
  218. min_time = timer.get_time_in_us();
  219. Base::target_algo = algo;
  220. }
  221. }
  222. opr->execution_policy().algorithm = Base::target_algo;
  223. auto workspace_size = opr->get_workspace_in_bytes(
  224. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  225. tensors[3].layout, tensors[4].layout);
  226. Base::W.update(workspace_size);
  227. }
  228. if (!Base::target_algo) {
  229. auto workspace_size = opr->get_workspace_in_bytes(
  230. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  231. tensors[3].layout, tensors[4].layout);
  232. Base::W.update(workspace_size);
  233. }
  234. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  235. Base::W.workspace());
  236. }
  237. };
  238. #define DEF_PROF5(c) \
  239. template <> \
  240. struct OprProxy<c> : public OprProxyProfiling5<c> { \
  241. using OprProxyProfiling5<c>::OprProxyProfiling5; \
  242. }
  243. DEF_PROF5(DeformableConvForward);
  244. DEF_PROF5(DeformableConvBackwardFilter);
  245. DEF_PROF5(ConvBiasForward);
  246. DEF_PROF5(BatchConvBiasForward);
  247. #undef DEF_PROF5
  248. template <class Opr>
  249. struct OprProxyProfiling8 : public OprProxyProfilingBase<Opr, 8> {
  250. using Base = OprProxyProfilingBase<Opr, 8>;
  251. using OprProxyProfilingBase<Opr, 8>::OprProxyProfilingBase;
  252. void exec(Opr* opr, const TensorNDArray& tensors) {
  253. megdnn_assert(tensors.size() == 8);
  254. if (!Base::W.valid()) {
  255. Base::W = WorkspaceWrapper(opr->handle(), 0);
  256. }
  257. if (Base::m_profiling && !Base::target_algo) {
  258. size_t min_time = std::numeric_limits<size_t>::max();
  259. for (auto algo : opr->get_all_algorithms(
  260. tensors[0].layout, tensors[1].layout,
  261. tensors[2].layout, tensors[3].layout,
  262. tensors[4].layout, tensors[5].layout,
  263. tensors[6].layout, tensors[7].layout)) {
  264. opr->execution_policy().algorithm = algo;
  265. auto workspace_size = opr->get_workspace_in_bytes(
  266. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  267. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  268. tensors[6].layout, tensors[7].layout);
  269. Base::W.update(workspace_size);
  270. for (size_t times = 0; times < Base::warmup_times; ++times)
  271. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  272. tensors[4], tensors[5], tensors[6], tensors[7],
  273. Base::W.workspace());
  274. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  275. Timer timer;
  276. timer.start();
  277. for (size_t times = 0; times < Base::exec_times; ++times) {
  278. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  279. tensors[4], tensors[5], tensors[6], tensors[7],
  280. Base::W.workspace());
  281. }
  282. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  283. timer.stop();
  284. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  285. algo->name());
  286. if (min_time > timer.get_time_in_us()) {
  287. min_time = timer.get_time_in_us();
  288. Base::target_algo = algo;
  289. }
  290. }
  291. opr->execution_policy().algorithm = Base::target_algo;
  292. auto workspace_size = opr->get_workspace_in_bytes(
  293. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  294. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  295. tensors[6].layout, tensors[7].layout);
  296. Base::W.update(workspace_size);
  297. }
  298. if (!Base::target_algo) {
  299. auto workspace_size = opr->get_workspace_in_bytes(
  300. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  301. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  302. tensors[6].layout, tensors[7].layout);
  303. Base::W.update(workspace_size);
  304. }
  305. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  306. tensors[5], tensors[6], tensors[7], Base::W.workspace());
  307. }
  308. };
  309. #define DEF_PROF8(c) \
  310. template <> \
  311. struct OprProxy<c> : public OprProxyProfiling8<c> { \
  312. using OprProxyProfiling8<c>::OprProxyProfiling8; \
  313. }
  314. DEF_PROF8(DeformableConvBackwardData);
  315. #undef DEF_PROF8
  316. } // namespace test
  317. } // namespace megdnn
  318. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台