You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_proxy.h 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. /**
  2. * \file dnn/test/common/opr_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "test/common/deduce_layout_proxy.h"
  13. #include "test/common/exec_proxy.h"
  14. #include "test/common/inspect_type.h"
  15. #include "test/common/opr_trait.h"
  16. #include "test/common/timer.h"
  17. #include "test/common/workspace_wrapper.h"
  18. #include <algorithm>
  19. #include <memory>
  20. namespace megdnn {
  21. namespace test {
  22. template <typename Opr, size_t arity = OprTrait<Opr>::arity,
  23. bool has_workspace = OprTrait<Opr>::has_workspace,
  24. bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
  25. struct OprProxyDefaultImpl
  26. : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
  27. public ExecProxy<Opr, arity, has_workspace> {};
  28. template <typename Opr>
  29. struct OprProxy : public OprProxyDefaultImpl<Opr> {};
  30. template <typename Opr>
  31. struct OprWeightPreprocessProxy : public OprProxyDefaultImpl<Opr> {};
  32. template <typename Opr>
  33. struct OprProxyVectorToSingle {};
  34. template <>
  35. struct OprProxy<ElemwiseForward> {
  36. static void deduce_layout(ElemwiseForward* opr,
  37. TensorLayoutArray& layouts) {
  38. megdnn_assert(layouts.size() >= 2);
  39. auto inp = layouts;
  40. inp.pop_back();
  41. opr->deduce_layout(inp, layouts.back());
  42. }
  43. static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
  44. megdnn_assert(tensors.size() >= 2);
  45. auto inp = tensors;
  46. inp.pop_back();
  47. opr->exec(inp, tensors.back());
  48. }
  49. };
  50. template <>
  51. struct OprProxy<ElemwiseMultiType> {
  52. static void deduce_layout(ElemwiseMultiType* opr,
  53. TensorLayoutArray& layouts) {
  54. megdnn_assert(layouts.size() >= 2);
  55. auto inp = layouts;
  56. inp.pop_back();
  57. opr->deduce_layout(inp, layouts.back());
  58. }
  59. static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
  60. megdnn_assert(tensors.size() >= 2);
  61. auto inp = tensors;
  62. inp.pop_back();
  63. opr->exec(inp, tensors.back());
  64. }
  65. };
  66. template <>
  67. struct OprProxy<ConcatForward> {
  68. static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
  69. megdnn_assert(layouts.size() >= 2);
  70. auto inp = layouts;
  71. inp.pop_back();
  72. opr->deduce_layout(inp, layouts.back());
  73. }
  74. static void exec(ConcatForward* opr, const TensorNDArray& tensors) {
  75. megdnn_assert(tensors.size() >= 2);
  76. auto inp = tensors;
  77. inp.pop_back();
  78. TensorLayoutArray layouts(tensors.size());
  79. std::transform(tensors.begin(), tensors.end(), layouts.begin(),
  80. [](const TensorND& tensor) { return tensor.layout; });
  81. auto inp_layouts = layouts;
  82. inp_layouts.pop_back();
  83. WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes(
  84. inp_layouts, layouts.back()));
  85. auto inp_tensors = tensors;
  86. inp_tensors.pop_back();
  87. opr->exec(inp_tensors, tensors.back(), W.workspace());
  88. }
  89. };
  90. template <>
  91. struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
  92. static void exec(SplitForward* opr, const TensorNDArray& tensors) {
  93. megdnn_assert(tensors.size() >= 2);
  94. auto out = tensors;
  95. out.erase(out.begin());
  96. TensorLayoutArray layouts(tensors.size());
  97. std::transform(tensors.begin(), tensors.end(), layouts.begin(),
  98. [](const TensorND& tensor) { return tensor.layout; });
  99. auto out_layouts = layouts;
  100. out_layouts.erase(out_layouts.begin());
  101. WorkspaceWrapper W(
  102. opr->handle(),
  103. opr->get_workspace_in_bytes(layouts.front(), out_layouts));
  104. auto out_tensors = tensors;
  105. out_tensors.erase(out_tensors.begin());
  106. opr->exec(tensors.front(), out_tensors, W.workspace());
  107. }
  108. };
  109. //! OprProxy impl for tenary oprs with profiling support
  110. template <class Opr, int arity>
  111. struct OprProxyProfilingBase
  112. : public DeduceLayoutProxy<Opr, arity,
  113. OprTrait<Opr>::can_deduce_layout> {
  114. size_t warmup_times = 10, exec_times = 100;
  115. //! whether to enable profiling
  116. bool m_profiling;
  117. WorkspaceWrapper W;
  118. //! target algo setup by profiler; it can also be directly specified by the
  119. //! caller
  120. typename Opr::Algorithm* target_algo = nullptr;
  121. OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
  122. //! used for alloc tensor for weight preprocess
  123. static std::shared_ptr<TensorNDArray> alloc_tensors(
  124. Handle* handle, const TensorLayoutArray& layouts) {
  125. auto deleter = [handle](TensorNDArray* ptr) {
  126. for (auto&& i : *ptr) {
  127. auto pdata = static_cast<dt_byte*>(i.raw_ptr) +
  128. i.layout.span().low_byte;
  129. megdnn_free(handle, pdata);
  130. }
  131. delete ptr;
  132. };
  133. std::shared_ptr<TensorNDArray> ret{new TensorNDArray, deleter};
  134. for (size_t i = 0; i < layouts.size(); ++i) {
  135. auto span = layouts[i].span();
  136. ret->emplace_back(static_cast<dt_byte*>(
  137. megdnn_malloc(handle, span.dist_byte())) -
  138. span.low_byte,
  139. layouts[i]);
  140. }
  141. return ret;
  142. }
  143. };
  144. template <class Opr>
  145. struct OprProxyProfilingTernary : public OprProxyProfilingBase<Opr, 3> {
  146. using Base = OprProxyProfilingBase<Opr, 3>;
  147. using OprProxyProfilingBase<Opr, 3>::OprProxyProfilingBase;
  148. void exec(Opr* opr, const TensorNDArray& tensors) {
  149. megdnn_assert(tensors.size() == 3);
  150. if (!Base::W.valid()) {
  151. Base::W = WorkspaceWrapper(opr->handle(), 0);
  152. }
  153. if (Base::m_profiling && !Base::target_algo) {
  154. size_t min_time = std::numeric_limits<size_t>::max();
  155. for (auto algo :
  156. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  157. tensors[2].layout)) {
  158. opr->execution_policy().algorithm = algo;
  159. auto workspace_size = opr->get_workspace_in_bytes(
  160. tensors[0].layout, tensors[1].layout,
  161. tensors[2].layout);
  162. Base::W.update(workspace_size);
  163. for (size_t times = 0; times < Base::warmup_times; ++times)
  164. opr->exec(tensors[0], tensors[1], tensors[2],
  165. Base::W.workspace());
  166. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  167. Timer timer;
  168. timer.start();
  169. for (size_t times = 0; times < Base::exec_times; ++times) {
  170. opr->exec(tensors[0], tensors[1], tensors[2],
  171. Base::W.workspace());
  172. }
  173. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  174. timer.stop();
  175. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  176. algo->name());
  177. if (min_time > timer.get_time_in_us()) {
  178. min_time = timer.get_time_in_us();
  179. Base::target_algo = algo;
  180. }
  181. }
  182. opr->execution_policy().algorithm = Base::target_algo;
  183. auto workspace_size = opr->get_workspace_in_bytes(
  184. tensors[0].layout, tensors[1].layout, tensors[2].layout);
  185. Base::W.update(workspace_size);
  186. }
  187. if (!Base::target_algo) {
  188. auto workspace_size = opr->get_workspace_in_bytes(
  189. tensors[0].layout, tensors[1].layout, tensors[2].layout);
  190. Base::W.update(workspace_size);
  191. }
  192. opr->exec(tensors[0], tensors[1], tensors[2], Base::W.workspace());
  193. }
  194. };
  195. #define DEF_PROF3(c) \
  196. template <> \
  197. struct OprProxy<c> : public OprProxyProfilingTernary<c> { \
  198. using OprProxyProfilingTernary<c>::OprProxyProfilingTernary; \
  199. }
  200. DEF_PROF3(ConvolutionBackwardData);
  201. DEF_PROF3(ConvolutionBackwardFilter);
  202. DEF_PROF3(LocalShareForward);
  203. DEF_PROF3(LocalShareBackwardData);
  204. DEF_PROF3(LocalShareBackwardFilter);
  205. #undef DEF_PROF3
  206. template <>
  207. struct OprProxy<ConvolutionForward>
  208. : public OprProxyProfilingTernary<ConvolutionForward> {
  209. using OprProxyProfilingTernary<ConvolutionForward>::OprProxyProfilingTernary;
  210. void exec(ConvolutionForward* opr, const TensorNDArray& tensors) {
  211. megdnn_assert(tensors.size() == 3);
  212. if (!Base::W.valid()) {
  213. Base::W = WorkspaceWrapper(opr->handle(), 0);
  214. }
  215. if (Base::m_profiling && !Base::target_algo) {
  216. size_t min_time = std::numeric_limits<size_t>::max();
  217. for (auto algo :
  218. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  219. tensors[2].layout)) {
  220. opr->execution_policy().algorithm = algo;
  221. auto workspace_size = opr->get_workspace_in_bytes(
  222. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  223. nullptr);
  224. Base::W.update(workspace_size);
  225. for (size_t times = 0; times < Base::warmup_times; ++times)
  226. opr->exec(tensors[0], tensors[1], tensors[2], nullptr,
  227. Base::W.workspace());
  228. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  229. Timer timer;
  230. timer.start();
  231. for (size_t times = 0; times < Base::exec_times; ++times) {
  232. opr->exec(tensors[0], tensors[1], tensors[2], nullptr,
  233. Base::W.workspace());
  234. }
  235. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  236. timer.stop();
  237. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  238. algo->name());
  239. if (min_time > timer.get_time_in_us()) {
  240. min_time = timer.get_time_in_us();
  241. Base::target_algo = algo;
  242. }
  243. }
  244. opr->execution_policy().algorithm = Base::target_algo;
  245. auto workspace_size = opr->get_workspace_in_bytes(
  246. tensors[0].layout, tensors[1].layout, tensors[2].layout, nullptr);
  247. Base::W.update(workspace_size);
  248. }
  249. if (!Base::target_algo) {
  250. auto workspace_size = opr->get_workspace_in_bytes(
  251. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  252. nullptr);
  253. Base::W.update(workspace_size);
  254. }
  255. opr->exec(tensors[0], tensors[1], tensors[2], nullptr,
  256. Base::W.workspace());
  257. }
  258. };
  259. template <>
  260. struct OprWeightPreprocessProxy<ConvolutionForward>
  261. : public OprProxyProfilingTernary<ConvolutionForward> {
  262. using OprProxyProfilingTernary<ConvolutionForward>::OprProxyProfilingTernary;
  263. void exec(ConvolutionForward* opr, const TensorNDArray& tensors) {
  264. megdnn_assert(tensors.size() == 3);
  265. if (!Base::W.valid()) {
  266. Base::W = WorkspaceWrapper(opr->handle(), 0);
  267. }
  268. if (Base::m_profiling && !Base::target_algo) {
  269. size_t min_time = std::numeric_limits<size_t>::max();
  270. for (auto algo :
  271. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  272. tensors[2].layout)) {
  273. opr->execution_policy().algorithm = algo;
  274. auto preprocess_tensors = weight_prerocess(opr, tensors, algo);
  275. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  276. ConvolutionForward::PreprocessedFilter preprocessed_filter{
  277. algo, *preprocess_tensors};
  278. auto workspace_size = opr->get_workspace_in_bytes(
  279. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  280. &preprocessed_filter);
  281. Base::W.update(workspace_size);
  282. for (size_t times = 0; times < Base::warmup_times; ++times)
  283. opr->exec(tensors[0], tensors[1], tensors[2],
  284. &preprocessed_filter, Base::W.workspace());
  285. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  286. Timer timer;
  287. timer.start();
  288. for (size_t times = 0; times < Base::exec_times; ++times) {
  289. opr->exec(tensors[0], tensors[1], tensors[2],
  290. &preprocessed_filter, Base::W.workspace());
  291. }
  292. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  293. timer.stop();
  294. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  295. algo->name());
  296. if (min_time > timer.get_time_in_us()) {
  297. min_time = timer.get_time_in_us();
  298. Base::target_algo = algo;
  299. }
  300. }
  301. opr->execution_policy().algorithm = Base::target_algo;
  302. auto preprocess_tensors =
  303. weight_prerocess(opr, tensors, Base::target_algo);
  304. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  305. ConvolutionForward::PreprocessedFilter preprocessed_filter{
  306. Base::target_algo, *preprocess_tensors};
  307. auto workspace_size = opr->get_workspace_in_bytes(
  308. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  309. &preprocessed_filter);
  310. Base::W.update(workspace_size);
  311. }
  312. auto preprocess_tensors =
  313. weight_prerocess(opr, tensors, Base::target_algo);
  314. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  315. ConvolutionForward::PreprocessedFilter preprocessed_filter{
  316. Base::target_algo, *preprocess_tensors};
  317. if (!Base::target_algo) {
  318. auto workspace_size = opr->get_workspace_in_bytes(
  319. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  320. &preprocessed_filter);
  321. Base::W.update(workspace_size);
  322. }
  323. opr->exec(tensors[0], tensors[1], tensors[2], &preprocessed_filter,
  324. Base::W.workspace());
  325. }
  326. //! handle weight preprocess
  327. std::shared_ptr<TensorNDArray> weight_prerocess(
  328. ConvolutionForward* opr, const TensorNDArray& tensors,
  329. ConvolutionForward::Algorithm* algo) {
  330. auto weight_perprocess_layouts = opr->deduce_preprocessed_filter_layout(
  331. tensors[0].layout, tensors[1].layout, tensors[2].layout);
  332. auto preprocessed_filter_tensors_ptr =
  333. alloc_tensors(opr->handle(), weight_perprocess_layouts);
  334. ConvolutionForward::PreprocessedFilter preprocessed_filter{
  335. algo, *preprocessed_filter_tensors_ptr};
  336. size_t preprocess_workspace_size =
  337. opr->get_preprocess_workspace_in_bytes(tensors[0].layout,
  338. tensors[1].layout,
  339. tensors[2].layout);
  340. WorkspaceWrapper preprocess_workspace(opr->handle(),
  341. preprocess_workspace_size);
  342. opr->exec_preprocess(tensors[0].layout, tensors[1], tensors[2].layout,
  343. &preprocessed_filter,
  344. preprocess_workspace.workspace());
  345. return preprocessed_filter_tensors_ptr;
  346. }
  347. };
  348. template <class Opr>
  349. struct OprProxyProfiling5 : public OprProxyProfilingBase<Opr, 5> {
  350. using Base = OprProxyProfilingBase<Opr, 5>;
  351. using OprProxyProfilingBase<Opr, 5>::OprProxyProfilingBase;
  352. void exec(Opr* opr, const TensorNDArray& tensors) {
  353. megdnn_assert(tensors.size() == 5);
  354. if (!Base::W.valid()) {
  355. Base::W = WorkspaceWrapper(opr->handle(), 0);
  356. }
  357. if (Base::m_profiling && !Base::target_algo) {
  358. size_t min_time = std::numeric_limits<size_t>::max();
  359. for (auto algo :
  360. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  361. tensors[2].layout, tensors[3].layout,
  362. tensors[4].layout)) {
  363. opr->execution_policy().algorithm = algo;
  364. auto workspace_size = opr->get_workspace_in_bytes(
  365. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  366. tensors[3].layout, tensors[4].layout);
  367. Base::W.update(workspace_size);
  368. for (size_t times = 0; times < Base::warmup_times; ++times)
  369. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  370. tensors[4], Base::W.workspace());
  371. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  372. Timer timer;
  373. timer.start();
  374. for (size_t times = 0; times < Base::exec_times; ++times) {
  375. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  376. tensors[4], Base::W.workspace());
  377. }
  378. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  379. timer.stop();
  380. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  381. algo->name());
  382. if (min_time > timer.get_time_in_us()) {
  383. min_time = timer.get_time_in_us();
  384. Base::target_algo = algo;
  385. }
  386. }
  387. opr->execution_policy().algorithm = Base::target_algo;
  388. auto workspace_size = opr->get_workspace_in_bytes(
  389. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  390. tensors[3].layout, tensors[4].layout);
  391. Base::W.update(workspace_size);
  392. }
  393. if (!Base::target_algo) {
  394. auto workspace_size = opr->get_workspace_in_bytes(
  395. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  396. tensors[3].layout, tensors[4].layout);
  397. Base::W.update(workspace_size);
  398. }
  399. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  400. Base::W.workspace());
  401. }
  402. };
  403. #define DEF_PROF5(c) \
  404. template <> \
  405. struct OprProxy<c> : public OprProxyProfiling5<c> { \
  406. using OprProxyProfiling5<c>::OprProxyProfiling5; \
  407. }
  408. DEF_PROF5(DeformableConvForward);
  409. DEF_PROF5(DeformableConvBackwardFilter);
  410. DEF_PROF5(BatchConvBiasForward);
  411. #undef DEF_PROF5
  412. template <>
  413. struct OprProxy<ConvBiasForward> : public OprProxyProfiling5<ConvBiasForward> {
  414. using OprProxyProfiling5<ConvBiasForward>::OprProxyProfiling5;
  415. void exec(ConvBiasForward* opr, const TensorNDArray& tensors) {
  416. megdnn_assert(tensors.size() == 5);
  417. if (!Base::W.valid()) {
  418. Base::W = WorkspaceWrapper(opr->handle(), 0);
  419. }
  420. if (Base::m_profiling && !Base::target_algo) {
  421. size_t min_time = std::numeric_limits<size_t>::max();
  422. for (auto algo :
  423. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  424. tensors[2].layout, tensors[3].layout,
  425. tensors[4].layout)) {
  426. opr->execution_policy().algorithm = algo;
  427. auto workspace_size = opr->get_workspace_in_bytes(
  428. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  429. tensors[3].layout, tensors[4].layout, nullptr);
  430. Base::W.update(workspace_size);
  431. for (size_t times = 0; times < Base::warmup_times; ++times)
  432. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  433. tensors[4], nullptr, Base::W.workspace());
  434. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  435. Timer timer;
  436. timer.start();
  437. for (size_t times = 0; times < Base::exec_times; ++times) {
  438. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  439. tensors[4], nullptr, Base::W.workspace());
  440. }
  441. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  442. timer.stop();
  443. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  444. algo->name());
  445. if (min_time > timer.get_time_in_us()) {
  446. min_time = timer.get_time_in_us();
  447. Base::target_algo = algo;
  448. }
  449. }
  450. opr->execution_policy().algorithm = Base::target_algo;
  451. auto workspace_size = opr->get_workspace_in_bytes(
  452. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  453. tensors[3].layout, tensors[4].layout, nullptr);
  454. Base::W.update(workspace_size);
  455. }
  456. if (!Base::target_algo) {
  457. auto workspace_size = opr->get_workspace_in_bytes(
  458. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  459. tensors[3].layout, tensors[4].layout, nullptr);
  460. Base::W.update(workspace_size);
  461. }
  462. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  463. nullptr, Base::W.workspace());
  464. }
  465. };
  466. template <>
  467. struct OprWeightPreprocessProxy<ConvBiasForward>
  468. : public OprProxyProfiling5<ConvBiasForward> {
  469. using OprProxyProfiling5<ConvBiasForward>::OprProxyProfiling5;
  470. void exec(ConvBiasForward* opr, const TensorNDArray& tensors) {
  471. megdnn_assert(tensors.size() == 5);
  472. if (!Base::W.valid()) {
  473. Base::W = WorkspaceWrapper(opr->handle(), 0);
  474. }
  475. if (Base::m_profiling && !Base::target_algo) {
  476. size_t min_time = std::numeric_limits<size_t>::max();
  477. for (auto algo :
  478. opr->get_all_algorithms(tensors[0].layout, tensors[1].layout,
  479. tensors[2].layout, tensors[3].layout,
  480. tensors[4].layout)) {
  481. opr->execution_policy().algorithm = algo;
  482. auto preprocess_tensors = weight_prerocess(opr, tensors, algo);
  483. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  484. ConvBiasForward::PreprocessedFilter preprocessed_filter{
  485. algo, *preprocess_tensors};
  486. auto workspace_size = opr->get_workspace_in_bytes(
  487. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  488. tensors[3].layout, tensors[4].layout,
  489. &preprocessed_filter);
  490. Base::W.update(workspace_size);
  491. for (size_t times = 0; times < Base::warmup_times; ++times)
  492. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  493. tensors[4], &preprocessed_filter,
  494. Base::W.workspace());
  495. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  496. Timer timer;
  497. timer.start();
  498. for (size_t times = 0; times < Base::exec_times; ++times) {
  499. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  500. tensors[4], &preprocessed_filter,
  501. Base::W.workspace());
  502. }
  503. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  504. timer.stop();
  505. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  506. algo->name());
  507. if (min_time > timer.get_time_in_us()) {
  508. min_time = timer.get_time_in_us();
  509. Base::target_algo = algo;
  510. }
  511. }
  512. opr->execution_policy().algorithm = Base::target_algo;
  513. auto preprocess_tensors =
  514. weight_prerocess(opr, tensors, Base::target_algo);
  515. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  516. ConvBiasForward::PreprocessedFilter preprocessed_filter{
  517. Base::target_algo, *preprocess_tensors};
  518. auto workspace_size = opr->get_workspace_in_bytes(
  519. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  520. tensors[3].layout, tensors[4].layout, &preprocessed_filter);
  521. Base::W.update(workspace_size);
  522. }
  523. auto preprocess_tensors =
  524. weight_prerocess(opr, tensors, Base::target_algo);
  525. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  526. ConvBiasForward::PreprocessedFilter preprocessed_filter{
  527. Base::target_algo, *preprocess_tensors};
  528. if (!Base::target_algo) {
  529. auto workspace_size = opr->get_workspace_in_bytes(
  530. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  531. tensors[3].layout, tensors[4].layout, &preprocessed_filter);
  532. Base::W.update(workspace_size);
  533. }
  534. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  535. &preprocessed_filter, Base::W.workspace());
  536. }
  537. //! handle weight preprocess
  538. std::shared_ptr<TensorNDArray> weight_prerocess(
  539. ConvBiasForward* opr, const TensorNDArray& tensors,
  540. ConvBiasForward::Algorithm* algo) {
  541. auto weight_perprocess_layouts = opr->deduce_preprocessed_filter_layout(
  542. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  543. tensors[3].layout, tensors[4].layout);
  544. auto preprocessed_filter_tensors_ptr =
  545. alloc_tensors(opr->handle(), weight_perprocess_layouts);
  546. ConvBiasForward::PreprocessedFilter preprocessed_filter{
  547. algo, *preprocessed_filter_tensors_ptr};
  548. size_t preprocess_workspace_size =
  549. opr->get_preprocess_workspace_in_bytes(
  550. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  551. tensors[3].layout, tensors[4].layout);
  552. WorkspaceWrapper preprocess_workspace(opr->handle(),
  553. preprocess_workspace_size);
  554. opr->exec_preprocess(tensors[0].layout, tensors[1], tensors[2],
  555. tensors[3].layout, tensors[4].layout,
  556. &preprocessed_filter,
  557. preprocess_workspace.workspace());
  558. return preprocessed_filter_tensors_ptr;
  559. }
  560. };
  561. template <class Opr>
  562. struct OprProxyProfiling8 : public OprProxyProfilingBase<Opr, 8> {
  563. using Base = OprProxyProfilingBase<Opr, 8>;
  564. using OprProxyProfilingBase<Opr, 8>::OprProxyProfilingBase;
  565. void exec(Opr* opr, const TensorNDArray& tensors) {
  566. megdnn_assert(tensors.size() == 8);
  567. if (!Base::W.valid()) {
  568. Base::W = WorkspaceWrapper(opr->handle(), 0);
  569. }
  570. if (Base::m_profiling && !Base::target_algo) {
  571. size_t min_time = std::numeric_limits<size_t>::max();
  572. for (auto algo : opr->get_all_algorithms(
  573. tensors[0].layout, tensors[1].layout,
  574. tensors[2].layout, tensors[3].layout,
  575. tensors[4].layout, tensors[5].layout,
  576. tensors[6].layout, tensors[7].layout)) {
  577. opr->execution_policy().algorithm = algo;
  578. auto workspace_size = opr->get_workspace_in_bytes(
  579. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  580. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  581. tensors[6].layout, tensors[7].layout);
  582. Base::W.update(workspace_size);
  583. for (size_t times = 0; times < Base::warmup_times; ++times)
  584. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  585. tensors[4], tensors[5], tensors[6], tensors[7],
  586. Base::W.workspace());
  587. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  588. Timer timer;
  589. timer.start();
  590. for (size_t times = 0; times < Base::exec_times; ++times) {
  591. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  592. tensors[4], tensors[5], tensors[6], tensors[7],
  593. Base::W.workspace());
  594. }
  595. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  596. timer.stop();
  597. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  598. algo->name());
  599. if (min_time > timer.get_time_in_us()) {
  600. min_time = timer.get_time_in_us();
  601. Base::target_algo = algo;
  602. }
  603. }
  604. opr->execution_policy().algorithm = Base::target_algo;
  605. auto workspace_size = opr->get_workspace_in_bytes(
  606. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  607. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  608. tensors[6].layout, tensors[7].layout);
  609. Base::W.update(workspace_size);
  610. }
  611. if (!Base::target_algo) {
  612. auto workspace_size = opr->get_workspace_in_bytes(
  613. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  614. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  615. tensors[6].layout, tensors[7].layout);
  616. Base::W.update(workspace_size);
  617. }
  618. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  619. tensors[5], tensors[6], tensors[7], Base::W.workspace());
  620. }
  621. };
  622. #define DEF_PROF8(c) \
  623. template <> \
  624. struct OprProxy<c> : public OprProxyProfiling8<c> { \
  625. using OprProxyProfiling8<c>::OprProxyProfiling8; \
  626. }
  627. DEF_PROF8(DeformableConvBackwardData);
  628. #undef DEF_PROF8
  629. } // namespace test
  630. } // namespace megdnn
  631. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台