You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_proxy.h 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. #pragma once
  2. #include "src/common/opr_trait.h"
  3. #include "test/common/deduce_layout_proxy.h"
  4. #include "test/common/exec_proxy.h"
  5. #include "test/common/fast_run_cache.h"
  6. #include "test/common/inspect_type.h"
  7. #include "test/common/opr_algo_proxy.h"
  8. #include "test/common/timer.h"
  9. #include "test/common/workspace_wrapper.h"
  10. #include <algorithm>
  11. #include <limits>
  12. #include <memory>
  13. #include <unordered_map>
  14. namespace megdnn {
  15. namespace test {
  16. template <Algorithm::OprType>
  17. struct OprFromOprTypeTrait;
  18. template <typename Opr>
  19. struct OprTypeFromOprTrait;
  20. #define cb(_opr_type, _opr) \
  21. template <> \
  22. struct OprFromOprTypeTrait<Algorithm::OprType::_opr_type> { \
  23. using Opr = megdnn::_opr; \
  24. }; \
  25. template <> \
  26. struct OprTypeFromOprTrait<megdnn::_opr> { \
  27. constexpr static Algorithm::OprType opr_type = Algorithm::OprType::_opr_type; \
  28. }
  29. cb(MATRIX_MUL_FORWARD, MatrixMulForward);
  30. cb(BATCHED_MATRIX_MUL_FORWARD, BatchedMatrixMulForward);
  31. cb(CONVOLUTION_FORWARD, ConvolutionForward);
  32. cb(CONVOLUTION_BACKWARD_DATA, ConvolutionBackwardData);
  33. cb(CONVOLUTION_BACKWARD_FILTER, ConvolutionBackwardFilter);
  34. cb(CONVOLUTION3D_FORWARD, Convolution3DForward);
  35. cb(CONVOLUTION3D_BACKWARD_DATA, Convolution3DBackwardData);
  36. cb(CONVOLUTION3D_BACKWARD_FILTER, Convolution3DBackwardFilter);
  37. cb(LOCAL_SHARE_FORWARD, LocalShareForward);
  38. cb(LOCAL_SHARE_BACKWARD_DATA, LocalShareBackwardData);
  39. cb(LOCAL_SHARE_BACKWARD_FILTER, LocalShareBackwardFilter);
  40. cb(DEFORMABLE_CONV_FORWARD, DeformableConvForward);
  41. cb(DEFORMABLE_CONV_BACKWARD_DATA, DeformableConvBackwardData);
  42. cb(DEFORMABLE_CONV_BACKWARD_FILTER, DeformableConvBackwardFilter);
  43. cb(BATCH_CONV_FORWARD, BatchConvBiasForward);
  44. cb(CONVBIAS_FORWARD, ConvBiasForward);
  45. #undef cb
  46. // clang-format off
  47. #define FOREACH_OPR_TYPE(cb) \
  48. cb(MATRIX_MUL_FORWARD) \
  49. cb(BATCHED_MATRIX_MUL_FORWARD) \
  50. cb(CONVOLUTION_FORWARD) \
  51. cb(CONVOLUTION_BACKWARD_DATA) \
  52. cb(CONVOLUTION_BACKWARD_FILTER) \
  53. cb(CONVOLUTION3D_FORWARD) \
  54. cb(CONVOLUTION3D_BACKWARD_DATA) \
  55. cb(CONVOLUTION3D_BACKWARD_FILTER) \
  56. cb(LOCAL_SHARE_FORWARD) \
  57. cb(LOCAL_SHARE_BACKWARD_DATA) \
  58. cb(LOCAL_SHARE_BACKWARD_FILTER) \
  59. cb(DEFORMABLE_CONV_FORWARD) \
  60. cb(DEFORMABLE_CONV_BACKWARD_DATA) \
  61. cb(DEFORMABLE_CONV_BACKWARD_FILTER) \
  62. cb(BATCH_CONV_FORWARD) \
  63. cb(CONVBIAS_FORWARD)
  64. #define FOREACH_OPR_TYPE_WITH_STMT(cb, stmt) \
  65. cb(MATRIX_MUL_FORWARD, stmt) \
  66. cb(BATCHED_MATRIX_MUL_FORWARD, stmt) \
  67. cb(CONVOLUTION_FORWARD, stmt) \
  68. cb(CONVOLUTION_BACKWARD_DATA, stmt) \
  69. cb(CONVOLUTION_BACKWARD_FILTER, stmt) \
  70. cb(CONVOLUTION3D_FORWARD, stmt) \
  71. cb(CONVOLUTION3D_BACKWARD_DATA, stmt) \
  72. cb(CONVOLUTION3D_BACKWARD_FILTER, stmt) \
  73. cb(LOCAL_SHARE_FORWARD, stmt) \
  74. cb(LOCAL_SHARE_BACKWARD_DATA, stmt) \
  75. cb(LOCAL_SHARE_BACKWARD_FILTER, stmt) \
  76. cb(DEFORMABLE_CONV_FORWARD, stmt) \
  77. cb(DEFORMABLE_CONV_BACKWARD_DATA, stmt) \
  78. cb(DEFORMABLE_CONV_BACKWARD_FILTER, stmt) \
  79. cb(BATCH_CONV_FORWARD, stmt) \
  80. cb(CONVBIAS_FORWARD, stmt)
  81. // clang-format on
  82. #define _OPR_TYPE_CASE(_opr_type, _stmt) \
  83. case Algorithm::OprType::_opr_type: { \
  84. using _Opr = typename OprFromOprTypeTrait<Algorithm::OprType::_opr_type>::Opr; \
  85. _stmt; \
  86. break; \
  87. }
  88. #define FOREACH_OPR_TYPE_DISPATCH(_search_items, _stmt) \
  89. for (size_t _item_idx = 0; _item_idx < _search_items.size(); _item_idx++) { \
  90. auto&& _item = _search_items[_item_idx]; \
  91. switch (_item.opr_type) { \
  92. FOREACH_OPR_TYPE_WITH_STMT(_OPR_TYPE_CASE, _stmt) \
  93. default: \
  94. megdnn_throw("unknown opr_type"); \
  95. } \
  96. }
  97. template <
  98. typename Opr, size_t arity = OprTrait<Opr>::arity,
  99. bool has_workspace = OprTrait<Opr>::has_workspace,
  100. bool can_deduce_layout = OprTrait<Opr>::can_deduce_layout>
  101. struct OprProxyDefaultImpl : public DeduceLayoutProxy<Opr, arity, can_deduce_layout>,
  102. public ExecProxy<Opr, arity, has_workspace> {
  103. virtual void init(Opr*, const TensorNDArray&) {}
  104. virtual ~OprProxyDefaultImpl() {}
  105. };
  106. template <typename Opr>
  107. struct OprProxy : public OprProxyDefaultImpl<Opr> {};
  108. template <typename Opr>
  109. struct OprWeightPreprocessProxy : public OprProxyDefaultImpl<Opr> {};
  110. template <typename Opr>
  111. struct OprWeightPreprocessBenchmarkProxy : OprProxyDefaultImpl<Opr> {};
  112. template <typename Opr>
  113. struct OprProxyVectorToSingle {};
  114. template <>
  115. struct OprProxy<ElemwiseForward> {
  116. static void deduce_layout(ElemwiseForward* opr, TensorLayoutArray& layouts) {
  117. megdnn_assert(layouts.size() >= 2);
  118. auto inp = layouts;
  119. inp.pop_back();
  120. opr->deduce_layout(inp, layouts.back());
  121. }
  122. static void init(ElemwiseForward*, const TensorNDArray&) {}
  123. static void exec(ElemwiseForward* opr, const TensorNDArray& tensors) {
  124. megdnn_assert(tensors.size() >= 2);
  125. auto inp = tensors;
  126. inp.pop_back();
  127. opr->exec(inp, tensors.back());
  128. }
  129. };
  130. template <>
  131. struct OprProxy<ElemwiseMultiType> {
  132. static void deduce_layout(ElemwiseMultiType* opr, TensorLayoutArray& layouts) {
  133. megdnn_assert(layouts.size() >= 2);
  134. auto inp = layouts;
  135. inp.pop_back();
  136. opr->deduce_layout(inp, layouts.back());
  137. }
  138. static void init(ElemwiseMultiType*, const TensorNDArray&) {}
  139. static void exec(ElemwiseMultiType* opr, const TensorNDArray& tensors) {
  140. megdnn_assert(tensors.size() >= 2);
  141. auto inp = tensors;
  142. inp.pop_back();
  143. opr->exec(inp, tensors.back());
  144. }
  145. };
  146. template <>
  147. struct OprProxy<ConcatForward> {
  148. WorkspaceWrapper W;
  149. static void deduce_layout(ConcatForward* opr, TensorLayoutArray& layouts) {
  150. megdnn_assert(layouts.size() >= 2);
  151. auto inp = layouts;
  152. inp.pop_back();
  153. opr->deduce_layout(inp, layouts.back());
  154. }
  155. static void init(ConcatForward*, const TensorNDArray&) {}
  156. void exec(ConcatForward* opr, const TensorNDArray& tensors) {
  157. if (!W.valid()) {
  158. W = WorkspaceWrapper(opr->handle(), 0);
  159. }
  160. megdnn_assert(tensors.size() >= 2);
  161. auto inp = tensors;
  162. inp.pop_back();
  163. TensorLayoutArray layouts(tensors.size());
  164. std::transform(
  165. tensors.begin(), tensors.end(), layouts.begin(),
  166. [](const TensorND& tensor) { return tensor.layout; });
  167. auto inp_layouts = layouts;
  168. inp_layouts.pop_back();
  169. W.update(opr->get_workspace_in_bytes(inp_layouts, layouts.back()));
  170. auto inp_tensors = tensors;
  171. inp_tensors.pop_back();
  172. opr->exec(inp_tensors, tensors.back(), W.workspace());
  173. }
  174. };
  175. template <>
  176. struct OprProxy<CheckNonFinite> {
  177. static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) {
  178. megdnn_assert(layouts.size() >= 2);
  179. auto inp = layouts;
  180. inp.pop_back();
  181. opr->deduce_layout(inp, layouts.back());
  182. }
  183. static void init(CheckNonFinite*, const TensorNDArray&) {}
  184. static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) {
  185. megdnn_assert(tensors.size() >= 2);
  186. auto inps = tensors;
  187. inps.pop_back();
  188. TensorLayoutArray inp_layouts(inps.size());
  189. std::transform(
  190. inps.begin(), inps.end(), inp_layouts.begin(),
  191. [](const TensorND& tensor) { return tensor.layout; });
  192. WorkspaceWrapper W(
  193. opr->handle(),
  194. opr->get_workspace_in_bytes(inp_layouts, tensors.back().layout));
  195. opr->exec(inps, tensors.back(), W.workspace());
  196. }
  197. };
  198. template <>
  199. struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
  200. WorkspaceWrapper W;
  201. void init(SplitForward*, const TensorNDArray&) {}
  202. void exec(SplitForward* opr, const TensorNDArray& tensors) {
  203. megdnn_assert(tensors.size() >= 2);
  204. if (!W.valid()) {
  205. W = WorkspaceWrapper(opr->handle(), 0);
  206. }
  207. auto out = tensors;
  208. out.erase(out.begin());
  209. TensorLayoutArray layouts(tensors.size());
  210. std::transform(
  211. tensors.begin(), tensors.end(), layouts.begin(),
  212. [](const TensorND& tensor) { return tensor.layout; });
  213. auto out_layouts = layouts;
  214. out_layouts.erase(out_layouts.begin());
  215. W.update(opr->get_workspace_in_bytes(layouts.front(), out_layouts));
  216. auto out_tensors = tensors;
  217. out_tensors.erase(out_tensors.begin());
  218. opr->exec(tensors.front(), out_tensors, W.workspace());
  219. }
  220. };
  221. //! OprProxy impl for tenary oprs with profiling support
  222. template <class Opr>
  223. struct OprProxyProfilingBase
  224. : public DeduceLayoutProxy<
  225. Opr, OprTrait<Opr>::arity, OprTrait<Opr>::can_deduce_layout> {
  226. static constexpr int arity = OprTrait<Opr>::arity;
  227. size_t warmup_times = 10, exec_times = 100;
  228. //! whether to enable profiling
  229. bool m_profiling;
  230. WorkspaceWrapper W;
  231. //! target algo setup by profiler; it can also be directly specified by the
  232. //! caller
  233. ExecutionPolicy target_execution_policy;
  234. OprProxyProfilingBase(bool profile = false) { m_profiling = profile; }
  235. //! used for alloc tensor for weight preprocess
  236. static std::shared_ptr<TensorNDArray> alloc_tensors(
  237. Handle* handle, const TensorLayoutArray& layouts) {
  238. auto deleter = [handle](TensorNDArray* ptr) {
  239. for (auto&& i : *ptr) {
  240. auto pdata =
  241. static_cast<dt_byte*>(i.raw_ptr()) + i.layout.span().low_byte;
  242. megdnn_free(handle, pdata);
  243. }
  244. delete ptr;
  245. };
  246. std::shared_ptr<TensorNDArray> ret{new TensorNDArray, deleter};
  247. for (size_t i = 0; i < layouts.size(); ++i) {
  248. auto span = layouts[i].span();
  249. ret->emplace_back(
  250. static_cast<dt_byte*>(megdnn_malloc(handle, span.dist_byte())) -
  251. span.low_byte,
  252. layouts[i]);
  253. }
  254. return ret;
  255. }
  256. /**
  257. * flatten search space in postorder traversal
  258. * The subopr search construct a search tree
  259. *
  260. * A
  261. * / \
  262. * B1B2 C
  263. * / \
  264. * D1D2D3 E
  265. * We use postorder traverse the search tree.
  266. * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
  267. */
  268. static std::vector<Algorithm::SearchItem> flatten_search_space(
  269. const TensorLayoutArray layouts, const std::string& param, Handle* handle) {
  270. megdnn_assert(layouts.size() == arity);
  271. auto opr = handle->create_operator<Opr>();
  272. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  273. std::vector<Algorithm::SearchItem> ret;
  274. for (auto algo_info :
  275. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  276. Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
  277. std::vector<Algorithm::SearchItem>&& sub_items =
  278. algo->get_subopr_list(layouts, opr.get());
  279. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  280. auto space = OprProxyProfilingBase<_Opr>::flatten_search_space(
  281. _item.layouts, _item.param, handle);
  282. ret.insert(ret.end(), space.begin(), space.end());
  283. });
  284. }
  285. ret.push_back({OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  286. return ret;
  287. }
  288. static void construct_execution_policy(
  289. const TensorLayoutArray& layouts, const std::string& param, Handle* handle,
  290. FastRunCache& cache, ExecutionPolicy& policy) {
  291. megdnn_assert(layouts.size() == arity);
  292. auto opr = handle->create_operator<Opr>();
  293. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  294. if (!policy.algo.valid()) {
  295. policy.algo = cache.get(Algorithm::SearchItem{
  296. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  297. megdnn_assert(
  298. policy.algo.valid(),
  299. "No cache found, maybe some error occured in "
  300. "flatten_search_space or get_subopr_list");
  301. }
  302. policy.sub_policy.clear();
  303. Algorithm* algo = opr->get_algorithm_from_desc(policy.algo);
  304. std::vector<Algorithm::SearchItem>&& sub_items =
  305. algo->get_subopr_list(layouts, opr.get());
  306. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  307. policy.sub_policy.push_back({});
  308. OprProxyProfilingBase<_Opr>::construct_execution_policy(
  309. _item.layouts, _item.param, handle, cache,
  310. policy.sub_policy.back());
  311. });
  312. return;
  313. }
  314. /**
  315. * \brief search and get the best execution_policy
  316. */
  317. static void search(
  318. const TensorLayoutArray& layouts, const std::string& param,
  319. WorkspaceWrapper& workspace_wrapper, Handle* handle, size_t warmup_times,
  320. size_t exec_times, FastRunCache& cache) {
  321. megdnn_assert(layouts.size() == arity);
  322. auto opr = handle->create_operator<Opr>();
  323. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  324. SmallVector<size_t> sizes_in_bytes;
  325. for (const auto& layout : layouts) {
  326. sizes_in_bytes.push_back(layout.span().dist_byte());
  327. }
  328. float min_time = std::numeric_limits<float>::max();
  329. Algorithm::Info::Desc best_algo;
  330. std::string log_info = "Profiling start: ";
  331. for (auto&& layout : layouts) {
  332. log_info += layout.to_string() + " ";
  333. }
  334. megdnn_log("%s", log_info.c_str());
  335. best_algo = cache.get(Algorithm::SearchItem{
  336. OprTypeFromOprTrait<Opr>::opr_type, param, layouts});
  337. if (best_algo.valid()) {
  338. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  339. MEGDNN_MARK_USED_VAR(algo);
  340. megdnn_log("Find best algo %s in cache", algo->name());
  341. return;
  342. }
  343. for (auto algo :
  344. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr.get(), layouts)) {
  345. //! construct execution_policy
  346. opr->execution_policy().algo = algo.desc;
  347. construct_execution_policy(
  348. layouts, param, handle, cache, opr->execution_policy());
  349. auto workspace_size =
  350. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr.get(), layouts);
  351. sizes_in_bytes.push_back(workspace_size);
  352. WorkspaceBundle wb(nullptr, sizes_in_bytes);
  353. workspace_wrapper.update(wb.total_size_in_bytes());
  354. wb.set(workspace_wrapper.workspace().raw_ptr);
  355. TensorNDArray tensors;
  356. for (size_t i = 0; i < arity; i++) {
  357. tensors.push_back({wb.get(i), layouts[i]});
  358. }
  359. for (size_t times = 0; times < warmup_times; ++times) {
  360. AlgoProxy<Opr, arity>::exec(
  361. opr.get(), tensors, wb.get_workspace(arity));
  362. }
  363. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  364. Timer timer;
  365. timer.start();
  366. for (size_t times = 0; times < exec_times; ++times) {
  367. AlgoProxy<Opr, arity>::exec(
  368. opr.get(), tensors, wb.get_workspace(arity));
  369. }
  370. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  371. timer.stop();
  372. megdnn_log(
  373. "%.3fms %s", timer.get_time_in_us() / 1e3, algo.desc.name.c_str());
  374. if (min_time > timer.get_time_in_us()) {
  375. min_time = timer.get_time_in_us();
  376. best_algo = algo.desc;
  377. }
  378. sizes_in_bytes.pop_back();
  379. }
  380. auto&& algo = opr->get_algorithm_from_desc(best_algo);
  381. MEGDNN_MARK_USED_VAR(algo);
  382. megdnn_log("Profiling end, got best algo: %s", algo->name());
  383. cache.put(
  384. Algorithm::SearchItem{
  385. OprTypeFromOprTrait<Opr>::opr_type, param, layouts},
  386. best_algo);
  387. }
  388. virtual void init(Opr*, const TensorNDArray&) {}
  389. virtual void exec(Opr* opr, const TensorNDArray& tensors) {
  390. megdnn_assert(tensors.size() == arity);
  391. if (!W.valid()) {
  392. W = WorkspaceWrapper(opr->handle(), 0);
  393. }
  394. TensorLayoutArray layouts;
  395. for (auto&& tensor : tensors) {
  396. layouts.push_back(tensor.layout);
  397. }
  398. if (m_profiling && !target_execution_policy.algo.valid()) {
  399. FastRunCache cache;
  400. std::string param_str;
  401. Algorithm::serialize_write_pod(opr->param(), param_str);
  402. auto&& search_items =
  403. flatten_search_space(layouts, param_str, opr->handle());
  404. FOREACH_OPR_TYPE_DISPATCH(search_items, {
  405. OprProxyProfilingBase<_Opr>::search(
  406. _item.layouts, _item.param, W, opr->handle(), warmup_times,
  407. exec_times, cache);
  408. });
  409. construct_execution_policy(
  410. layouts, param_str, opr->handle(), cache, opr->execution_policy());
  411. target_execution_policy = opr->execution_policy();
  412. auto workspace_size =
  413. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  414. W.update(workspace_size);
  415. }
  416. if (!target_execution_policy.algo.valid()) {
  417. auto workspace_size =
  418. AlgoProxy<Opr, arity>::get_workspace_in_bytes(opr, layouts);
  419. W.update(workspace_size);
  420. }
  421. AlgoProxy<Opr, arity>::exec(opr, tensors, W.workspace());
  422. }
  423. virtual ~OprProxyProfilingBase() {}
  424. };
  425. #define DEF_PROF(c) \
  426. template <> \
  427. struct OprProxy<c> : public OprProxyProfilingBase<c> { \
  428. using OprProxyProfilingBase<c>::OprProxyProfilingBase; \
  429. }
  430. DEF_PROF(MatrixMulForward);
  431. DEF_PROF(ConvolutionForward);
  432. DEF_PROF(ConvolutionBackwardData);
  433. DEF_PROF(ConvolutionBackwardFilter);
  434. DEF_PROF(LocalShareForward);
  435. DEF_PROF(LocalShareBackwardData);
  436. DEF_PROF(LocalShareBackwardFilter);
  437. DEF_PROF(DeformableConvForward);
  438. DEF_PROF(DeformableConvBackwardFilter);
  439. DEF_PROF(BatchConvBiasForward);
  440. DEF_PROF(ConvBiasForward);
  441. DEF_PROF(DeformableConvBackwardData);
  442. #undef DEF_PROF
  443. template <class Opr>
  444. struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
  445. using Base = OprProxyProfilingBase<Opr>;
  446. static constexpr int arity = OprTrait<Opr>::arity;
  447. void exec(Opr* opr, const TensorNDArray& tensors) override {
  448. megdnn_assert(tensors.size() == arity);
  449. if (!Base::W.valid()) {
  450. Base::W = WorkspaceWrapper(opr->handle(), 0);
  451. }
  452. TensorLayoutArray layouts;
  453. for (auto&& tensor : tensors) {
  454. layouts.push_back(tensor.layout);
  455. }
  456. if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) {
  457. size_t min_time = std::numeric_limits<size_t>::max();
  458. for (auto algo :
  459. AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) {
  460. opr->execution_policy().algo = algo.desc;
  461. auto preprocess_tensors = weight_prerocess(opr, tensors, algo.desc);
  462. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  463. typename Opr::PreprocessedFilter preprocessed_filter{
  464. nullptr, *preprocess_tensors};
  465. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  466. opr, layouts, &preprocessed_filter);
  467. Base::W.update(workspace_size);
  468. for (size_t times = 0; times < Base::warmup_times; ++times) {
  469. AlgoProxy<Opr, arity>::exec(
  470. opr, tensors, &preprocessed_filter, Base::W.workspace());
  471. }
  472. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  473. Timer timer;
  474. timer.start();
  475. for (size_t times = 0; times < Base::exec_times; ++times) {
  476. AlgoProxy<Opr, arity>::exec(
  477. opr, tensors, &preprocessed_filter, Base::W.workspace());
  478. }
  479. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  480. timer.stop();
  481. printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
  482. algo.desc.name.c_str());
  483. if (min_time > timer.get_time_in_us()) {
  484. min_time = timer.get_time_in_us();
  485. Base::target_execution_policy.algo = algo.desc;
  486. }
  487. }
  488. opr->execution_policy() = Base::target_execution_policy;
  489. auto preprocess_tensors =
  490. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  491. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  492. typename Opr::PreprocessedFilter preprocessed_filter{
  493. nullptr, *preprocess_tensors};
  494. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  495. opr, layouts, &preprocessed_filter);
  496. Base::W.update(workspace_size);
  497. }
  498. auto preprocess_tensors =
  499. weight_prerocess(opr, tensors, Base::target_execution_policy.algo);
  500. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  501. typename Opr::PreprocessedFilter preprocessed_filter{
  502. nullptr, *preprocess_tensors};
  503. if (!Base::target_execution_policy.algo.valid()) {
  504. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  505. opr, layouts, &preprocessed_filter);
  506. Base::W.update(workspace_size);
  507. }
  508. AlgoProxy<Opr, arity>::exec(
  509. opr, tensors, &preprocessed_filter, Base::W.workspace());
  510. //! as preprocess_tensors will call destructor at end of this function,
  511. //! sync to wait worker consume preprocess_tensors, to prevent use after free
  512. //! case happen
  513. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  514. }
  515. //! handle weight preprocess
  516. std::shared_ptr<TensorNDArray> weight_prerocess(
  517. Opr* opr, const TensorNDArray& tensors,
  518. const typename Opr::AlgorithmDesc&) {
  519. TensorLayoutArray layouts;
  520. for (auto&& tensor : tensors) {
  521. layouts.push_back(tensor.layout);
  522. }
  523. auto weight_perprocess_layouts =
  524. AlgoProxy<Opr, arity>::deduce_preprocessed_filter_layout(opr, layouts);
  525. auto preprocessed_filter_tensors_ptr =
  526. Base::alloc_tensors(opr->handle(), weight_perprocess_layouts);
  527. typename Opr::PreprocessedFilter preprocessed_filter{
  528. nullptr, *preprocessed_filter_tensors_ptr};
  529. size_t preprocess_workspace_size =
  530. AlgoProxy<Opr, arity>::get_preprocess_workspace_in_bytes(opr, layouts);
  531. WorkspaceWrapper preprocess_workspace(opr->handle(), preprocess_workspace_size);
  532. AlgoProxy<Opr, arity>::exec_preprocess(
  533. opr, tensors, layouts, &preprocessed_filter,
  534. preprocess_workspace.workspace());
  535. return preprocessed_filter_tensors_ptr;
  536. }
  537. };
  538. template <class Opr>
  539. struct OprWeightPreprocessProxyBenchmarkImpl
  540. : public OprWeightPreprocessProxyImpl<Opr> {
  541. using Base = OprProxyProfilingBase<Opr>;
  542. static constexpr int arity = OprTrait<Opr>::arity;
  543. void init(Opr* opr, const TensorNDArray& tensors) override {
  544. megdnn_assert(tensors.size() == arity);
  545. if (!Base::W.valid()) {
  546. Base::W = WorkspaceWrapper(opr->handle(), 0);
  547. }
  548. TensorLayoutArray layouts;
  549. for (auto&& tensor : tensors) {
  550. layouts.push_back(tensor.layout);
  551. }
  552. m_preprocessed_tensors = this->weight_prerocess(
  553. opr, tensors, Base::target_execution_policy.algo);
  554. megcoreSynchronize(opr->handle()->megcore_computing_handle());
  555. typename Opr::PreprocessedFilter preprocessed_filter{
  556. nullptr, *m_preprocessed_tensors};
  557. if (!Base::target_execution_policy.algo.valid()) {
  558. auto workspace_size = AlgoProxy<Opr, arity>::get_workspace_in_bytes(
  559. opr, layouts, &preprocessed_filter);
  560. Base::W.update(workspace_size);
  561. }
  562. }
  563. void exec(Opr* opr, const TensorNDArray& tensors) override {
  564. megdnn_assert(tensors.size() == arity);
  565. typename Opr::PreprocessedFilter preprocessed_filter{
  566. nullptr, *m_preprocessed_tensors};
  567. AlgoProxy<Opr, arity>::exec(
  568. opr, tensors, &preprocessed_filter, Base::W.workspace());
  569. }
  570. public:
  571. std::shared_ptr<TensorNDArray> m_preprocessed_tensors;
  572. };
  573. #define DEF_PROF(c) \
  574. template <> \
  575. struct OprWeightPreprocessProxy<c> : public OprWeightPreprocessProxyImpl<c> { \
  576. using OprWeightPreprocessProxyImpl<c>::OprWeightPreprocessProxyImpl; \
  577. }; \
  578. template <> \
  579. struct OprWeightPreprocessBenchmarkProxy<c> \
  580. : public OprWeightPreprocessProxyBenchmarkImpl<c> { \
  581. using OprWeightPreprocessProxyBenchmarkImpl< \
  582. c>::OprWeightPreprocessProxyBenchmarkImpl; \
  583. };
  584. DEF_PROF(ConvolutionForward);
  585. DEF_PROF(ConvBias);
  586. #undef DEF_PROF
  587. } // namespace test
  588. } // namespace megdnn
  589. // vim: syntax=cpp.doxygen