You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_proxy.h 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. #pragma once
  2. #include "src/common/opr_trait.h"
  3. #include "test/common/deduce_layout_proxy.h"
  4. #include "test/common/exec_proxy.h"
  5. #include "test/common/fast_run_cache.h"
  6. #include "test/common/inspect_type.h"
  7. #include "test/common/opr_algo_proxy.h"
  8. #include "test/common/timer.h"
  9. #include "test/common/workspace_wrapper.h"
  10. #include <algorithm>
  11. #include <limits>
  12. #include <memory>
  13. #include <unordered_map>
  14. namespace megdnn {
  15. namespace test {
  16. template <Algorithm::OprType>
  17. struct OprFromOprTypeTrait;
  18. template <typename Opr>
  19. struct OprTypeFromOprTrait;
  20. #define cb(_opr_type, _opr) \
  21. template <> \
  22. struct OprFromOprTypeTrait<Algorithm::OprType::_opr_type> { \
  23. using Opr = megdnn::_opr; \
  24. }; \
  25. template <> \
  26. struct OprTypeFromOprTrait<megdnn::_opr> { \
  27. constexpr static Algorithm::OprType opr_type = Algorithm::OprType::_opr_type; \
  28. }
  29. cb(MATRIX_MUL_FORWARD, MatrixMulForward);
  30. cb(BATCHED_MATRIX_MUL_FORWARD, BatchedMatrixMulForward);
  31. cb(CONVOLUTION_FORWARD, ConvolutionForward);
  32. cb(CONVOLUTION_BACKWARD_DATA, ConvolutionBackwardData);
  33. cb(CONVOLUTION_BACKWARD_FILTER, ConvolutionBackwardFilter);
  34. cb(CONVOLUTION3D_FORWARD, Convolution3DForward);
  35. cb(CONVOLUTION3D_BACKWARD_DATA, Convolution3DBackwardData);
  36. cb(CONVOLUTION3D_BACKWARD_FILTER, Convolution3DBackwardFilter);
  37. cb(LOCAL_SHARE_FORWARD, LocalShareForward);
  38. cb(LOCAL_SHARE_BACKWARD_DATA, LocalShareBackwardData);
  39. cb(LOCAL_SHARE_BACKWARD_FILTER, LocalShareBackwardFilter);
  40. cb(DEFORMABLE_CONV_FORWARD, DeformableConvForward);
  41. cb(DEFORMABLE_CONV_BACKWARD_DATA, DeformableConvBackwardData);
  42. cb(DEFORMABLE_CONV_BACKWARD_FILTER, DeformableConvBackwardFilter);
  43. cb(BATCH_CONV_FORWARD, BatchConvBiasForward);
  44. cb(CONVBIAS_FORWARD, ConvBiasForward);
  45. #undef cb
  46. // clang-format off
  47. #define FOREACH_OPR_TYPE(cb) \
  48. cb(MATRIX_MUL_FORWARD) \
  49. cb(BATCHED_MATRIX_MUL_FORWARD) \
  50. cb(CONVOLUTION_FORWARD) \
  51. cb(CONVOLUTION_BACKWARD_DATA) \
  52. cb(CONVOLUTION_BACKWARD_FILTER) \
  53. cb(CONVOLUTION3D_FORWARD) \
  54. cb(CONVOLUTION3D_BACKWARD_DATA) \
  55. cb(CONVOLUTION3D_BACKWARD_FILTER) \
  56. cb(LOCAL_SHARE_FORWARD) \
  57. cb(LOCAL_SHARE_BACKWARD_DATA) \
  58. cb(LOCAL_SHARE_BACKWARD_FILTER) \
  59. cb(DEFORMABLE_CONV_FORWARD) \
  60. cb(DEFORMABLE_CONV_BACKWARD_DATA) \
  61. cb(DEFORMABLE_CONV_BACKWARD_FILTER) \
  62. cb(BATCH_CONV_FORWARD) \
  63. cb(CONVBIAS_FORWARD)
  64. #define FOREACH_OPR_TYPE_WITH_STMT(cb, stmt) \
  65. cb(MATRIX_MUL_FORWARD, stmt) \
  66. cb(BATCHED_MATRIX_MUL_FORWARD, stmt) \
  67. cb(CONVOLUTION_FORWARD, stmt) \
  68. cb(CONVOLUTION_BACKWARD_DATA, stmt) \
  69. cb(CONVOLUTION_BACKWARD_FILTER, stmt) \
  70. cb(CONVOLUTION3D_FORWARD, stmt) \
  71. cb(CONVOLUTION3D_BACKWARD_DATA, stmt) \
  72. cb(CONVOLUTION3D_BACKWARD_FILTER, stmt) \
  73. cb(LOCAL_SHARE_FORWARD, stmt) \
  74. cb(LOCAL_SHARE_BACKWARD_DATA, stmt) \
  75. cb(LOCAL_SHARE_BACKWARD_FILTER, stmt) \
  76. cb(DEFORMABLE_CONV_FORWARD, stmt) \
  77. cb(DEFORMABLE_CONV_BACKWARD_DATA, stmt) \
  78. cb(DEFORMABLE_CONV_BACKWARD_FILTER, stmt) \
  79. cb(BATCH_CONV_FORWARD, stmt) \
  80. cb(CONVBIAS_FORWARD, stmt)
  81. // clang-format on
  82. #define _OPR_TYPE_CASE(_opr_type, _stmt) \
  83. case Algorithm::OprType::_opr_type: { \
  84. using _Opr = typename OprFromOprTypeTrait<Algorithm::OprType::_opr_type>::Opr; \
  85. _stmt; \
  86. break; \
  87. }
  88. #define FOREACH_OPR_TYPE_DISPATCH(_search_items, _stmt) \
  89. for (size_t _item_idx = 0; _item_idx < _search_items.size(); _item_idx++) { \
  90. auto&& _item = _search_items[_item_idx]; \
  91. switch (_item.opr_type) { \
  92. FOREACH_OPR_TYPE_WITH_STMT(_OPR_TYPE_CASE, _stmt) \
  93. default: \
  94. megdnn_throw("unknown opr_type"); \
  95. } \
  96. }
  97. template <
  98. typename Opr, size_t arity = OprTrait<Opr>::arity,
  99. bool has_workspace = OprTrait<Opr>::has_workspace,
  100. bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
  101. struct OprProxyDefaultImpl : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
  102. public ExecProxy<Opr, arity, has_workspace> {};
  103. template <typename Opr>
  104. struct OprProxy : public OprProxyDefaultImpl<Opr> {};
  105. template <typename Opr>
  106. struct OprWeightPreprocessProxy : public OprProxyDefaultImpl<Opr> {};
  107. template <typename Opr>
  108. struct OprProxyVectorToSingle {};
  109. template <>
  110. struct OprProxy<ElemwiseForward> {
  111. static void deduce_layout(ElemwiseForward* opr, TensorLayoutArray& layouts) {
  112. megdnn_assert(layouts.size() >= 2);
  113. auto inp = layouts;
  114. inp.pop_back();
  115. opr->deduce_layout(inp, layouts.back());
  116. }
  117. static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
  118. megdnn_assert(tensors.size() >= 2);
  119. auto inp = tensors;
  120. inp.pop_back();
  121. opr->exec(inp, tensors.back());
  122. }
  123. };
  124. template <>
  125. struct OprProxy<ElemwiseMultiType> {
  126. static void deduce_layout(ElemwiseMultiType* opr, TensorLayoutArray& layouts) {
  127. megdnn_assert(layouts.size() >= 2);
  128. auto inp = layouts;
  129. inp.pop_back();
  130. opr->deduce_layout(inp, layouts.back());
  131. }
  132. static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
  133. megdnn_assert(tensors.size() >= 2);
  134. auto inp = tensors;
  135. inp.pop_back();
  136. opr->exec(inp, tensors.back());
  137. }
  138. };
  139. template <>
  140. struct OprProxy<ConcatForward> {
  141. WorkspaceWrapper W;
  142. static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
  143. megdnn_assert(layouts.size() >= 2);
  144. auto inp = layouts;
  145. inp.pop_back();
  146. opr->deduce_layout(inp, layouts.back());
  147. }
  148. void exec(ConcatForward* opr, const TensorNDArray& tensors) {
  149. if (!W.valid()) {
  150. W = WorkspaceWrapper(opr->handle(), 0);
  151. }
  152. megdnn_assert(tensors.size() >= 2);
  153. auto inp = tensors;
  154. inp.pop_back();
  155. TensorLayoutArray layouts(tensors.size());
  156. std::transform(
  157. tensors.begin(), tensors.end(), layouts.begin(),
  158. [](const TensorND& tensor) { return tensor.layout; });
  159. auto inp_layouts = layouts;
  160. inp_layouts.pop_back();
  161. W.update(opr->get_workspace_in_bytes(inp_layouts, layouts.back()));
  162. auto inp_tensors = tensors;
  163. inp_tensors.pop_back();
  164. opr->exec(inp_tensors, tensors.back(), W.workspace());
  165. }
  166. };
  167. template <>
  168. struct OprProxy<CheckNonFinite> {
  169. static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) {
  170. megdnn_assert(layouts.size() >= 2);
  171. auto inp = layouts;
  172. inp.pop_back();
  173. opr->deduce_layout(inp, layouts.back());
  174. }
  175. static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) {
  176. megdnn_assert(tensors.size() >= 2);
  177. auto inps = tensors;
  178. inps.pop_back();
  179. WorkspaceWrapper W(
  180. opr->handle(),
  181. opr->get_workspace_in_bytes(inps, tensors.back().layout));
  182. opr->exec(inps, tensors.back(), W.workspace());
  183. }
  184. };
  185. template <>
  186. struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
  187. WorkspaceWrapper W;
  188. void exec(SplitForward* opr, const TensorNDArray& tensors) {
  189. megdnn_assert(tensors.size() >= 2);
  190. if (!W.valid()) {
  191. W = WorkspaceWrapper(opr->handle(), 0);
  192. }
  193. auto out = tensors;
  194. out.erase(out.begin());
  195. TensorLayoutArray layouts(tensors.size());
  196. std::transform(
  197. tensors.begin(), tensors.end(), layouts.begin(),
  198. [](const TensorND& tensor) { return tensor.layout; });
  199. auto out_layouts = layouts;
  200. out_layouts.erase(out_layouts.begin());
  201. W.update(opr->get_workspace_in_bytes(layouts.front(), out_layouts));
  202. auto out_tensors = tensors;
  203. out_tensors.erase(out_tensors.begin());
  204. opr->exec(tensors.front(), out_tensors, W.workspace());
  205. }
  206. };
  207. //! OprProxy impl for tenary oprs with profiling support
  208. template <class Opr>
  209. struct OprProxyProfilingBase
  210. : public DeduceLayoutProxy<
  211. Opr, OprTrait<Opr>::arity, OprTrait<Opr>::can_deduce_layout> {
  212. static constexpr int arity = OprTrait<Opr>::arity;
  213. size_t warmup_times = 10, exec_times = 100;
  214. //! whether to enable profiling
  215. bool m_profiling;
  216. WorkspaceWrapper W;
  217. //! target algo setup by profiler; it can also be directly specified by the
  218. //! caller
  219. ExecutionPolicy target_execution_policy;
  220. OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
  221. //! used for alloc tensor for weight preprocess
  222. static std::shared_ptr<TensorNDArray> alloc_tensors(
  223. Handle* handle, const TensorLayoutArray& layouts) {
  224. auto deleter = [handle](TensorNDArray* ptr) {
  225. for (auto&& i : *ptr) {
  226. auto pdata =
  227. static_cast<dt_byte*>(i.raw_ptr()) + i.layout.span().low_byte;
  228. megdnn_free(handle, pdata);
  229. }
  230. delete ptr;
  231. };
  232. std::shared_ptr<TensorNDArray> ret{new TensorNDArray, deleter};
  233. for (size_t i = 0; i < layouts.size(); ++i) {
  234. auto span = layouts[i].span();
  235. ret->emplace_back(
  236. static_cast<dt_byte*>(megdnn_malloc(handle, span.dist_byte())) -
  237. span.low_byte,
  238. layouts[i]);
  239. }
  240. return ret;
  241. }
  242. /**
  243. * flatten search space in postorder traversal
  244. * The subopr search construct a search tree
  245. *
  246. * A
  247. * / \
  248. * B1B2 C
  249. * / \
  250. * D1D2D3 E
  251. * We use postorder traverse the search tree.
  252. * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
  253. */
  254. static std::vector<Algorithm::SearchItem> flatten_search_space(
  255. const TensorLayoutArray layouts, const std::string& param, Handle* handle) {
  256. megdnn_assert(layouts.size() == arity);
  257. auto opr = handle->create_operator<Opr>();
  258. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  259. std::vector<Algorithm::SearchItem> ret;
  260. for (auto algo_info :
  261. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  262. Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
  263. std::vector<Algorithm::SearchItem>&& sub_items =
  264. algo->get_subopr_list(layouts, opr.get());
  265. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  266. auto space = OprProxyProfilingBase<_Opr>::flatten_search_space(
  267. _item.layouts, _item.param, handle);
  268. ret.insert(ret.end(), space.begin(), space.end());
  269. });
  270. }
  271. ret.push_back({OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  272. return ret;
  273. }
  274. static void construct_execution_policy(
  275. const TensorLayoutArray& layouts, const std::string& param, Handle* handle,
  276. FastRunCache& cache, ExecutionPolicy& policy) {
  277. megdnn_assert(layouts.size() == arity);
  278. auto opr = handle->create_operator<Opr>();
  279. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  280. if (!policy.algo.valid()) {
  281. policy.algo = cache.get(Algorithm::SearchItem{
  282. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  283. megdnn_assert(
  284. policy.algo.valid(),
  285. "No cache found, maybe some error occured in "
  286. "flatten_search_space or get_subopr_list");
  287. }
  288. policy.sub_policy.clear();
  289. Algorithm* algo = opr->get_algorithm_from_desc(policy.algo);
  290. std::vector<Algorithm::SearchItem>&& sub_items =
  291. algo->get_subopr_list(layouts, opr.get());
  292. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  293. policy.sub_policy.push_back({});
  294. OprProxyProfilingBase<_Opr>::construct_execution_policy(
  295. _item.layouts, _item.param, handle, cache,
  296. policy.sub_policy.back());
  297. });
  298. return;
  299. }
  300. /**
  301. * \brief search and get the best execution_policy
  302. */
  303. static void search(
  304. const TensorLayoutArray& layouts, const std::string& param,
  305. WorkspaceWrapper& workspace_wrapper, Handle* handle, size_t warmup_times,
  306. size_t exec_times, FastRunCache& cache) {
  307. megdnn_assert(layouts.size() == arity);
  308. auto opr = handle->create_operator<Opr>();
  309. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  310. SmallVector<size_t> sizes_in_bytes;
  311. for (const auto& layout : layouts) {
  312. sizes_in_bytes.push_back(layout.span().dist_byte());
  313. }
  314. float min_time = std::numeric_limits<float>::max();
  315. Algorithm::Info::Desc best_algo;
  316. std::string log_info = "Profiling start: ";
  317. for (auto&& layout : layouts) {
  318. log_info += layout.to_string() + " ";
  319. }
  320. megdnn_log("%s", log_info.c_str());
  321. best_algo = cache.get(Algorithm::SearchItem{
  322. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  323. if (best_algo.valid()) {
  324. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  325. MEGDNN_MARK_USED_VAR(algo);
  326. megdnn_log("Find best algo %s in cache", algo->name());
  327. return;
  328. }
  329. for (auto algo :
  330. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  331. //! construct execution_policy
  332. opr->execution_policy().algo = algo.desc;
  333. construct_execution_policy(
  334. layouts, param, handle, cache, opr->execution_policy());
  335. auto workspace_size =
  336. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr.get(), layouts);
  337. sizes_in_bytes.push_back(workspace_size);
  338. WorkspaceBundle wb(nullptr, sizes_in_bytes);
  339. workspace_wrapper.update(wb.total_size_in_bytes());
  340. wb.set(workspace_wrapper.workspace().raw_ptr);
  341. TensorNDArray tensors;
  342. for (size_t i = 0; i < arity; i++) {
  343. tensors.push_back({wb.get(i), layouts[i]});
  344. }
  345. for (size_t times = 0; times < warmup_times; ++times) {
  346. AlgoProxy<Opr, arity>::exec(
  347. opr.get(), tensors, wb.get_workspace(arity));
  348. }
  349. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  350. Timer timer;
  351. timer.start();
  352. for (size_t times = 0; times < exec_times; ++times) {
  353. AlgoProxy<Opr, arity>::exec(
  354. opr.get(), tensors, wb.get_workspace(arity));
  355. }
  356. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  357. timer.stop();
  358. megdnn_log(
  359. "%.3fms %s", timer.get_time_in_us() / 1e3, algo.desc.name.c_str());
  360. if (min_time > timer.get_time_in_us()) {
  361. min_time = timer.get_time_in_us();
  362. best_algo = algo.desc;
  363. }
  364. sizes_in_bytes.pop_back();
  365. }
  366. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  367. MEGDNN_MARK_USED_VAR(algo);
  368. megdnn_log("Profiling end, got best algo: %s", algo->name());
  369. cache.put(
  370. Algorithm::SearchItem{
  371. OprTypeFromOprTrait<Opr>::opr_type, param, layouts},
  372. best_algo);
  373. }
  374. void exec(Opr* opr, const TensorNDArray& tensors) {
  375. megdnn_assert(tensors.size() == arity);
  376. if (!W.valid()) {
  377. W = WorkspaceWrapper(opr->handle(), 0);
  378. }
  379. TensorLayoutArray layouts;
  380. for (auto&& tensor : tensors) {
  381. layouts.push_back(tensor.layout);
  382. }
  383. if (m_profiling && !target_execution_policy.algo.valid()) {
  384. FastRunCache cache;
  385. std::string param_str;
  386. Algorithm::serialize_write_pod(opr->param(), param_str);
  387. auto&& search_items =
  388. flatten_search_space(layouts, param_str, opr->handle());
  389. FOREACH_OPR_TYPE_DISPATCH(search_items, {
  390. OprProxyProfilingBase<_Opr>::search(
  391. _item.layouts, _item.param, W, opr->handle(), warmup_times,
  392. exec_times, cache);
  393. });
  394. construct_execution_policy(
  395. layouts, param_str, opr->handle(), cache, opr->execution_policy());
  396. target_execution_policy = opr->execution_policy();
  397. auto workspace_size =
  398. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  399. W.update(workspace_size);
  400. }
  401. if (!target_execution_policy.algo.valid()) {
  402. auto workspace_size =
  403. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  404. W.update(workspace_size);
  405. }
  406. AlgoProxy<Opr, arity>::exec(opr, tensors, W.workspace());
  407. }
  408. };
  409. #define DEF_PROF(c) \
  410. template <> \
  411. struct OprProxy<c> : public OprProxyProfilingBase<c> { \
  412. using OprProxyProfilingBase<c>::OprProxyProfilingBase; \
  413. }
  414. DEF_PROF(MatrixMulForward);
  415. DEF_PROF(ConvolutionForward);
  416. DEF_PROF(ConvolutionBackwardData);
  417. DEF_PROF(ConvolutionBackwardFilter);
  418. DEF_PROF(LocalShareForward);
  419. DEF_PROF(LocalShareBackwardData);
  420. DEF_PROF(LocalShareBackwardFilter);
  421. DEF_PROF(DeformableConvForward);
  422. DEF_PROF(DeformableConvBackwardFilter);
  423. DEF_PROF(BatchConvBiasForward);
  424. DEF_PROF(ConvBiasForward);
  425. DEF_PROF(DeformableConvBackwardData);
  426. #undef DEF_PROF
  427. template <class Opr>
  428. struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
  429. using Base = OprProxyProfilingBase<Opr>;
  430. static constexpr int arity = OprTrait<Opr>::arity;
  431. void exec(Opr* opr, const TensorNDArray& tensors) {
  432. megdnn_assert(tensors.size() == arity);
  433. if (!Base::W.valid()) {
  434. Base::W = WorkspaceWrapper(opr->handle(), 0);
  435. }
  436. TensorLayoutArray layouts;
  437. for (auto&& tensor : tensors) {
  438. layouts.push_back(tensor.layout);
  439. }
  440. if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) {
  441. size_t min_time = std::numeric_limits<size_t>::max();
  442. for (auto algo :
  443. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) {
  444. opr->execution_policy().algo = algo.desc;
  445. auto preprocess_tensors = weight_prerocess(opr, tensors, algo.desc);
  446. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  447. typename Opr::PreprocessedFilter preprocessed_filter{
  448. nullptr, *preprocess_tensors};
  449. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  450. opr, layouts, &preprocessed_filter);
  451. Base::W.update(workspace_size);
  452. for (size_t times = 0; times < Base::warmup_times; ++times) {
  453. AlgoProxy<Opr, arity>::exec(
  454. opr, tensors, &preprocessed_filter, Base::W.workspace());
  455. }
  456. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  457. Timer timer;
  458. timer.start();
  459. for (size_t times = 0; times < Base::exec_times; ++times) {
  460. AlgoProxy<Opr, arity>::exec(
  461. opr, tensors, &preprocessed_filter, Base::W.workspace());
  462. }
  463. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  464. timer.stop();
  465. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  466. algo.desc.name.c_str());
  467. if (min_time > timer.get_time_in_us()) {
  468. min_time = timer.get_time_in_us();
  469. Base::target_execution_policy.algo = algo.desc;
  470. }
  471. }
  472. opr->execution_policy() = Base::target_execution_policy;
  473. auto preprocess_tensors =
  474. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  475. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  476. typename Opr::PreprocessedFilter preprocessed_filter{
  477. nullptr, *preprocess_tensors};
  478. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  479. opr, layouts, &preprocessed_filter);
  480. Base::W.update(workspace_size);
  481. }
  482. auto preprocess_tensors =
  483. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  484. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  485. typename Opr::PreprocessedFilter preprocessed_filter{
  486. nullptr, *preprocess_tensors};
  487. if (!Base::target_execution_policy.algo.valid()) {
  488. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  489. opr, layouts, &preprocessed_filter);
  490. Base::W.update(workspace_size);
  491. }
  492. AlgoProxy<Opr, arity>::exec(
  493. opr, tensors, &preprocessed_filter, Base::W.workspace());
  494. }
  495. //! handle weight preprocess
  496. std::shared_ptr<TensorNDArray> weight_prerocess(
  497. Opr* opr, const TensorNDArray& tensors,
  498. const typename Opr::AlgorithmDesc&) {
  499. TensorLayoutArray layouts;
  500. for (auto&& tensor : tensors) {
  501. layouts.push_back(tensor.layout);
  502. }
  503. auto weight_perprocess_layouts =
  504. AlgoProxy<Opr, arity>::deduce_preprocessed_filter_layout(opr, layouts);
  505. auto preprocessed_filter_tensors_ptr =
  506. Base::alloc_tensors(opr->handle(), weight_perprocess_layouts);
  507. typename Opr::PreprocessedFilter preprocessed_filter{
  508. nullptr, *preprocessed_filter_tensors_ptr};
  509. size_t preprocess_workspace_size =
  510. AlgoProxy<Opr, arity>::get_preprocess_workspace_in_bytes(opr, layouts);
  511. WorkspaceWrapper preprocess_workspace(opr->handle(), preprocess_workspace_size);
  512. AlgoProxy<Opr, arity>::exec_preprocess(
  513. opr, tensors, layouts, &preprocessed_filter,
  514. preprocess_workspace.workspace());
  515. return preprocessed_filter_tensors_ptr;
  516. }
  517. };
  518. #define DEF_PROF(c) \
  519. template <> \
  520. struct OprWeightPreprocessProxy<c> : public OprWeightPreprocessProxyImpl<c> { \
  521. using OprWeightPreprocessProxyImpl<c>::OprWeightPreprocessProxyImpl; \
  522. }
  523. DEF_PROF(ConvolutionForward);
  524. DEF_PROF(ConvBias);
  525. #undef DEF_PROF
  526. } // namespace test
  527. } // namespace megdnn
  528. // vim: syntax=cpp.doxygen