You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checker.h 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/tensor_iter.h"
  4. #include "test/common/opr_algo_proxy.h"
  5. #include "test/common/opr_proxy.h"
  6. #include "test/common/rng.h"
  7. #include <gtest/gtest.h>
  8. #include <memory>
  9. #include <regex>
  10. #include <unordered_map>
  11. // clang-format off
  12. #if defined(__has_feature)
  13. #if __has_feature(address_sanitizer)
  14. #define MEGDNN_TEST_ASAN 1
  15. #else
  16. #define MEGDNN_TEST_ASAN 0
  17. #endif
  18. #elif defined(__SANITIZE_ADDRESS__)
  19. #define MEGDNN_TEST_ASAN 1
  20. #else
  21. #define MEGDNN_TEST_ASAN 0
  22. #endif
  23. // clang-format on
  24. namespace megdnn {
  25. namespace test {
  26. class CheckerHelper {
  27. // TensorLayoutArray and TensorValueArray should be protected in theory;
  28. // but g++-4.9 bugs handle access privilege wrongfully, so we change it
  29. // to public.
  30. public:
  31. using TensorValueArray = TensorNDArray;
  32. using TensorsConstriant = std::function<void(TensorValueArray& tensors)>;
  33. using ExtraOprImpl = std::function<void(const TensorNDArray&)>;
  34. using OutputCanonizer = std::function<void(const TensorValueArray&)>;
  35. static std::shared_ptr<TensorValueArray> alloc_tensors(
  36. Handle* handle, const TensorLayoutArray& layouts, size_t offset);
  37. Handle* handle() const { return m_handle_cur; }
  38. CheckerHelper() {
  39. auto tmp_handle = create_cpu_handle(2, false);
  40. m_handle_naive = std::move(tmp_handle);
  41. m_default_rng = std::unique_ptr<RNG>(new NormalRNG());
  42. }
  43. protected:
  44. //! whether to use physically contiguous (i.e. default layout) for naive
  45. //! impl
  46. bool m_enable_contig_naive = false;
  47. bool m_prev_succ = true;
  48. const char* m_input_tensors_fpath = nullptr;
  49. thin_function<void()> m_expect_exec_fail;
  50. std::unique_ptr<Handle> m_handle_naive;
  51. Handle* m_handle_cur;
  52. std::unique_ptr<RNG> m_default_rng;
  53. std::unordered_map<size_t, RNG*> m_rng;
  54. std::unordered_map<size_t, DType> m_dtype;
  55. std::unordered_map<size_t, TensorFormat> m_fmt;
  56. std::set<size_t> m_bypass;
  57. float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3, m_max_avg_biased_error = 1e-3;
  58. float_t m_perf_check_threshold = -1;
  59. bool m_perf_check = false;
  60. ExtraOprImpl m_extra_opr_impl;
  61. OutputCanonizer m_output_canonizer;
  62. TensorsConstriant m_tensor_constraint;
  63. bool m_no_naive_and_check = false;
  64. bool m_stable_check = false;
  65. bool m_force_deduce_dst = false;
  66. bool m_allow_invalid_check = false;
  67. /**
  68. * the offset from the start of malloc memory
  69. *
  70. * \note alloc \p m_offset more memory when alloc memory for a tensor,
  71. * the start of tensor just begin at \p m_offset.
  72. * \warning current only used for opencl
  73. */
  74. size_t m_offset = 0;
  75. CheckerHelper(Handle* handle, bool check_dispatch = true);
  76. ~CheckerHelper() noexcept;
  77. using OprExec = std::function<void(const TensorValueArray&)>;
  78. void do_exec_with_testcases(
  79. const TensorValueArray& testcase_in, const TensorValueArray& testcase_out,
  80. const OprExec& exec_opr);
  81. void do_exec(
  82. const TensorLayoutArray& user_layouts,
  83. const TensorLayoutArray& deduced_layouts, const OprExec& exec_naive,
  84. const OprExec& exec_opr);
  85. void enable_contig_naive() { m_enable_contig_naive = true; }
  86. void copy_tensors_to_device(
  87. const TensorValueArray& dest, const TensorValueArray& src);
  88. void copy_tensors_from_device(
  89. const TensorValueArray& dest, const TensorValueArray& src);
  90. void check_tensors(
  91. const TensorValueArray& expected, const TensorValueArray& computed);
  92. private:
  93. std::shared_ptr<TensorValueArray> m_tensors_naive;
  94. void init_naive_values();
  95. };
  96. template <typename Opr, typename Proxy = OprProxy<Opr>>
  97. class Checker : public CheckerHelper {
  98. public:
  99. using Param = typename Opr::Param;
  100. using BeforeExecCallback = std::function<void(Opr*, const TensorValueArray&)>;
  101. Checker(Handle* handle, bool check_dispatch = true)
  102. : CheckerHelper(handle, check_dispatch), m_param(Param()) {}
  103. TensorLayoutArray make_layouts(const TensorShapeArray& shapes) {
  104. TensorLayoutArray layouts(shapes.size());
  105. for (size_t i = 0; i < shapes.size(); ++i) {
  106. DType dt =
  107. (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] : dtype::Float32());
  108. if (m_fmt.find(i) == m_fmt.end()) {
  109. layouts[i] = TensorLayout(shapes[i], dt);
  110. } else
  111. layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]);
  112. }
  113. return layouts;
  114. }
  115. /*!
  116. * \brief execute opr on current param/dtype/rng config
  117. * \param shapes input/output shapes, which would be passed as
  118. * arguments to Opr::deduce_layout
  119. *
  120. * Checker would construct TensorLayout vectors from shapes and dtypes,
  121. * and call exec(TensorLayoutArray &).
  122. */
  123. Checker& exec(const TensorShapeArray& shapes) {
  124. exec(make_layouts(shapes));
  125. return *this;
  126. }
  127. void exec(TensorLayoutArray layouts);
  128. //! explicitly require argument to be TensorShape
  129. Checker& execs(const TensorShapeArray& shapes) { return exec(shapes); }
  130. //! explicitly require argument to be TensorLayout
  131. Checker& execl(const TensorLayoutArray& layouts) {
  132. exec(layouts);
  133. return *this;
  134. }
  135. Checker& exect(
  136. const TensorValueArray& testcase_in, const TensorValueArray& testcase_out);
  137. Checker& set_param(Param param) {
  138. m_param = param;
  139. opr()->param() = param;
  140. return *this;
  141. }
  142. Checker& set_dtype(size_t idx, DType dtype) {
  143. m_dtype[idx] = dtype;
  144. return *this;
  145. }
  146. Checker& set_fmt(size_t idx, TensorFormat fmt) {
  147. m_fmt[idx] = fmt;
  148. return *this;
  149. }
  150. Checker& set_rng(size_t idx, RNG* rng) {
  151. m_rng[idx] = rng;
  152. return *this;
  153. }
  154. Checker& set_bypass(size_t idx) {
  155. m_bypass.insert(idx);
  156. return *this;
  157. }
  158. //! max error of a single element
  159. Checker& set_epsilon(dt_float32 epsilon) {
  160. m_epsilon = epsilon;
  161. m_max_avg_error = epsilon;
  162. m_max_avg_biased_error = epsilon;
  163. return *this;
  164. }
  165. //! max average error; defaults to epsilon
  166. Checker& set_max_avg_error(dt_float32 error) {
  167. m_max_avg_error = error;
  168. return *this;
  169. }
  170. //! max average biased error; defaults to epsilon
  171. Checker& set_max_avg_biased_error(dt_float32 error) {
  172. m_max_avg_biased_error = error;
  173. return *this;
  174. }
  175. Checker& set_offset(size_t offset) {
  176. m_offset = offset;
  177. return *this;
  178. }
  179. Checker& set_proxy(const Proxy& proxy) {
  180. m_naive_proxy = proxy;
  181. m_cur_proxy = proxy;
  182. return *this;
  183. }
  184. //! set_perf_check and set_perf_check_threshold control the
  185. //! performance checking behavior.
  186. //!
  187. //! If perf_check is on (default to off), the running time of the
  188. //! current operator and the naive operator would be measured and
  189. //! checked when calling exec.
  190. //! The accelerating ratio should be larger than perf_check_threshold,
  191. //! otherwise errors would be reported.
  192. //! perf_check_threshold must be set in advance since the default value
  193. //! (which is negative) is invalid.
  194. Checker& set_perf_check(bool perf_check) {
  195. m_perf_check = perf_check;
  196. return *this;
  197. }
  198. Checker& set_perf_check_threshold(float perf_check_threshold) {
  199. m_perf_check_threshold = perf_check_threshold;
  200. return *this;
  201. }
  202. //! stable check will run many iter and compare result with first iter
  203. Checker& set_stable_check(bool stable_check) {
  204. m_stable_check = stable_check;
  205. return *this;
  206. }
  207. //! froce deduce dst
  208. Checker& set_force_deduce_dst(bool force_deduce_dst) {
  209. m_force_deduce_dst = force_deduce_dst;
  210. return *this;
  211. }
  212. Checker& set_no_naive_check(bool no_naive_and_check) {
  213. m_no_naive_and_check = no_naive_and_check;
  214. return *this;
  215. }
  216. Checker& set_allow_invalid_check(bool allow_invalid_check) {
  217. m_allow_invalid_check = allow_invalid_check;
  218. return *this;
  219. }
  220. //! load input tensors from file for next run
  221. Checker& load_input_tensors(const char* fpath) {
  222. m_input_tensors_fpath = fpath;
  223. return *this;
  224. }
  225. //! add another checker to ensure naive implementation is correct
  226. Checker& set_extra_opr_impl(const ExtraOprImpl& chk) {
  227. m_extra_opr_impl = chk;
  228. return *this;
  229. }
  230. //! set a callback to be invoked before executing the operator
  231. Checker& set_before_exec_callback(const BeforeExecCallback& cb) {
  232. m_before_exec_callback = cb;
  233. return *this;
  234. }
  235. Checker& reset_before_exec_callback() {
  236. m_before_exec_callback = nullptr;
  237. return *this;
  238. }
  239. //! set a tensors constraints function, for the purpose of manipulating
  240. //! tensors when testing.
  241. Checker& set_tensors_constraint(const TensorsConstriant& tensor_constraint) {
  242. m_tensor_constraint = tensor_constraint;
  243. return *this;
  244. }
  245. /*!
  246. * \brief set that exec() on opr should fail, so naive is not called and
  247. * exec() returns directly after opr is called.
  248. *
  249. * This is only valid for next exec() call. It is usually used for
  250. * testing megcore::AsyncErrorInfo.
  251. *
  252. * \param cb callback to be invoked after opr exec (so error would not
  253. * be passed to destructor)
  254. */
  255. Checker& set_expect_exec_fail(const thin_function<void()>& cb) {
  256. m_expect_exec_fail = cb;
  257. return *this;
  258. }
  259. /*!
  260. * \brief set a function to canonize the outputs
  261. *
  262. * For some oprs maybe multiple outputs can be accepted; we can use a
  263. * function to transform them into a canonized form before comparing.
  264. *
  265. * The arguments are tensors on CPU and should be modified in-place.
  266. */
  267. Checker& set_output_canonizer(OutputCanonizer canonizer) {
  268. m_output_canonizer = std::move(canonizer);
  269. return *this;
  270. }
  271. //! get the opr impl so setting other than param() can be modified
  272. Opr* opr() {
  273. if (!m_opr_cur) {
  274. m_opr_cur = m_handle_cur->create_operator<Opr>();
  275. }
  276. return m_opr_cur.get();
  277. }
  278. //! whether previous exec succeeds
  279. bool prev_succ() const { return m_prev_succ; }
  280. private:
  281. BeforeExecCallback m_before_exec_callback;
  282. Param m_param;
  283. Proxy m_naive_proxy, m_cur_proxy;
  284. std::unique_ptr<Opr> m_opr_cur;
  285. };
  286. ::testing::AssertionResult __assert_tensor_eq(
  287. const char* expr0, const char* expr1, const char* expr_maxerr,
  288. const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased,
  289. const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg,
  290. float maxerr_avg_biased, bool allow_invalid = false);
  291. ::testing::AssertionResult __assert_tensor_eq_allow_invalid(
  292. const char* expr0, const char* expr1, const char* expr_maxerr,
  293. const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased,
  294. const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg,
  295. float maxerr_avg_biased);
  296. #define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr_avg, maxerr_avg_biased) \
  297. ASSERT_PRED_FORMAT5( \
  298. ::megdnn::test::__assert_tensor_eq, v0, v1, maxerr, maxerr_avg, \
  299. maxerr_avg_biased)
  300. #define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( \
  301. v0, v1, maxerr, maxerr_avg, maxerr_avg_biased) \
  302. ASSERT_PRED_FORMAT5( \
  303. ::megdnn::test::__assert_tensor_eq_allow_invalid, v0, v1, maxerr, \
  304. maxerr_avg, maxerr_avg_biased)
  305. #define MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, maxerr) \
  306. MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr, maxerr)
  307. #define MEGDNN_ASSERT_TENSOR_EQ(v0, v1) MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, 1e-3)
  308. template <typename Opr, typename Proxy>
  309. void Checker<Opr, Proxy>::exec(TensorLayoutArray layouts) {
  310. auto opr_naive = m_handle_naive->create_operator<Opr>();
  311. auto opr_relayout = m_handle_naive->create_operator<RelayoutForward>();
  312. auto opr_cur = this->opr();
  313. opr_naive->param() = m_param;
  314. opr_cur->param() = m_param;
  315. bool deduce_layout = layouts.back().ndim == 0;
  316. if (deduce_layout || m_force_deduce_dst) {
  317. m_naive_proxy.deduce_layout(opr_naive.get(), layouts);
  318. }
  319. auto exec_naive = [this, &opr_naive, &layouts,
  320. &opr_relayout](const TensorValueArray& values) {
  321. TensorValueArray contig_values = values;
  322. TensorValueArray real_values = values;
  323. std::shared_ptr<TensorValueArray> tensors_naive_contig_storage;
  324. if (m_enable_contig_naive) {
  325. TensorLayoutArray contig_layouts;
  326. for (auto&& layout : layouts) {
  327. contig_layouts.emplace_back(TensorLayout{
  328. static_cast<const TensorShape&>(layout), layout.dtype});
  329. }
  330. m_naive_proxy.deduce_layout(opr_naive.get(), contig_layouts);
  331. tensors_naive_contig_storage =
  332. alloc_tensors(m_handle_naive.get(), contig_layouts, m_offset);
  333. contig_values = *tensors_naive_contig_storage;
  334. //! relayout value to the contig_values
  335. for (size_t i = 0; i < contig_values.size(); ++i) {
  336. if (real_values[i].layout.ndim == 0)
  337. continue;
  338. real_values[i].layout.format = {};
  339. opr_relayout->exec(
  340. real_values[i], contig_values[i], m_handle_naive.get());
  341. }
  342. }
  343. m_naive_proxy.exec(opr_naive.get(), contig_values);
  344. if (m_enable_contig_naive) {
  345. //! relayout to the values
  346. for (size_t i = 0; i < contig_values.size(); ++i) {
  347. if (real_values[i].layout.ndim == 0)
  348. continue;
  349. opr_relayout->exec(
  350. contig_values[i], real_values[i], m_handle_naive.get());
  351. }
  352. }
  353. };
  354. auto exec_opr = [this, opr_cur](const TensorValueArray& values) {
  355. if (m_before_exec_callback) {
  356. m_before_exec_callback(opr_cur, values);
  357. }
  358. m_cur_proxy.exec(opr_cur, values);
  359. };
  360. auto user_layouts = layouts;
  361. do_exec(user_layouts, layouts, exec_naive, exec_opr);
  362. }
  363. template <typename Opr, typename Proxy>
  364. Checker<Opr, Proxy>& Checker<Opr, Proxy>::exect(
  365. const TensorValueArray& testcase_in, const TensorValueArray& testcase_out) {
  366. auto opr_cur = this->opr();
  367. opr_cur->param() = m_param;
  368. auto exec_opr = [this, opr_cur](const TensorValueArray& values) {
  369. if (m_before_exec_callback) {
  370. m_before_exec_callback(opr_cur, values);
  371. }
  372. m_cur_proxy.exec(opr_cur, values);
  373. };
  374. do_exec_with_testcases(testcase_in, testcase_out, exec_opr);
  375. return *this;
  376. }
  377. template <typename T, typename U>
  378. TensorND TensorValue(
  379. const TensorShape& shape, T dtype, std::initializer_list<U> values) {
  380. TensorLayout layout{shape, dtype};
  381. auto buf = static_cast<dt_byte*>(malloc(layout.span().dist_byte()));
  382. TensorND tensor{buf, layout};
  383. megdnn_assert(
  384. values.size() == tensor.layout.total_nr_elems(), "%zu == %zu",
  385. values.size(), tensor.layout.total_nr_elems());
  386. auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>();
  387. for (const auto& v : values) {
  388. *ptr++ = typename DTypeTrait<T>::ctype(v);
  389. }
  390. return tensor;
  391. }
  392. template <typename T, typename U>
  393. TensorND TensorValueLowbit4(const TensorShape& shape, T dtype, std::vector<U> values) {
  394. TensorLayout layout{shape, dtype};
  395. auto buf = static_cast<dt_byte*>(malloc(layout.span().dist_byte()));
  396. TensorND tensor{buf, layout};
  397. megdnn_assert(values.size() == tensor.layout.total_nr_elems());
  398. auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>();
  399. auto dim_in = shape[layout.ndim - 1];
  400. auto elems = tensor.layout.total_nr_elems();
  401. auto dim_out = elems / dim_in;
  402. auto stride_out = div_ceil(dim_in, 2_z);
  403. size_t in_offset = 0;
  404. for (size_t i = 0; i < dim_out; ++i) {
  405. for (size_t j = 0; j < dim_in; j += 2) {
  406. U a = values[in_offset + j];
  407. U b = 0;
  408. if (j + 1 < dim_in)
  409. b = values[in_offset + j + 1];
  410. megdnn_assert(a >= DTypeTrait<T>::min());
  411. megdnn_assert(a <= DTypeTrait<T>::max());
  412. megdnn_assert(b >= DTypeTrait<T>::min());
  413. megdnn_assert(b <= DTypeTrait<T>::max());
  414. ptr[j / 2] = (a & 0xF) | (b << 4);
  415. }
  416. in_offset += dim_in;
  417. ptr += stride_out;
  418. }
  419. return tensor;
  420. }
  421. class Testcase : public SmallVector<TensorND> {
  422. public:
  423. using SmallVector<TensorND>::SmallVector;
  424. ~Testcase() {
  425. // Suicide
  426. for (const auto& tensor : *this) {
  427. if (tensor.raw_ptr()) {
  428. free(tensor.raw_ptr());
  429. }
  430. }
  431. }
  432. Testcase(const Testcase&) = delete;
  433. Testcase operator=(const Testcase&) = delete;
  434. };
  435. struct ExecutionPolicyAlgoName {
  436. std::string name;
  437. std::vector<ExecutionPolicyAlgoName> sub_policy_names;
  438. ExecutionPolicyAlgoName(const char* name) : name{name} {}
  439. ExecutionPolicyAlgoName(
  440. const char* name, const std::vector<ExecutionPolicyAlgoName>& sub_policy)
  441. : name{name}, sub_policy_names{sub_policy} {}
  442. };
  443. /*!
  444. * \brief a callable to check that given algorithm is used for heuristic
  445. * \param require_algo if its value is true, then requires
  446. * get_algorithm_heuristic() to return the expected algo; otherwise the
  447. * expected algo must exist in get_all_algorithms_safe() and it would be set to
  448. * be used
  449. */
  450. template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>>
  451. class AlgoChecker {
  452. public:
  453. AlgoChecker(ExecutionPolicyAlgoName name, bool* require_algo = nullptr)
  454. : m_policy_name{name}, m_require_algo{require_algo} {}
  455. AlgoChecker(ExecutionPolicy policy, bool* require_algo = nullptr)
  456. : m_policy{policy}, m_require_algo{require_algo} {}
  457. static ExecutionPolicy construct_execution_policy_from_name(
  458. const ExecutionPolicyAlgoName& policy_name,
  459. const TensorLayoutArray& layouts, const std::string& param,
  460. Handle* handle) {
  461. ExecutionPolicy ret;
  462. megdnn_assert(layouts.size() == OprTrait<Opr>::arity);
  463. auto opr = handle->create_operator<Opr>();
  464. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  465. for (auto algo_info :
  466. AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe(
  467. opr.get(), layouts)) {
  468. if (std::regex_match(
  469. algo_info.desc.name,
  470. std::regex("(" + policy_name.name + ")(.*)"))) {
  471. ret.algo = algo_info.desc;
  472. } else {
  473. continue;
  474. }
  475. Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
  476. std::vector<Algorithm::SearchItem>&& sub_items =
  477. algo->get_subopr_list(layouts, opr.get());
  478. if (sub_items.size() != policy_name.sub_policy_names.size()) {
  479. printf("Invalid sub_policy_names in %s, expected %zu but got "
  480. "%zu\n",
  481. algo_info.desc.name.c_str(), sub_items.size(),
  482. policy_name.sub_policy_names.size());
  483. return {};
  484. }
  485. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  486. ExecutionPolicy policy =
  487. AlgoChecker<_Opr>::construct_execution_policy_from_name(
  488. policy_name.sub_policy_names[_item_idx], _item.layouts,
  489. _item.param, handle);
  490. ret.sub_policy.push_back(policy);
  491. });
  492. return ret;
  493. }
  494. megdnn_assert(false, "Expected algo not found: %s\n", policy_name.name.c_str());
  495. return ret;
  496. }
  497. void operator()(Opr* opr, const CheckerHelper::TensorValueArray& arr) {
  498. TensorLayoutArray layouts;
  499. for (auto&& val : arr) {
  500. layouts.push_back(val.layout);
  501. }
  502. if (!m_policy_name.name.empty()) {
  503. std::string param_str;
  504. Algorithm::serialize_write_pod(opr->param(), param_str);
  505. m_policy = construct_execution_policy_from_name(
  506. m_policy_name, layouts, param_str, opr->handle());
  507. ASSERT_TRUE(m_policy.algo.valid())
  508. << "algorithm " << m_policy_name.name << " not found";
  509. }
  510. if (m_require_algo && *m_require_algo) {
  511. auto algo = OprAlgoProxy::get_algorithm_info_heuristic(opr, layouts);
  512. ASSERT_STREQ(
  513. opr->get_algorithm_from_desc(m_policy.algo)->name(),
  514. algo.desc.name.c_str());
  515. } else {
  516. opr->execution_policy() = m_policy;
  517. }
  518. }
  519. private:
  520. ExecutionPolicyAlgoName m_policy_name;
  521. ExecutionPolicy m_policy;
  522. bool* m_require_algo;
  523. };
  524. template <typename Opr>
  525. void construct_sub_execution_policy_heuristic(
  526. ExecutionPolicy& policy, const TensorLayoutArray& layouts,
  527. const std::string& param, Handle* handle) {
  528. megdnn_assert(layouts.size() == OprTrait<Opr>::arity);
  529. auto opr = handle->create_operator<Opr>();
  530. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  531. if (!policy.algo.valid()) {
  532. policy.algo =
  533. AlgoProxy<Opr, OprTrait<Opr>::arity>::get_algorithm_info_heuristic(
  534. opr.get(), layouts)
  535. .desc;
  536. }
  537. Algorithm* algo = opr->get_algorithm_from_desc(policy.algo);
  538. std::vector<Algorithm::SearchItem>&& sub_items =
  539. algo->get_subopr_list(layouts, opr.get());
  540. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  541. policy.sub_policy.push_back(ExecutionPolicy{});
  542. construct_sub_execution_policy_heuristic<_Opr>(
  543. policy.sub_policy.back(), _item.layouts, _item.param, handle);
  544. });
  545. }
  546. } // namespace test
  547. } // namespace megdnn
  548. // vim: syntax=cpp.doxygen