You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_proxy.h 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. #pragma once
  2. #include "src/common/opr_trait.h"
  3. #include "test/common/deduce_layout_proxy.h"
  4. #include "test/common/exec_proxy.h"
  5. #include "test/common/fast_run_cache.h"
  6. #include "test/common/inspect_type.h"
  7. #include "test/common/opr_algo_proxy.h"
  8. #include "test/common/timer.h"
  9. #include "test/common/workspace_wrapper.h"
  10. #include <algorithm>
  11. #include <limits>
  12. #include <memory>
  13. #include <unordered_map>
  14. namespace megdnn {
  15. namespace test {
  16. template <Algorithm::OprType>
  17. struct OprFromOprTypeTrait;
  18. template <typename Opr>
  19. struct OprTypeFromOprTrait;
  20. #define cb(_opr_type, _opr) \
  21. template <> \
  22. struct OprFromOprTypeTrait<Algorithm::OprType::_opr_type> { \
  23. using Opr = megdnn::_opr; \
  24. }; \
  25. template <> \
  26. struct OprTypeFromOprTrait<megdnn::_opr> { \
  27. constexpr static Algorithm::OprType opr_type = Algorithm::OprType::_opr_type; \
  28. }
  29. cb(MATRIX_MUL_FORWARD, MatrixMulForward);
  30. cb(BATCHED_MATRIX_MUL_FORWARD, BatchedMatrixMulForward);
  31. cb(CONVOLUTION_FORWARD, ConvolutionForward);
  32. cb(CONVOLUTION_BACKWARD_DATA, ConvolutionBackwardData);
  33. cb(CONVOLUTION_BACKWARD_FILTER, ConvolutionBackwardFilter);
  34. cb(CONVOLUTION3D_FORWARD, Convolution3DForward);
  35. cb(CONVOLUTION3D_BACKWARD_DATA, Convolution3DBackwardData);
  36. cb(CONVOLUTION3D_BACKWARD_FILTER, Convolution3DBackwardFilter);
  37. cb(LOCAL_SHARE_FORWARD, LocalShareForward);
  38. cb(LOCAL_SHARE_BACKWARD_DATA, LocalShareBackwardData);
  39. cb(LOCAL_SHARE_BACKWARD_FILTER, LocalShareBackwardFilter);
  40. cb(DEFORMABLE_CONV_FORWARD, DeformableConvForward);
  41. cb(DEFORMABLE_CONV_BACKWARD_DATA, DeformableConvBackwardData);
  42. cb(DEFORMABLE_CONV_BACKWARD_FILTER, DeformableConvBackwardFilter);
  43. cb(BATCH_CONV_FORWARD, BatchConvBiasForward);
  44. cb(CONVBIAS_FORWARD, ConvBiasForward);
  45. #undef cb
  46. // clang-format off
  47. #define FOREACH_OPR_TYPE(cb) \
  48. cb(MATRIX_MUL_FORWARD) \
  49. cb(BATCHED_MATRIX_MUL_FORWARD) \
  50. cb(CONVOLUTION_FORWARD) \
  51. cb(CONVOLUTION_BACKWARD_DATA) \
  52. cb(CONVOLUTION_BACKWARD_FILTER) \
  53. cb(CONVOLUTION3D_FORWARD) \
  54. cb(CONVOLUTION3D_BACKWARD_DATA) \
  55. cb(CONVOLUTION3D_BACKWARD_FILTER) \
  56. cb(LOCAL_SHARE_FORWARD) \
  57. cb(LOCAL_SHARE_BACKWARD_DATA) \
  58. cb(LOCAL_SHARE_BACKWARD_FILTER) \
  59. cb(DEFORMABLE_CONV_FORWARD) \
  60. cb(DEFORMABLE_CONV_BACKWARD_DATA) \
  61. cb(DEFORMABLE_CONV_BACKWARD_FILTER) \
  62. cb(BATCH_CONV_FORWARD) \
  63. cb(CONVBIAS_FORWARD)
  64. #define FOREACH_OPR_TYPE_WITH_STMT(cb, stmt) \
  65. cb(MATRIX_MUL_FORWARD, stmt) \
  66. cb(BATCHED_MATRIX_MUL_FORWARD, stmt) \
  67. cb(CONVOLUTION_FORWARD, stmt) \
  68. cb(CONVOLUTION_BACKWARD_DATA, stmt) \
  69. cb(CONVOLUTION_BACKWARD_FILTER, stmt) \
  70. cb(CONVOLUTION3D_FORWARD, stmt) \
  71. cb(CONVOLUTION3D_BACKWARD_DATA, stmt) \
  72. cb(CONVOLUTION3D_BACKWARD_FILTER, stmt) \
  73. cb(LOCAL_SHARE_FORWARD, stmt) \
  74. cb(LOCAL_SHARE_BACKWARD_DATA, stmt) \
  75. cb(LOCAL_SHARE_BACKWARD_FILTER, stmt) \
  76. cb(DEFORMABLE_CONV_FORWARD, stmt) \
  77. cb(DEFORMABLE_CONV_BACKWARD_DATA, stmt) \
  78. cb(DEFORMABLE_CONV_BACKWARD_FILTER, stmt) \
  79. cb(BATCH_CONV_FORWARD, stmt) \
  80. cb(CONVBIAS_FORWARD, stmt)
  81. // clang-format on
  82. #define _OPR_TYPE_CASE(_opr_type, _stmt) \
  83. case Algorithm::OprType::_opr_type: { \
  84. using _Opr = typename OprFromOprTypeTrait<Algorithm::OprType::_opr_type>::Opr; \
  85. _stmt; \
  86. break; \
  87. }
  88. #define FOREACH_OPR_TYPE_DISPATCH(_search_items, _stmt) \
  89. for (size_t _item_idx = 0; _item_idx < _search_items.size(); _item_idx++) { \
  90. auto&& _item = _search_items[_item_idx]; \
  91. switch (_item.opr_type) { \
  92. FOREACH_OPR_TYPE_WITH_STMT(_OPR_TYPE_CASE, _stmt) \
  93. default: \
  94. megdnn_throw("unknown opr_type"); \
  95. } \
  96. }
  97. template <
  98. typename Opr, size_t arity = OprTrait<Opr>::arity,
  99. bool has_workspace = OprTrait<Opr>::has_workspace,
  100. bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
  101. struct OprProxyDefaultImpl : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
  102. public ExecProxy<Opr, arity, has_workspace> {};
  103. template <typename Opr>
  104. struct OprProxy : public OprProxyDefaultImpl<Opr> {};
  105. template <typename Opr>
  106. struct OprWeightPreprocessProxy : public OprProxyDefaultImpl<Opr> {};
  107. template <typename Opr>
  108. struct OprProxyVectorToSingle {};
  109. template <>
  110. struct OprProxy<ElemwiseForward> {
  111. static void deduce_layout(ElemwiseForward* opr, TensorLayoutArray& layouts) {
  112. megdnn_assert(layouts.size() >= 2);
  113. auto inp = layouts;
  114. inp.pop_back();
  115. opr->deduce_layout(inp, layouts.back());
  116. }
  117. static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
  118. megdnn_assert(tensors.size() >= 2);
  119. auto inp = tensors;
  120. inp.pop_back();
  121. opr->exec(inp, tensors.back());
  122. }
  123. };
  124. template <>
  125. struct OprProxy<ElemwiseMultiType> {
  126. static void deduce_layout(ElemwiseMultiType* opr, TensorLayoutArray& layouts) {
  127. megdnn_assert(layouts.size() >= 2);
  128. auto inp = layouts;
  129. inp.pop_back();
  130. opr->deduce_layout(inp, layouts.back());
  131. }
  132. static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
  133. megdnn_assert(tensors.size() >= 2);
  134. auto inp = tensors;
  135. inp.pop_back();
  136. opr->exec(inp, tensors.back());
  137. }
  138. };
  139. template <>
  140. struct OprProxy<ConcatForward> {
  141. WorkspaceWrapper W;
  142. static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
  143. megdnn_assert(layouts.size() >= 2);
  144. auto inp = layouts;
  145. inp.pop_back();
  146. opr->deduce_layout(inp, layouts.back());
  147. }
  148. void exec(ConcatForward* opr, const TensorNDArray& tensors) {
  149. if (!W.valid()) {
  150. W = WorkspaceWrapper(opr->handle(), 0);
  151. }
  152. megdnn_assert(tensors.size() >= 2);
  153. auto inp = tensors;
  154. inp.pop_back();
  155. TensorLayoutArray layouts(tensors.size());
  156. std::transform(
  157. tensors.begin(), tensors.end(), layouts.begin(),
  158. [](const TensorND& tensor) { return tensor.layout; });
  159. auto inp_layouts = layouts;
  160. inp_layouts.pop_back();
  161. W.update(opr->get_workspace_in_bytes(inp_layouts, layouts.back()));
  162. auto inp_tensors = tensors;
  163. inp_tensors.pop_back();
  164. opr->exec(inp_tensors, tensors.back(), W.workspace());
  165. }
  166. };
  167. template <>
  168. struct OprProxy<CheckNonFinite> {
  169. static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) {
  170. megdnn_assert(layouts.size() >= 2);
  171. auto inp = layouts;
  172. inp.pop_back();
  173. opr->deduce_layout(inp, layouts.back());
  174. }
  175. static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) {
  176. megdnn_assert(tensors.size() >= 2);
  177. auto inps = tensors;
  178. inps.pop_back();
  179. TensorLayoutArray inp_layouts(inps.size());
  180. std::transform(
  181. inps.begin(), inps.end(), inp_layouts.begin(),
  182. [](const TensorND& tensor) { return tensor.layout; });
  183. WorkspaceWrapper W(
  184. opr->handle(),
  185. opr->get_workspace_in_bytes(inp_layouts, tensors.back().layout));
  186. opr->exec(inps, tensors.back(), W.workspace());
  187. }
  188. };
  189. template <>
  190. struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
  191. WorkspaceWrapper W;
  192. void exec(SplitForward* opr, const TensorNDArray& tensors) {
  193. megdnn_assert(tensors.size() >= 2);
  194. if (!W.valid()) {
  195. W = WorkspaceWrapper(opr->handle(), 0);
  196. }
  197. auto out = tensors;
  198. out.erase(out.begin());
  199. TensorLayoutArray layouts(tensors.size());
  200. std::transform(
  201. tensors.begin(), tensors.end(), layouts.begin(),
  202. [](const TensorND& tensor) { return tensor.layout; });
  203. auto out_layouts = layouts;
  204. out_layouts.erase(out_layouts.begin());
  205. W.update(opr->get_workspace_in_bytes(layouts.front(), out_layouts));
  206. auto out_tensors = tensors;
  207. out_tensors.erase(out_tensors.begin());
  208. opr->exec(tensors.front(), out_tensors, W.workspace());
  209. }
  210. };
  211. //! OprProxy impl for tenary oprs with profiling support
  212. template <class Opr>
  213. struct OprProxyProfilingBase
  214. : public DeduceLayoutProxy<
  215. Opr, OprTrait<Opr>::arity, OprTrait<Opr>::can_deduce_layout> {
  216. static constexpr int arity = OprTrait<Opr>::arity;
  217. size_t warmup_times = 10, exec_times = 100;
  218. //! whether to enable profiling
  219. bool m_profiling;
  220. WorkspaceWrapper W;
  221. //! target algo setup by profiler; it can also be directly specified by the
  222. //! caller
  223. ExecutionPolicy target_execution_policy;
  224. OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
  225. //! used for alloc tensor for weight preprocess
  226. static std::shared_ptr<TensorNDArray> alloc_tensors(
  227. Handle* handle, const TensorLayoutArray& layouts) {
  228. auto deleter = [handle](TensorNDArray* ptr) {
  229. for (auto&& i : *ptr) {
  230. auto pdata =
  231. static_cast<dt_byte*>(i.raw_ptr()) + i.layout.span().low_byte;
  232. megdnn_free(handle, pdata);
  233. }
  234. delete ptr;
  235. };
  236. std::shared_ptr<TensorNDArray> ret{new TensorNDArray, deleter};
  237. for (size_t i = 0; i < layouts.size(); ++i) {
  238. auto span = layouts[i].span();
  239. ret->emplace_back(
  240. static_cast<dt_byte*>(megdnn_malloc(handle, span.dist_byte())) -
  241. span.low_byte,
  242. layouts[i]);
  243. }
  244. return ret;
  245. }
  246. /**
  247. * flatten search space in postorder traversal
  248. * The subopr search construct a search tree
  249. *
  250. * A
  251. * / \
  252. * B1B2 C
  253. * / \
  254. * D1D2D3 E
  255. * We use postorder traverse the search tree.
  256. * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
  257. */
  258. static std::vector<Algorithm::SearchItem> flatten_search_space(
  259. const TensorLayoutArray layouts, const std::string& param, Handle* handle) {
  260. megdnn_assert(layouts.size() == arity);
  261. auto opr = handle->create_operator<Opr>();
  262. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  263. std::vector<Algorithm::SearchItem> ret;
  264. for (auto algo_info :
  265. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  266. Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
  267. std::vector<Algorithm::SearchItem>&& sub_items =
  268. algo->get_subopr_list(layouts, opr.get());
  269. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  270. auto space = OprProxyProfilingBase<_Opr>::flatten_search_space(
  271. _item.layouts, _item.param, handle);
  272. ret.insert(ret.end(), space.begin(), space.end());
  273. });
  274. }
  275. ret.push_back({OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  276. return ret;
  277. }
  278. static void construct_execution_policy(
  279. const TensorLayoutArray& layouts, const std::string& param, Handle* handle,
  280. FastRunCache& cache, ExecutionPolicy& policy) {
  281. megdnn_assert(layouts.size() == arity);
  282. auto opr = handle->create_operator<Opr>();
  283. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  284. if (!policy.algo.valid()) {
  285. policy.algo = cache.get(Algorithm::SearchItem{
  286. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  287. megdnn_assert(
  288. policy.algo.valid(),
  289. "No cache found, maybe some error occured in "
  290. "flatten_search_space or get_subopr_list");
  291. }
  292. policy.sub_policy.clear();
  293. Algorithm* algo = opr->get_algorithm_from_desc(policy.algo);
  294. std::vector<Algorithm::SearchItem>&& sub_items =
  295. algo->get_subopr_list(layouts, opr.get());
  296. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  297. policy.sub_policy.push_back({});
  298. OprProxyProfilingBase<_Opr>::construct_execution_policy(
  299. _item.layouts, _item.param, handle, cache,
  300. policy.sub_policy.back());
  301. });
  302. return;
  303. }
  304. /**
  305. * \brief search and get the best execution_policy
  306. */
  307. static void search(
  308. const TensorLayoutArray& layouts, const std::string& param,
  309. WorkspaceWrapper& workspace_wrapper, Handle* handle, size_t warmup_times,
  310. size_t exec_times, FastRunCache& cache) {
  311. megdnn_assert(layouts.size() == arity);
  312. auto opr = handle->create_operator<Opr>();
  313. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  314. SmallVector<size_t> sizes_in_bytes;
  315. for (const auto& layout : layouts) {
  316. sizes_in_bytes.push_back(layout.span().dist_byte());
  317. }
  318. float min_time = std::numeric_limits<float>::max();
  319. Algorithm::Info::Desc best_algo;
  320. std::string log_info = "Profiling start: ";
  321. for (auto&& layout : layouts) {
  322. log_info += layout.to_string() + " ";
  323. }
  324. megdnn_log("%s", log_info.c_str());
  325. best_algo = cache.get(Algorithm::SearchItem{
  326. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  327. if (best_algo.valid()) {
  328. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  329. MEGDNN_MARK_USED_VAR(algo);
  330. megdnn_log("Find best algo %s in cache", algo->name());
  331. return;
  332. }
  333. for (auto algo :
  334. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  335. //! construct execution_policy
  336. opr->execution_policy().algo = algo.desc;
  337. construct_execution_policy(
  338. layouts, param, handle, cache, opr->execution_policy());
  339. auto workspace_size =
  340. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr.get(), layouts);
  341. sizes_in_bytes.push_back(workspace_size);
  342. WorkspaceBundle wb(nullptr, sizes_in_bytes);
  343. workspace_wrapper.update(wb.total_size_in_bytes());
  344. wb.set(workspace_wrapper.workspace().raw_ptr);
  345. TensorNDArray tensors;
  346. for (size_t i = 0; i < arity; i++) {
  347. tensors.push_back({wb.get(i), layouts[i]});
  348. }
  349. for (size_t times = 0; times < warmup_times; ++times) {
  350. AlgoProxy<Opr, arity>::exec(
  351. opr.get(), tensors, wb.get_workspace(arity));
  352. }
  353. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  354. Timer timer;
  355. timer.start();
  356. for (size_t times = 0; times < exec_times; ++times) {
  357. AlgoProxy<Opr, arity>::exec(
  358. opr.get(), tensors, wb.get_workspace(arity));
  359. }
  360. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  361. timer.stop();
  362. megdnn_log(
  363. "%.3fms %s", timer.get_time_in_us() / 1e3, algo.desc.name.c_str());
  364. if (min_time > timer.get_time_in_us()) {
  365. min_time = timer.get_time_in_us();
  366. best_algo = algo.desc;
  367. }
  368. sizes_in_bytes.pop_back();
  369. }
  370. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  371. MEGDNN_MARK_USED_VAR(algo);
  372. megdnn_log("Profiling end, got best algo: %s", algo->name());
  373. cache.put(
  374. Algorithm::SearchItem{
  375. OprTypeFromOprTrait<Opr>::opr_type, param, layouts},
  376. best_algo);
  377. }
  378. void exec(Opr* opr, const TensorNDArray& tensors) {
  379. megdnn_assert(tensors.size() == arity);
  380. if (!W.valid()) {
  381. W = WorkspaceWrapper(opr->handle(), 0);
  382. }
  383. TensorLayoutArray layouts;
  384. for (auto&& tensor : tensors) {
  385. layouts.push_back(tensor.layout);
  386. }
  387. if (m_profiling && !target_execution_policy.algo.valid()) {
  388. FastRunCache cache;
  389. std::string param_str;
  390. Algorithm::serialize_write_pod(opr->param(), param_str);
  391. auto&& search_items =
  392. flatten_search_space(layouts, param_str, opr->handle());
  393. FOREACH_OPR_TYPE_DISPATCH(search_items, {
  394. OprProxyProfilingBase<_Opr>::search(
  395. _item.layouts, _item.param, W, opr->handle(), warmup_times,
  396. exec_times, cache);
  397. });
  398. construct_execution_policy(
  399. layouts, param_str, opr->handle(), cache, opr->execution_policy());
  400. target_execution_policy = opr->execution_policy();
  401. auto workspace_size =
  402. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  403. W.update(workspace_size);
  404. }
  405. if (!target_execution_policy.algo.valid()) {
  406. auto workspace_size =
  407. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  408. W.update(workspace_size);
  409. }
  410. AlgoProxy<Opr, arity>::exec(opr, tensors, W.workspace());
  411. }
  412. };
  413. #define DEF_PROF(c) \
  414. template <> \
  415. struct OprProxy<c> : public OprProxyProfilingBase<c> { \
  416. using OprProxyProfilingBase<c>::OprProxyProfilingBase; \
  417. }
  418. DEF_PROF(MatrixMulForward);
  419. DEF_PROF(ConvolutionForward);
  420. DEF_PROF(ConvolutionBackwardData);
  421. DEF_PROF(ConvolutionBackwardFilter);
  422. DEF_PROF(LocalShareForward);
  423. DEF_PROF(LocalShareBackwardData);
  424. DEF_PROF(LocalShareBackwardFilter);
  425. DEF_PROF(DeformableConvForward);
  426. DEF_PROF(DeformableConvBackwardFilter);
  427. DEF_PROF(BatchConvBiasForward);
  428. DEF_PROF(ConvBiasForward);
  429. DEF_PROF(DeformableConvBackwardData);
  430. #undef DEF_PROF
  431. template <class Opr>
  432. struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
  433. using Base = OprProxyProfilingBase<Opr>;
  434. static constexpr int arity = OprTrait<Opr>::arity;
  435. void exec(Opr* opr, const TensorNDArray& tensors) {
  436. megdnn_assert(tensors.size() == arity);
  437. if (!Base::W.valid()) {
  438. Base::W = WorkspaceWrapper(opr->handle(), 0);
  439. }
  440. TensorLayoutArray layouts;
  441. for (auto&& tensor : tensors) {
  442. layouts.push_back(tensor.layout);
  443. }
  444. if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) {
  445. size_t min_time = std::numeric_limits<size_t>::max();
  446. for (auto algo :
  447. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) {
  448. opr->execution_policy().algo = algo.desc;
  449. auto preprocess_tensors = weight_prerocess(opr, tensors, algo.desc);
  450. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  451. typename Opr::PreprocessedFilter preprocessed_filter{
  452. nullptr, *preprocess_tensors};
  453. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  454. opr, layouts, &preprocessed_filter);
  455. Base::W.update(workspace_size);
  456. for (size_t times = 0; times < Base::warmup_times; ++times) {
  457. AlgoProxy<Opr, arity>::exec(
  458. opr, tensors, &preprocessed_filter, Base::W.workspace());
  459. }
  460. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  461. Timer timer;
  462. timer.start();
  463. for (size_t times = 0; times < Base::exec_times; ++times) {
  464. AlgoProxy<Opr, arity>::exec(
  465. opr, tensors, &preprocessed_filter, Base::W.workspace());
  466. }
  467. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  468. timer.stop();
  469. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  470. algo.desc.name.c_str());
  471. if (min_time > timer.get_time_in_us()) {
  472. min_time = timer.get_time_in_us();
  473. Base::target_execution_policy.algo = algo.desc;
  474. }
  475. }
  476. opr->execution_policy() = Base::target_execution_policy;
  477. auto preprocess_tensors =
  478. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  479. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  480. typename Opr::PreprocessedFilter preprocessed_filter{
  481. nullptr, *preprocess_tensors};
  482. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  483. opr, layouts, &preprocessed_filter);
  484. Base::W.update(workspace_size);
  485. }
  486. auto preprocess_tensors =
  487. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  488. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  489. typename Opr::PreprocessedFilter preprocessed_filter{
  490. nullptr, *preprocess_tensors};
  491. if (!Base::target_execution_policy.algo.valid()) {
  492. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  493. opr, layouts, &preprocessed_filter);
  494. Base::W.update(workspace_size);
  495. }
  496. AlgoProxy<Opr, arity>::exec(
  497. opr, tensors, &preprocessed_filter, Base::W.workspace());
  498. }
  499. //! handle weight preprocess
  500. std::shared_ptr<TensorNDArray> weight_prerocess(
  501. Opr* opr, const TensorNDArray& tensors,
  502. const typename Opr::AlgorithmDesc&) {
  503. TensorLayoutArray layouts;
  504. for (auto&& tensor : tensors) {
  505. layouts.push_back(tensor.layout);
  506. }
  507. auto weight_perprocess_layouts =
  508. AlgoProxy<Opr, arity>::deduce_preprocessed_filter_layout(opr, layouts);
  509. auto preprocessed_filter_tensors_ptr =
  510. Base::alloc_tensors(opr->handle(), weight_perprocess_layouts);
  511. typename Opr::PreprocessedFilter preprocessed_filter{
  512. nullptr, *preprocessed_filter_tensors_ptr};
  513. size_t preprocess_workspace_size =
  514. AlgoProxy<Opr, arity>::get_preprocess_workspace_in_bytes(opr, layouts);
  515. WorkspaceWrapper preprocess_workspace(opr->handle(), preprocess_workspace_size);
  516. AlgoProxy<Opr, arity>::exec_preprocess(
  517. opr, tensors, layouts, &preprocessed_filter,
  518. preprocess_workspace.workspace());
  519. return preprocessed_filter_tensors_ptr;
  520. }
  521. };
  522. #define DEF_PROF(c) \
  523. template <> \
  524. struct OprWeightPreprocessProxy<c> : public OprWeightPreprocessProxyImpl<c> { \
  525. using OprWeightPreprocessProxyImpl<c>::OprWeightPreprocessProxyImpl; \
  526. }
  527. DEF_PROF(ConvolutionForward);
  528. DEF_PROF(ConvBias);
  529. #undef DEF_PROF
  530. } // namespace test
  531. } // namespace megdnn
  532. // vim: syntax=cpp.doxygen