You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_proxy.h 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. /**
  2. * \file dnn/test/common/opr_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/common/opr_trait.h"
  14. #include "test/common/deduce_layout_proxy.h"
  15. #include "test/common/exec_proxy.h"
  16. #include "test/common/fast_run_cache.h"
  17. #include "test/common/inspect_type.h"
  18. #include "test/common/opr_algo_proxy.h"
  19. #include "test/common/timer.h"
  20. #include "test/common/workspace_wrapper.h"
  21. #include <algorithm>
  22. #include <limits>
  23. #include <memory>
  24. #include <unordered_map>
  25. namespace megdnn {
  26. namespace test {
  27. template <Algorithm::OprType>
  28. struct OprFromOprTypeTrait;
  29. template <typename Opr>
  30. struct OprTypeFromOprTrait;
  31. #define cb(_opr_type, _opr) \
  32. template <> \
  33. struct OprFromOprTypeTrait<Algorithm::OprType::_opr_type> { \
  34. using Opr = megdnn::_opr; \
  35. }; \
  36. template <> \
  37. struct OprTypeFromOprTrait<megdnn::_opr> { \
  38. constexpr static Algorithm::OprType opr_type = Algorithm::OprType::_opr_type; \
  39. }
  40. cb(MATRIX_MUL_FORWARD, MatrixMulForward);
  41. cb(BATCHED_MATRIX_MUL_FORWARD, BatchedMatrixMulForward);
  42. cb(CONVOLUTION_FORWARD, ConvolutionForward);
  43. cb(CONVOLUTION_BACKWARD_DATA, ConvolutionBackwardData);
  44. cb(CONVOLUTION_BACKWARD_FILTER, ConvolutionBackwardFilter);
  45. cb(CONVOLUTION3D_FORWARD, Convolution3DForward);
  46. cb(CONVOLUTION3D_BACKWARD_DATA, Convolution3DBackwardData);
  47. cb(CONVOLUTION3D_BACKWARD_FILTER, Convolution3DBackwardFilter);
  48. cb(LOCAL_SHARE_FORWARD, LocalShareForward);
  49. cb(LOCAL_SHARE_BACKWARD_DATA, LocalShareBackwardData);
  50. cb(LOCAL_SHARE_BACKWARD_FILTER, LocalShareBackwardFilter);
  51. cb(DEFORMABLE_CONV_FORWARD, DeformableConvForward);
  52. cb(DEFORMABLE_CONV_BACKWARD_DATA, DeformableConvBackwardData);
  53. cb(DEFORMABLE_CONV_BACKWARD_FILTER, DeformableConvBackwardFilter);
  54. cb(BATCH_CONV_FORWARD, BatchConvBiasForward);
  55. cb(CONVBIAS_FORWARD, ConvBiasForward);
  56. #undef cb
  57. // clang-format off
  58. #define FOREACH_OPR_TYPE(cb) \
  59. cb(MATRIX_MUL_FORWARD) \
  60. cb(BATCHED_MATRIX_MUL_FORWARD) \
  61. cb(CONVOLUTION_FORWARD) \
  62. cb(CONVOLUTION_BACKWARD_DATA) \
  63. cb(CONVOLUTION_BACKWARD_FILTER) \
  64. cb(CONVOLUTION3D_FORWARD) \
  65. cb(CONVOLUTION3D_BACKWARD_DATA) \
  66. cb(CONVOLUTION3D_BACKWARD_FILTER) \
  67. cb(LOCAL_SHARE_FORWARD) \
  68. cb(LOCAL_SHARE_BACKWARD_DATA) \
  69. cb(LOCAL_SHARE_BACKWARD_FILTER) \
  70. cb(DEFORMABLE_CONV_FORWARD) \
  71. cb(DEFORMABLE_CONV_BACKWARD_DATA) \
  72. cb(DEFORMABLE_CONV_BACKWARD_FILTER) \
  73. cb(BATCH_CONV_FORWARD) \
  74. cb(CONVBIAS_FORWARD)
  75. #define FOREACH_OPR_TYPE_WITH_STMT(cb, stmt) \
  76. cb(MATRIX_MUL_FORWARD, stmt) \
  77. cb(BATCHED_MATRIX_MUL_FORWARD, stmt) \
  78. cb(CONVOLUTION_FORWARD, stmt) \
  79. cb(CONVOLUTION_BACKWARD_DATA, stmt) \
  80. cb(CONVOLUTION_BACKWARD_FILTER, stmt) \
  81. cb(CONVOLUTION3D_FORWARD, stmt) \
  82. cb(CONVOLUTION3D_BACKWARD_DATA, stmt) \
  83. cb(CONVOLUTION3D_BACKWARD_FILTER, stmt) \
  84. cb(LOCAL_SHARE_FORWARD, stmt) \
  85. cb(LOCAL_SHARE_BACKWARD_DATA, stmt) \
  86. cb(LOCAL_SHARE_BACKWARD_FILTER, stmt) \
  87. cb(DEFORMABLE_CONV_FORWARD, stmt) \
  88. cb(DEFORMABLE_CONV_BACKWARD_DATA, stmt) \
  89. cb(DEFORMABLE_CONV_BACKWARD_FILTER, stmt) \
  90. cb(BATCH_CONV_FORWARD, stmt) \
  91. cb(CONVBIAS_FORWARD, stmt)
  92. // clang-format on
  93. #define _OPR_TYPE_CASE(_opr_type, _stmt) \
  94. case Algorithm::OprType::_opr_type: { \
  95. using _Opr = typename OprFromOprTypeTrait<Algorithm::OprType::_opr_type>::Opr; \
  96. _stmt; \
  97. break; \
  98. }
  99. #define FOREACH_OPR_TYPE_DISPATCH(_search_items, _stmt) \
  100. for (size_t _item_idx = 0; _item_idx < _search_items.size(); _item_idx++) { \
  101. auto&& _item = _search_items[_item_idx]; \
  102. switch (_item.opr_type) { \
  103. FOREACH_OPR_TYPE_WITH_STMT(_OPR_TYPE_CASE, _stmt) \
  104. default: \
  105. megdnn_throw("unknown opr_type"); \
  106. } \
  107. }
  108. template <
  109. typename Opr, size_t arity = OprTrait<Opr>::arity,
  110. bool has_workspace = OprTrait<Opr>::has_workspace,
  111. bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
  112. struct OprProxyDefaultImpl : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
  113. public ExecProxy<Opr, arity, has_workspace> {};
  114. template <typename Opr>
  115. struct OprProxy : public OprProxyDefaultImpl<Opr> {};
  116. template <typename Opr>
  117. struct OprWeightPreprocessProxy : public OprProxyDefaultImpl<Opr> {};
  118. template <typename Opr>
  119. struct OprProxyVectorToSingle {};
  120. template <>
  121. struct OprProxy<ElemwiseForward> {
  122. static void deduce_layout(ElemwiseForward* opr, TensorLayoutArray& layouts) {
  123. megdnn_assert(layouts.size() >= 2);
  124. auto inp = layouts;
  125. inp.pop_back();
  126. opr->deduce_layout(inp, layouts.back());
  127. }
  128. static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
  129. megdnn_assert(tensors.size() >= 2);
  130. auto inp = tensors;
  131. inp.pop_back();
  132. opr->exec(inp, tensors.back());
  133. }
  134. };
  135. template <>
  136. struct OprProxy<ElemwiseMultiType> {
  137. static void deduce_layout(ElemwiseMultiType* opr, TensorLayoutArray& layouts) {
  138. megdnn_assert(layouts.size() >= 2);
  139. auto inp = layouts;
  140. inp.pop_back();
  141. opr->deduce_layout(inp, layouts.back());
  142. }
  143. static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
  144. megdnn_assert(tensors.size() >= 2);
  145. auto inp = tensors;
  146. inp.pop_back();
  147. opr->exec(inp, tensors.back());
  148. }
  149. };
  150. template <>
  151. struct OprProxy<ConcatForward> {
  152. WorkspaceWrapper W;
  153. static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
  154. megdnn_assert(layouts.size() >= 2);
  155. auto inp = layouts;
  156. inp.pop_back();
  157. opr->deduce_layout(inp, layouts.back());
  158. }
  159. void exec(ConcatForward* opr, const TensorNDArray& tensors) {
  160. if (!W.valid()) {
  161. W = WorkspaceWrapper(opr->handle(), 0);
  162. }
  163. megdnn_assert(tensors.size() >= 2);
  164. auto inp = tensors;
  165. inp.pop_back();
  166. TensorLayoutArray layouts(tensors.size());
  167. std::transform(
  168. tensors.begin(), tensors.end(), layouts.begin(),
  169. [](const TensorND& tensor) { return tensor.layout; });
  170. auto inp_layouts = layouts;
  171. inp_layouts.pop_back();
  172. W.update(opr->get_workspace_in_bytes(inp_layouts, layouts.back()));
  173. auto inp_tensors = tensors;
  174. inp_tensors.pop_back();
  175. opr->exec(inp_tensors, tensors.back(), W.workspace());
  176. }
  177. };
  178. template <>
  179. struct OprProxy<CheckNonFinite> {
  180. static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) {
  181. megdnn_assert(layouts.size() >= 2);
  182. auto inp = layouts;
  183. inp.pop_back();
  184. opr->deduce_layout(inp, layouts.back());
  185. }
  186. static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) {
  187. megdnn_assert(tensors.size() >= 2);
  188. auto inps = tensors;
  189. inps.pop_back();
  190. WorkspaceWrapper W(
  191. opr->handle(),
  192. opr->get_workspace_in_bytes(inps, tensors.back().layout));
  193. opr->exec(inps, tensors.back(), W.workspace());
  194. }
  195. };
  196. template <>
  197. struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
  198. WorkspaceWrapper W;
  199. void exec(SplitForward* opr, const TensorNDArray& tensors) {
  200. megdnn_assert(tensors.size() >= 2);
  201. if (!W.valid()) {
  202. W = WorkspaceWrapper(opr->handle(), 0);
  203. }
  204. auto out = tensors;
  205. out.erase(out.begin());
  206. TensorLayoutArray layouts(tensors.size());
  207. std::transform(
  208. tensors.begin(), tensors.end(), layouts.begin(),
  209. [](const TensorND& tensor) { return tensor.layout; });
  210. auto out_layouts = layouts;
  211. out_layouts.erase(out_layouts.begin());
  212. W.update(opr->get_workspace_in_bytes(layouts.front(), out_layouts));
  213. auto out_tensors = tensors;
  214. out_tensors.erase(out_tensors.begin());
  215. opr->exec(tensors.front(), out_tensors, W.workspace());
  216. }
  217. };
  218. //! OprProxy impl for tenary oprs with profiling support
  219. template <class Opr>
  220. struct OprProxyProfilingBase
  221. : public DeduceLayoutProxy<
  222. Opr, OprTrait<Opr>::arity, OprTrait<Opr>::can_deduce_layout> {
  223. static constexpr int arity = OprTrait<Opr>::arity;
  224. size_t warmup_times = 10, exec_times = 100;
  225. //! whether to enable profiling
  226. bool m_profiling;
  227. WorkspaceWrapper W;
  228. //! target algo setup by profiler; it can also be directly specified by the
  229. //! caller
  230. ExecutionPolicy target_execution_policy;
  231. OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
  232. //! used for alloc tensor for weight preprocess
  233. static std::shared_ptr<TensorNDArray> alloc_tensors(
  234. Handle* handle, const TensorLayoutArray& layouts) {
  235. auto deleter = [handle](TensorNDArray* ptr) {
  236. for (auto&& i : *ptr) {
  237. auto pdata =
  238. static_cast<dt_byte*>(i.raw_ptr()) + i.layout.span().low_byte;
  239. megdnn_free(handle, pdata);
  240. }
  241. delete ptr;
  242. };
  243. std::shared_ptr<TensorNDArray> ret{new TensorNDArray, deleter};
  244. for (size_t i = 0; i < layouts.size(); ++i) {
  245. auto span = layouts[i].span();
  246. ret->emplace_back(
  247. static_cast<dt_byte*>(megdnn_malloc(handle, span.dist_byte())) -
  248. span.low_byte,
  249. layouts[i]);
  250. }
  251. return ret;
  252. }
  253. /**
  254. * flatten search space in postorder traversal
  255. * The subopr search construct a search tree
  256. *
  257. * A
  258. * / \
  259. * B1B2 C
  260. * / \
  261. * D1D2D3 E
  262. * We use postorder traverse the search tree.
  263. * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
  264. */
  265. static std::vector<Algorithm::SearchItem> flatten_search_space(
  266. const TensorLayoutArray layouts, const std::string& param, Handle* handle) {
  267. megdnn_assert(layouts.size() == arity);
  268. auto opr = handle->create_operator<Opr>();
  269. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  270. std::vector<Algorithm::SearchItem> ret;
  271. for (auto algo_info :
  272. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  273. Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
  274. std::vector<Algorithm::SearchItem>&& sub_items =
  275. algo->get_subopr_list(layouts, opr.get());
  276. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  277. auto space = OprProxyProfilingBase<_Opr>::flatten_search_space(
  278. _item.layouts, _item.param, handle);
  279. ret.insert(ret.end(), space.begin(), space.end());
  280. });
  281. }
  282. ret.push_back({OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  283. return ret;
  284. }
  285. static void construct_execution_policy(
  286. const TensorLayoutArray& layouts, const std::string& param, Handle* handle,
  287. FastRunCache& cache, ExecutionPolicy& policy) {
  288. megdnn_assert(layouts.size() == arity);
  289. auto opr = handle->create_operator<Opr>();
  290. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  291. if (!policy.algo.valid()) {
  292. policy.algo = cache.get(Algorithm::SearchItem{
  293. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  294. megdnn_assert(
  295. policy.algo.valid(),
  296. "No cache found, maybe some error occured in "
  297. "flatten_search_space or get_subopr_list");
  298. }
  299. policy.sub_policy.clear();
  300. Algorithm* algo = opr->get_algorithm_from_desc(policy.algo);
  301. std::vector<Algorithm::SearchItem>&& sub_items =
  302. algo->get_subopr_list(layouts, opr.get());
  303. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  304. policy.sub_policy.push_back({});
  305. OprProxyProfilingBase<_Opr>::construct_execution_policy(
  306. _item.layouts, _item.param, handle, cache,
  307. policy.sub_policy.back());
  308. });
  309. return;
  310. }
  311. /**
  312. * \brief search and get the best execution_policy
  313. */
  314. static void search(
  315. const TensorLayoutArray& layouts, const std::string& param,
  316. WorkspaceWrapper& workspace_wrapper, Handle* handle, size_t warmup_times,
  317. size_t exec_times, FastRunCache& cache) {
  318. megdnn_assert(layouts.size() == arity);
  319. auto opr = handle->create_operator<Opr>();
  320. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  321. SmallVector<size_t> sizes_in_bytes;
  322. for (const auto& layout : layouts) {
  323. sizes_in_bytes.push_back(layout.span().dist_byte());
  324. }
  325. float min_time = std::numeric_limits<float>::max();
  326. Algorithm::Info::Desc best_algo;
  327. std::string log_info = "Profiling start: ";
  328. for (auto&& layout : layouts) {
  329. log_info += layout.to_string() + " ";
  330. }
  331. megdnn_log("%s", log_info.c_str());
  332. best_algo = cache.get(Algorithm::SearchItem{
  333. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  334. if (best_algo.valid()) {
  335. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  336. MEGDNN_MARK_USED_VAR(algo);
  337. megdnn_log("Find best algo %s in cache", algo->name());
  338. return;
  339. }
  340. for (auto algo :
  341. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  342. //! construct execution_policy
  343. opr->execution_policy().algo = algo.desc;
  344. construct_execution_policy(
  345. layouts, param, handle, cache, opr->execution_policy());
  346. auto workspace_size =
  347. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr.get(), layouts);
  348. sizes_in_bytes.push_back(workspace_size);
  349. WorkspaceBundle wb(nullptr, sizes_in_bytes);
  350. workspace_wrapper.update(wb.total_size_in_bytes());
  351. wb.set(workspace_wrapper.workspace().raw_ptr);
  352. TensorNDArray tensors;
  353. for (size_t i = 0; i < arity; i++) {
  354. tensors.push_back({wb.get(i), layouts[i]});
  355. }
  356. for (size_t times = 0; times < warmup_times; ++times) {
  357. AlgoProxy<Opr, arity>::exec(
  358. opr.get(), tensors, wb.get_workspace(arity));
  359. }
  360. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  361. Timer timer;
  362. timer.start();
  363. for (size_t times = 0; times < exec_times; ++times) {
  364. AlgoProxy<Opr, arity>::exec(
  365. opr.get(), tensors, wb.get_workspace(arity));
  366. }
  367. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  368. timer.stop();
  369. megdnn_log(
  370. "%.3fms %s", timer.get_time_in_us() / 1e3, algo.desc.name.c_str());
  371. if (min_time > timer.get_time_in_us()) {
  372. min_time = timer.get_time_in_us();
  373. best_algo = algo.desc;
  374. }
  375. sizes_in_bytes.pop_back();
  376. }
  377. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  378. MEGDNN_MARK_USED_VAR(algo);
  379. megdnn_log("Profiling end, got best algo: %s", algo->name());
  380. cache.put(
  381. Algorithm::SearchItem{
  382. OprTypeFromOprTrait<Opr>::opr_type, param, layouts},
  383. best_algo);
  384. }
  385. void exec(Opr* opr, const TensorNDArray& tensors) {
  386. megdnn_assert(tensors.size() == arity);
  387. if (!W.valid()) {
  388. W = WorkspaceWrapper(opr->handle(), 0);
  389. }
  390. TensorLayoutArray layouts;
  391. for (auto&& tensor : tensors) {
  392. layouts.push_back(tensor.layout);
  393. }
  394. if (m_profiling && !target_execution_policy.algo.valid()) {
  395. FastRunCache cache;
  396. std::string param_str;
  397. Algorithm::serialize_write_pod(opr->param(), param_str);
  398. auto&& search_items =
  399. flatten_search_space(layouts, param_str, opr->handle());
  400. FOREACH_OPR_TYPE_DISPATCH(search_items, {
  401. OprProxyProfilingBase<_Opr>::search(
  402. _item.layouts, _item.param, W, opr->handle(), warmup_times,
  403. exec_times, cache);
  404. });
  405. construct_execution_policy(
  406. layouts, param_str, opr->handle(), cache, opr->execution_policy());
  407. target_execution_policy = opr->execution_policy();
  408. auto workspace_size =
  409. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  410. W.update(workspace_size);
  411. }
  412. if (!target_execution_policy.algo.valid()) {
  413. auto workspace_size =
  414. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  415. W.update(workspace_size);
  416. }
  417. AlgoProxy<Opr, arity>::exec(opr, tensors, W.workspace());
  418. }
  419. };
  420. #define DEF_PROF(c) \
  421. template <> \
  422. struct OprProxy<c> : public OprProxyProfilingBase<c> { \
  423. using OprProxyProfilingBase<c>::OprProxyProfilingBase; \
  424. }
  425. DEF_PROF(MatrixMulForward);
  426. DEF_PROF(ConvolutionForward);
  427. DEF_PROF(ConvolutionBackwardData);
  428. DEF_PROF(ConvolutionBackwardFilter);
  429. DEF_PROF(LocalShareForward);
  430. DEF_PROF(LocalShareBackwardData);
  431. DEF_PROF(LocalShareBackwardFilter);
  432. DEF_PROF(DeformableConvForward);
  433. DEF_PROF(DeformableConvBackwardFilter);
  434. DEF_PROF(BatchConvBiasForward);
  435. DEF_PROF(ConvBiasForward);
  436. DEF_PROF(DeformableConvBackwardData);
  437. #undef DEF_PROF
  438. template <class Opr>
  439. struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
  440. using Base = OprProxyProfilingBase<Opr>;
  441. static constexpr int arity = OprTrait<Opr>::arity;
  442. void exec(Opr* opr, const TensorNDArray& tensors) {
  443. megdnn_assert(tensors.size() == arity);
  444. if (!Base::W.valid()) {
  445. Base::W = WorkspaceWrapper(opr->handle(), 0);
  446. }
  447. TensorLayoutArray layouts;
  448. for (auto&& tensor : tensors) {
  449. layouts.push_back(tensor.layout);
  450. }
  451. if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) {
  452. size_t min_time = std::numeric_limits<size_t>::max();
  453. for (auto algo :
  454. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) {
  455. opr->execution_policy().algo = algo.desc;
  456. auto preprocess_tensors = weight_prerocess(opr, tensors, algo.desc);
  457. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  458. typename Opr::PreprocessedFilter preprocessed_filter{
  459. nullptr, *preprocess_tensors};
  460. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  461. opr, layouts, &preprocessed_filter);
  462. Base::W.update(workspace_size);
  463. for (size_t times = 0; times < Base::warmup_times; ++times) {
  464. AlgoProxy<Opr, arity>::exec(
  465. opr, tensors, &preprocessed_filter, Base::W.workspace());
  466. }
  467. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  468. Timer timer;
  469. timer.start();
  470. for (size_t times = 0; times < Base::exec_times; ++times) {
  471. AlgoProxy<Opr, arity>::exec(
  472. opr, tensors, &preprocessed_filter, Base::W.workspace());
  473. }
  474. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  475. timer.stop();
  476. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  477. algo.desc.name.c_str());
  478. if (min_time > timer.get_time_in_us()) {
  479. min_time = timer.get_time_in_us();
  480. Base::target_execution_policy.algo = algo.desc;
  481. }
  482. }
  483. opr->execution_policy() = Base::target_execution_policy;
  484. auto preprocess_tensors =
  485. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  486. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  487. typename Opr::PreprocessedFilter preprocessed_filter{
  488. nullptr, *preprocess_tensors};
  489. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  490. opr, layouts, &preprocessed_filter);
  491. Base::W.update(workspace_size);
  492. }
  493. auto preprocess_tensors =
  494. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  495. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  496. typename Opr::PreprocessedFilter preprocessed_filter{
  497. nullptr, *preprocess_tensors};
  498. if (!Base::target_execution_policy.algo.valid()) {
  499. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  500. opr, layouts, &preprocessed_filter);
  501. Base::W.update(workspace_size);
  502. }
  503. AlgoProxy<Opr, arity>::exec(
  504. opr, tensors, &preprocessed_filter, Base::W.workspace());
  505. }
  506. //! handle weight preprocess
  507. std::shared_ptr<TensorNDArray> weight_prerocess(
  508. Opr* opr, const TensorNDArray& tensors,
  509. const typename Opr::AlgorithmDesc&) {
  510. TensorLayoutArray layouts;
  511. for (auto&& tensor : tensors) {
  512. layouts.push_back(tensor.layout);
  513. }
  514. auto weight_perprocess_layouts =
  515. AlgoProxy<Opr, arity>::deduce_preprocessed_filter_layout(opr, layouts);
  516. auto preprocessed_filter_tensors_ptr =
  517. Base::alloc_tensors(opr->handle(), weight_perprocess_layouts);
  518. typename Opr::PreprocessedFilter preprocessed_filter{
  519. nullptr, *preprocessed_filter_tensors_ptr};
  520. size_t preprocess_workspace_size =
  521. AlgoProxy<Opr, arity>::get_preprocess_workspace_in_bytes(opr, layouts);
  522. WorkspaceWrapper preprocess_workspace(opr->handle(), preprocess_workspace_size);
  523. AlgoProxy<Opr, arity>::exec_preprocess(
  524. opr, tensors, layouts, &preprocessed_filter,
  525. preprocess_workspace.workspace());
  526. return preprocessed_filter_tensors_ptr;
  527. }
  528. };
  529. #define DEF_PROF(c) \
  530. template <> \
  531. struct OprWeightPreprocessProxy<c> : public OprWeightPreprocessProxyImpl<c> { \
  532. using OprWeightPreprocessProxyImpl<c>::OprWeightPreprocessProxyImpl; \
  533. }
  534. DEF_PROF(ConvolutionForward);
  535. DEF_PROF(ConvBias);
  536. #undef DEF_PROF
  537. } // namespace test
  538. } // namespace megdnn
  539. // vim: syntax=cpp.doxygen