You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_proxy.h 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. /**
  2. * \file dnn/test/common/opr_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "test/common/deduce_layout_proxy.h"
  13. #include "test/common/exec_proxy.h"
  14. #include "test/common/inspect_type.h"
  15. #include "test/common/opr_trait.h"
  16. #include "test/common/timer.h"
  17. #include "test/common/workspace_wrapper.h"
  18. #include <algorithm>
  19. namespace megdnn {
  20. namespace test {
  21. template <typename Opr, size_t arity = OprTrait<Opr>::arity,
  22. bool has_workspace = OprTrait<Opr>::has_workspace,
  23. bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
  24. struct OprProxyDefaultImpl
  25. : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
  26. public ExecProxy<Opr, arity, has_workspace> {};
  27. template <typename Opr>
  28. struct OprProxy : public OprProxyDefaultImpl<Opr> {};
  29. template <typename Opr>
  30. struct OprProxyVectorToSingle {};
  31. template <>
  32. struct OprProxy<ElemwiseForward> {
  33. static void deduce_layout(ElemwiseForward* opr,
  34. TensorLayoutArray& layouts) {
  35. megdnn_assert(layouts.size() >= 2);
  36. auto inp = layouts;
  37. inp.pop_back();
  38. opr->deduce_layout(inp, layouts.back());
  39. }
  40. static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
  41. megdnn_assert(tensors.size() >= 2);
  42. auto inp = tensors;
  43. inp.pop_back();
  44. opr->exec(inp, tensors.back());
  45. }
  46. };
  47. template <>
  48. struct OprProxy<ElemwiseMultiType> {
  49. static void deduce_layout(ElemwiseMultiType* opr,
  50. TensorLayoutArray& layouts) {
  51. megdnn_assert(layouts.size() >= 2);
  52. auto inp = layouts;
  53. inp.pop_back();
  54. opr->deduce_layout(inp, layouts.back());
  55. }
  56. static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
  57. megdnn_assert(tensors.size() >= 2);
  58. auto inp = tensors;
  59. inp.pop_back();
  60. opr->exec(inp, tensors.back());
  61. }
  62. };
  63. template <>
  64. struct OprProxy<ConcatForward> {
  65. static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
  66. megdnn_assert(layouts.size() >= 2);
  67. auto inp = layouts;
  68. inp.pop_back();
  69. opr->deduce_layout(inp, layouts.back());
  70. }
  71. static void exec(ConcatForward* opr, const TensorNDArray& tensors) {
  72. megdnn_assert(tensors.size() >= 2);
  73. auto inp = tensors;
  74. inp.pop_back();
  75. TensorLayoutArray layouts(tensors.size());
  76. std::transform(tensors.begin(), tensors.end(), layouts.begin(),
  77. [](const TensorND& tensor) { return tensor.layout; });
  78. auto inp_layouts = layouts;
  79. inp_layouts.pop_back();
  80. WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes(
  81. inp_layouts, layouts.back()));
  82. auto inp_tensors = tensors;
  83. inp_tensors.pop_back();
  84. opr->exec(inp_tensors, tensors.back(), W.workspace());
  85. }
  86. };
  87. template <>
  88. struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
  89. static void exec(SplitForward* opr, const TensorNDArray& tensors) {
  90. megdnn_assert(tensors.size() >= 2);
  91. auto out = tensors;
  92. out.erase(out.begin());
  93. TensorLayoutArray layouts(tensors.size());
  94. std::transform(tensors.begin(), tensors.end(), layouts.begin(),
  95. [](const TensorND& tensor) { return tensor.layout; });
  96. auto out_layouts = layouts;
  97. out_layouts.erase(out_layouts.begin());
  98. WorkspaceWrapper W(
  99. opr->handle(),
  100. opr->get_workspace_in_bytes(layouts.front(), out_layouts));
  101. auto out_tensors = tensors;
  102. out_tensors.erase(out_tensors.begin());
  103. opr->exec(tensors.front(), out_tensors, W.workspace());
  104. }
  105. };
  106. //! OprProxy impl for tenary oprs with profiling support
  107. template <class Opr, int arity>
  108. struct OprProxyProfilingBase
  109. : public DeduceLayoutProxy<Opr, arity,
  110. OprTrait<Opr>::can_deduce_layout> {
  111. size_t warmup_times = 10, exec_times = 100;
  112. //! whether to enable profiling
  113. bool m_profiling;
  114. WorkspaceWrapper W;
  115. //! target algo setup by profiler; it can also be directly specified by the
  116. //! caller
  117. typename Opr::Algorithm* target_algo = nullptr;
  118. OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
  119. };
  120. template <class Opr>
  121. struct OprProxyProfilingTernary : public OprProxyProfilingBase<Opr, 3> {
  122. using Base = OprProxyProfilingBase<Opr, 3>;
  123. using OprProxyProfilingBase<Opr, 3>::OprProxyProfilingBase;
  124. void exec(Opr* opr, const TensorNDArray& tensors) {
  125. megdnn_assert(tensors.size() == 3);
  126. if (!Base::W.valid()) {
  127. Base::W = WorkspaceWrapper(opr->handle(), 0);
  128. }
  129. if (Base::m_profiling && !Base::target_algo) {
  130. size_t min_time = std::numeric_limits<size_t>::max();
  131. for (auto algo :
  132. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  133. tensors[2].layout)) {
  134. opr->execution_policy().algorithm = algo;
  135. auto workspace_size = opr->get_workspace_in_bytes(
  136. tensors[0].layout, tensors[1].layout,
  137. tensors[2].layout);
  138. Base::W.update(workspace_size);
  139. for (size_t times = 0; times < Base::warmup_times; ++times)
  140. opr->exec(tensors[0], tensors[1], tensors[2],
  141. Base::W.workspace());
  142. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  143. Timer timer;
  144. timer.start();
  145. for (size_t times = 0; times < Base::exec_times; ++times) {
  146. opr->exec(tensors[0], tensors[1], tensors[2],
  147. Base::W.workspace());
  148. }
  149. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  150. timer.stop();
  151. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  152. algo->name());
  153. if (min_time > timer.get_time_in_us()) {
  154. min_time = timer.get_time_in_us();
  155. Base::target_algo = algo;
  156. }
  157. }
  158. opr->execution_policy().algorithm = Base::target_algo;
  159. auto workspace_size = opr->get_workspace_in_bytes(
  160. tensors[0].layout, tensors[1].layout, tensors[2].layout);
  161. Base::W.update(workspace_size);
  162. }
  163. if (!Base::target_algo) {
  164. auto workspace_size = opr->get_workspace_in_bytes(
  165. tensors[0].layout, tensors[1].layout, tensors[2].layout);
  166. Base::W.update(workspace_size);
  167. }
  168. opr->exec(tensors[0], tensors[1], tensors[2], Base::W.workspace());
  169. }
  170. };
  171. #define DEF_PROF3(c) \
  172. template <> \
  173. struct OprProxy<c> : public OprProxyProfilingTernary<c> { \
  174. using OprProxyProfilingTernary<c>::OprProxyProfilingTernary; \
  175. }
  176. DEF_PROF3(ConvolutionBackwardData);
  177. DEF_PROF3(ConvolutionBackwardFilter);
  178. DEF_PROF3(LocalShareForward);
  179. DEF_PROF3(LocalShareBackwardData);
  180. DEF_PROF3(LocalShareBackwardFilter);
  181. #undef DEF_PROF3
  182. //! TODO: it should adapt weight preprocess later
  183. template <>
  184. struct OprProxy<ConvolutionForward>
  185. : public OprProxyProfilingTernary<ConvolutionForward> {
  186. using OprProxyProfilingTernary<ConvolutionForward>::OprProxyProfilingTernary;
  187. void exec(ConvolutionForward* opr, const TensorNDArray& tensors) {
  188. megdnn_assert(tensors.size() == 3);
  189. if (!Base::W.valid()) {
  190. Base::W = WorkspaceWrapper(opr->handle(), 0);
  191. }
  192. if (Base::m_profiling && !Base::target_algo) {
  193. size_t min_time = std::numeric_limits<size_t>::max();
  194. for (auto algo :
  195. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  196. tensors[2].layout)) {
  197. opr->execution_policy().algorithm = algo;
  198. auto workspace_size = opr->get_workspace_in_bytes(
  199. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  200. nullptr);
  201. Base::W.update(workspace_size);
  202. for (size_t times = 0; times < Base::warmup_times; ++times)
  203. opr->exec(tensors[0], tensors[1], tensors[2], nullptr,
  204. Base::W.workspace());
  205. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  206. Timer timer;
  207. timer.start();
  208. for (size_t times = 0; times < Base::exec_times; ++times) {
  209. opr->exec(tensors[0], tensors[1], tensors[2], nullptr,
  210. Base::W.workspace());
  211. }
  212. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  213. timer.stop();
  214. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  215. algo->name());
  216. if (min_time > timer.get_time_in_us()) {
  217. min_time = timer.get_time_in_us();
  218. Base::target_algo = algo;
  219. }
  220. }
  221. opr->execution_policy().algorithm = Base::target_algo;
  222. auto workspace_size = opr->get_workspace_in_bytes(
  223. tensors[0].layout, tensors[1].layout, tensors[2].layout, nullptr);
  224. Base::W.update(workspace_size);
  225. }
  226. if (!Base::target_algo) {
  227. auto workspace_size = opr->get_workspace_in_bytes(
  228. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  229. nullptr);
  230. Base::W.update(workspace_size);
  231. }
  232. opr->exec(tensors[0], tensors[1], tensors[2], nullptr,
  233. Base::W.workspace());
  234. }
  235. };
  236. template <class Opr>
  237. struct OprProxyProfiling5 : public OprProxyProfilingBase<Opr, 5> {
  238. using Base = OprProxyProfilingBase<Opr, 5>;
  239. using OprProxyProfilingBase<Opr, 5>::OprProxyProfilingBase;
  240. void exec(Opr* opr, const TensorNDArray& tensors) {
  241. megdnn_assert(tensors.size() == 5);
  242. if (!Base::W.valid()) {
  243. Base::W = WorkspaceWrapper(opr->handle(), 0);
  244. }
  245. if (Base::m_profiling && !Base::target_algo) {
  246. size_t min_time = std::numeric_limits<size_t>::max();
  247. for (auto algo :
  248. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  249. tensors[2].layout, tensors[3].layout,
  250. tensors[4].layout)) {
  251. opr->execution_policy().algorithm = algo;
  252. auto workspace_size = opr->get_workspace_in_bytes(
  253. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  254. tensors[3].layout, tensors[4].layout);
  255. Base::W.update(workspace_size);
  256. for (size_t times = 0; times < Base::warmup_times; ++times)
  257. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  258. tensors[4], Base::W.workspace());
  259. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  260. Timer timer;
  261. timer.start();
  262. for (size_t times = 0; times < Base::exec_times; ++times) {
  263. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  264. tensors[4], Base::W.workspace());
  265. }
  266. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  267. timer.stop();
  268. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  269. algo->name());
  270. if (min_time > timer.get_time_in_us()) {
  271. min_time = timer.get_time_in_us();
  272. Base::target_algo = algo;
  273. }
  274. }
  275. opr->execution_policy().algorithm = Base::target_algo;
  276. auto workspace_size = opr->get_workspace_in_bytes(
  277. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  278. tensors[3].layout, tensors[4].layout);
  279. Base::W.update(workspace_size);
  280. }
  281. if (!Base::target_algo) {
  282. auto workspace_size = opr->get_workspace_in_bytes(
  283. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  284. tensors[3].layout, tensors[4].layout);
  285. Base::W.update(workspace_size);
  286. }
  287. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  288. Base::W.workspace());
  289. }
  290. };
  291. #define DEF_PROF5(c) \
  292. template <> \
  293. struct OprProxy<c> : public OprProxyProfiling5<c> { \
  294. using OprProxyProfiling5<c>::OprProxyProfiling5; \
  295. }
  296. DEF_PROF5(DeformableConvForward);
  297. DEF_PROF5(DeformableConvBackwardFilter);
  298. //DEF_PROF5(ConvBiasForward);
  299. DEF_PROF5(BatchConvBiasForward);
  300. #undef DEF_PROF5
  301. //! TODO: it should adapt weight preprocess later
  302. template <>
  303. struct OprProxy<ConvBiasForward> : public OprProxyProfiling5<ConvBiasForward> {
  304. using OprProxyProfiling5<ConvBiasForward>::OprProxyProfiling5;
  305. void exec(ConvBiasForward* opr, const TensorNDArray& tensors) {
  306. megdnn_assert(tensors.size() == 5);
  307. if (!Base::W.valid()) {
  308. Base::W = WorkspaceWrapper(opr->handle(), 0);
  309. }
  310. if (Base::m_profiling && !Base::target_algo) {
  311. size_t min_time = std::numeric_limits<size_t>::max();
  312. for (auto algo :
  313. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  314. tensors[2].layout, tensors[3].layout,
  315. tensors[4].layout)) {
  316. opr->execution_policy().algorithm = algo;
  317. auto workspace_size = opr->get_workspace_in_bytes(
  318. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  319. tensors[3].layout, tensors[4].layout, nullptr);
  320. Base::W.update(workspace_size);
  321. for (size_t times = 0; times < Base::warmup_times; ++times)
  322. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  323. tensors[4], nullptr, Base::W.workspace());
  324. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  325. Timer timer;
  326. timer.start();
  327. for (size_t times = 0; times < Base::exec_times; ++times) {
  328. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  329. tensors[4], nullptr, Base::W.workspace());
  330. }
  331. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  332. timer.stop();
  333. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  334. algo->name());
  335. if (min_time > timer.get_time_in_us()) {
  336. min_time = timer.get_time_in_us();
  337. Base::target_algo = algo;
  338. }
  339. }
  340. opr->execution_policy().algorithm = Base::target_algo;
  341. auto workspace_size = opr->get_workspace_in_bytes(
  342. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  343. tensors[3].layout, tensors[4].layout, nullptr);
  344. Base::W.update(workspace_size);
  345. }
  346. if (!Base::target_algo) {
  347. auto workspace_size = opr->get_workspace_in_bytes(
  348. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  349. tensors[3].layout, tensors[4].layout, nullptr);
  350. Base::W.update(workspace_size);
  351. }
  352. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  353. nullptr, Base::W.workspace());
  354. }
  355. };
  356. template <class Opr>
  357. struct OprProxyProfiling8 : public OprProxyProfilingBase<Opr, 8> {
  358. using Base = OprProxyProfilingBase<Opr, 8>;
  359. using OprProxyProfilingBase<Opr, 8>::OprProxyProfilingBase;
  360. void exec(Opr* opr, const TensorNDArray& tensors) {
  361. megdnn_assert(tensors.size() == 8);
  362. if (!Base::W.valid()) {
  363. Base::W = WorkspaceWrapper(opr->handle(), 0);
  364. }
  365. if (Base::m_profiling && !Base::target_algo) {
  366. size_t min_time = std::numeric_limits<size_t>::max();
  367. for (auto algo : opr->get_all_algorithms(
  368. tensors[0].layout, tensors[1].layout,
  369. tensors[2].layout, tensors[3].layout,
  370. tensors[4].layout, tensors[5].layout,
  371. tensors[6].layout, tensors[7].layout)) {
  372. opr->execution_policy().algorithm = algo;
  373. auto workspace_size = opr->get_workspace_in_bytes(
  374. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  375. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  376. tensors[6].layout, tensors[7].layout);
  377. Base::W.update(workspace_size);
  378. for (size_t times = 0; times < Base::warmup_times; ++times)
  379. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  380. tensors[4], tensors[5], tensors[6], tensors[7],
  381. Base::W.workspace());
  382. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  383. Timer timer;
  384. timer.start();
  385. for (size_t times = 0; times < Base::exec_times; ++times) {
  386. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  387. tensors[4], tensors[5], tensors[6], tensors[7],
  388. Base::W.workspace());
  389. }
  390. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  391. timer.stop();
  392. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  393. algo->name());
  394. if (min_time > timer.get_time_in_us()) {
  395. min_time = timer.get_time_in_us();
  396. Base::target_algo = algo;
  397. }
  398. }
  399. opr->execution_policy().algorithm = Base::target_algo;
  400. auto workspace_size = opr->get_workspace_in_bytes(
  401. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  402. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  403. tensors[6].layout, tensors[7].layout);
  404. Base::W.update(workspace_size);
  405. }
  406. if (!Base::target_algo) {
  407. auto workspace_size = opr->get_workspace_in_bytes(
  408. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  409. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  410. tensors[6].layout, tensors[7].layout);
  411. Base::W.update(workspace_size);
  412. }
  413. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  414. tensors[5], tensors[6], tensors[7], Base::W.workspace());
  415. }
  416. };
  417. #define DEF_PROF8(c) \
  418. template <> \
  419. struct OprProxy<c> : public OprProxyProfiling8<c> { \
  420. using OprProxyProfiling8<c>::OprProxyProfiling8; \
  421. }
  422. DEF_PROF8(DeformableConvBackwardData);
  423. #undef DEF_PROF8
  424. } // namespace test
  425. } // namespace megdnn
  426. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台