You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checker.h 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. /**
  2. * \file dnn/test/common/checker.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/basic_types.h"
  14. #include "megdnn/tensor_iter.h"
  15. #include "test/common/opr_algo_proxy.h"
  16. #include "test/common/opr_proxy.h"
  17. #include "test/common/rng.h"
  18. #include <gtest/gtest.h>
  19. #include <memory>
  20. #include <regex>
  21. #include <unordered_map>
  22. // clang-format off
  23. #if defined(__has_feature)
  24. #if __has_feature(address_sanitizer)
  25. #define MEGDNN_TEST_ASAN 1
  26. #else
  27. #define MEGDNN_TEST_ASAN 0
  28. #endif
  29. #elif defined(__SANITIZE_ADDRESS__)
  30. #define MEGDNN_TEST_ASAN 1
  31. #else
  32. #define MEGDNN_TEST_ASAN 0
  33. #endif
  34. // clang-format on
  35. namespace megdnn {
  36. namespace test {
  37. class CheckerHelper {
  38. // TensorLayoutArray and TensorValueArray should be protected in theory;
  39. // but g++-4.9 bugs handle access privilege wrongfully, so we change it
  40. // to public.
  41. public:
  42. using TensorValueArray = TensorNDArray;
  43. using TensorsConstriant = std::function<void(TensorValueArray& tensors)>;
  44. using ExtraOprImpl = std::function<void(const TensorNDArray&)>;
  45. using OutputCanonizer = std::function<void(const TensorValueArray&)>;
  46. static std::shared_ptr<TensorValueArray> alloc_tensors(
  47. Handle* handle, const TensorLayoutArray& layouts, size_t offset);
  48. Handle* handle() const { return m_handle_cur; }
  49. protected:
  50. //! whether to use physically contiguous (i.e. default layout) for naive
  51. //! impl
  52. bool m_enable_contig_naive = false;
  53. bool m_prev_succ = true;
  54. const char* m_input_tensors_fpath = nullptr;
  55. thin_function<void()> m_expect_exec_fail;
  56. std::unique_ptr<Handle> m_handle_naive;
  57. Handle* m_handle_cur;
  58. std::unique_ptr<RNG> m_default_rng;
  59. std::unordered_map<size_t, RNG*> m_rng;
  60. std::unordered_map<size_t, DType> m_dtype;
  61. std::unordered_map<size_t, TensorFormat> m_fmt;
  62. std::set<size_t> m_bypass;
  63. float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3,
  64. m_max_avg_biased_error = 1e-3;
  65. float_t m_perf_check_threshold = -1;
  66. bool m_perf_check = false;
  67. ExtraOprImpl m_extra_opr_impl;
  68. OutputCanonizer m_output_canonizer;
  69. TensorsConstriant m_tensor_constraint;
  70. bool m_no_naive_and_check = false;
  71. bool m_stable_check = false;
  72. bool m_force_deduce_dst = true;
  73. bool m_allow_invalid_check = false;
  74. /**
  75. * the offset from the start of malloc memory
  76. *
  77. * \note alloc \p m_offset more memory when alloc memory for a tensor,
  78. * the start of tensor just begin at \p m_offset.
  79. * \warning current only used for opencl
  80. */
  81. size_t m_offset = 0;
  82. CheckerHelper(Handle* handle, bool check_dispatch = true);
  83. ~CheckerHelper() noexcept;
  84. using OprExec = std::function<void(const TensorValueArray&)>;
  85. void do_exec_with_testcases(const TensorValueArray& testcase_in,
  86. const TensorValueArray& testcase_out,
  87. const OprExec& exec_opr);
  88. void do_exec(const TensorLayoutArray& user_layouts,
  89. const TensorLayoutArray& deduced_layouts,
  90. const OprExec& exec_naive, const OprExec& exec_opr);
  91. void enable_contig_naive() { m_enable_contig_naive = true; }
  92. void copy_tensors_to_device(const TensorValueArray& dest,
  93. const TensorValueArray& src);
  94. void copy_tensors_from_device(const TensorValueArray& dest,
  95. const TensorValueArray& src);
  96. private:
  97. std::shared_ptr<TensorValueArray> m_tensors_naive;
  98. void init_naive_values();
  99. void check_tensors(const TensorValueArray& expected,
  100. const TensorValueArray& computed);
  101. };
  102. template <typename Opr, typename Proxy = OprProxy<Opr>>
  103. class Checker : public CheckerHelper {
  104. public:
  105. using Param = typename Opr::Param;
  106. using BeforeExecCallback =
  107. std::function<void(Opr*, const TensorValueArray&)>;
  108. Checker(Handle* handle, bool check_dispatch = true)
  109. : CheckerHelper(handle, check_dispatch), m_param(Param()) {}
  110. TensorLayoutArray make_layouts(const TensorShapeArray& shapes) {
  111. TensorLayoutArray layouts(shapes.size());
  112. for (size_t i = 0; i < shapes.size(); ++i) {
  113. DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i]
  114. : dtype::Float32());
  115. if (m_fmt.find(i) == m_fmt.end()) {
  116. layouts[i] = TensorLayout(shapes[i], dt);
  117. } else
  118. layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]);
  119. }
  120. return layouts;
  121. }
  122. /*!
  123. * \brief execute opr on current param/dtype/rng config
  124. * \param shapes input/output shapes, which would be passed as
  125. * arguments to Opr::deduce_layout
  126. *
  127. * Checker would construct TensorLayout vectors from shapes and dtypes,
  128. * and call exec(TensorLayoutArray &).
  129. */
  130. Checker& exec(const TensorShapeArray& shapes) {
  131. exec(make_layouts(shapes));
  132. return *this;
  133. }
  134. void exec(TensorLayoutArray layouts);
  135. //! explicitly require argument to be TensorShape
  136. Checker& execs(const TensorShapeArray& shapes) { return exec(shapes); }
  137. //! explicitly require argument to be TensorLayout
  138. Checker& execl(const TensorLayoutArray& layouts) {
  139. exec(layouts);
  140. return *this;
  141. }
  142. Checker& exect(const TensorValueArray& testcase_in,
  143. const TensorValueArray& testcase_out);
  144. Checker& set_param(Param param) {
  145. m_param = param;
  146. opr()->param() = param;
  147. return *this;
  148. }
  149. Checker& set_dtype(size_t idx, DType dtype) {
  150. m_dtype[idx] = dtype;
  151. return *this;
  152. }
  153. Checker& set_fmt(size_t idx, TensorFormat fmt) {
  154. m_fmt[idx] = fmt;
  155. return *this;
  156. }
  157. Checker& set_rng(size_t idx, RNG* rng) {
  158. m_rng[idx] = rng;
  159. return *this;
  160. }
  161. Checker& set_bypass(size_t idx) {
  162. m_bypass.insert(idx);
  163. return *this;
  164. }
  165. //! max error of a single element
  166. Checker& set_epsilon(dt_float32 epsilon) {
  167. m_epsilon = epsilon;
  168. m_max_avg_error = epsilon;
  169. m_max_avg_biased_error = epsilon;
  170. return *this;
  171. }
  172. //! max average error; defaults to epsilon
  173. Checker& set_max_avg_error(dt_float32 error) {
  174. m_max_avg_error = error;
  175. return *this;
  176. }
  177. //! max average biased error; defaults to epsilon
  178. Checker& set_max_avg_biased_error(dt_float32 error) {
  179. m_max_avg_biased_error = error;
  180. return *this;
  181. }
  182. Checker& set_offset(size_t offset) {
  183. m_offset = offset;
  184. return *this;
  185. }
  186. Checker& set_proxy(const Proxy& proxy) {
  187. m_naive_proxy = proxy;
  188. m_cur_proxy = proxy;
  189. return *this;
  190. }
  191. //! set_perf_check and set_perf_check_threshold control the
  192. //! performance checking behavior.
  193. //!
  194. //! If perf_check is on (default to off), the running time of the
  195. //! current operator and the naive operator would be measured and
  196. //! checked when calling exec.
  197. //! The accelerating ratio should be larger than perf_check_threshold,
  198. //! otherwise errors would be reported.
  199. //! perf_check_threshold must be set in advance since the default value
  200. //! (which is negative) is invalid.
  201. Checker& set_perf_check(bool perf_check) {
  202. m_perf_check = perf_check;
  203. return *this;
  204. }
  205. Checker& set_perf_check_threshold(float perf_check_threshold) {
  206. m_perf_check_threshold = perf_check_threshold;
  207. return *this;
  208. }
  209. //! stable check will run many iter and compare result with first iter
  210. Checker& set_stable_check(bool stable_check) {
  211. m_stable_check = stable_check;
  212. return *this;
  213. }
  214. //! froce deduce dst
  215. Checker& set_force_deduce_dst(bool force_deduce_dst) {
  216. m_force_deduce_dst = force_deduce_dst;
  217. return *this;
  218. }
  219. Checker& set_no_naive_check(bool no_naive_and_check) {
  220. m_no_naive_and_check = no_naive_and_check;
  221. return *this;
  222. }
  223. Checker& set_allow_invalid_check(bool allow_invalid_check) {
  224. m_allow_invalid_check = allow_invalid_check;
  225. return *this;
  226. }
  227. //! load input tensors from file for next run
  228. Checker& load_input_tensors(const char* fpath) {
  229. m_input_tensors_fpath = fpath;
  230. return *this;
  231. }
  232. //! add another checker to ensure naive implementation is correct
  233. Checker& set_extra_opr_impl(const ExtraOprImpl& chk) {
  234. m_extra_opr_impl = chk;
  235. return *this;
  236. }
  237. //! set a callback to be invoked before executing the operator
  238. Checker& set_before_exec_callback(const BeforeExecCallback& cb) {
  239. m_before_exec_callback = cb;
  240. return *this;
  241. }
  242. Checker& reset_before_exec_callback() {
  243. m_before_exec_callback = nullptr;
  244. return *this;
  245. }
  246. //! set a tensors constraints function, for the purpose of manipulating
  247. //! tensors when testing.
  248. Checker& set_tensors_constraint(
  249. const TensorsConstriant& tensor_constraint) {
  250. m_tensor_constraint = tensor_constraint;
  251. return *this;
  252. }
  253. /*!
  254. * \brief set that exec() on opr should fail, so naive is not called and
  255. * exec() returns directly after opr is called.
  256. *
  257. * This is only valid for next exec() call. It is usually used for
  258. * testing megcore::AsyncErrorInfo.
  259. *
  260. * \param cb callback to be invoked after opr exec (so error would not
  261. * be passed to destructor)
  262. */
  263. Checker& set_expect_exec_fail(const thin_function<void()>& cb) {
  264. m_expect_exec_fail = cb;
  265. return *this;
  266. }
  267. /*!
  268. * \brief set a function to canonize the outputs
  269. *
  270. * For some oprs maybe multiple outputs can be accepted; we can use a
  271. * function to transform them into a canonized form before comparing.
  272. *
  273. * The arguments are tensors on CPU and should be modified in-place.
  274. */
  275. Checker& set_output_canonizer(OutputCanonizer canonizer) {
  276. m_output_canonizer = std::move(canonizer);
  277. return *this;
  278. }
  279. //! get the opr impl so setting other than param() can be modified
  280. Opr* opr() {
  281. if (!m_opr_cur) {
  282. m_opr_cur = m_handle_cur->create_operator<Opr>();
  283. }
  284. return m_opr_cur.get();
  285. }
  286. //! whether previous exec succeeds
  287. bool prev_succ() const { return m_prev_succ; }
  288. private:
  289. BeforeExecCallback m_before_exec_callback;
  290. Param m_param;
  291. Proxy m_naive_proxy, m_cur_proxy;
  292. std::unique_ptr<Opr> m_opr_cur;
  293. };
  294. ::testing::AssertionResult __assert_tensor_eq(
  295. const char* expr0, const char* expr1, const char* expr_maxerr,
  296. const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased,
  297. const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg,
  298. float maxerr_avg_biased, bool allow_invalid = false);
  299. ::testing::AssertionResult __assert_tensor_eq_allow_invalid(
  300. const char* expr0, const char* expr1, const char* expr_maxerr,
  301. const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased,
  302. const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg,
  303. float maxerr_avg_biased);
  304. #define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr_avg, \
  305. maxerr_avg_biased) \
  306. ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq, v0, v1, maxerr, \
  307. maxerr_avg, maxerr_avg_biased)
  308. #define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( \
  309. v0, v1, maxerr, maxerr_avg, maxerr_avg_biased) \
  310. ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq_allow_invalid, v0, \
  311. v1, maxerr, maxerr_avg, maxerr_avg_biased)
  312. #define MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, maxerr) \
  313. MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr, maxerr)
  314. #define MEGDNN_ASSERT_TENSOR_EQ(v0, v1) \
  315. MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, 1e-3)
  316. template <typename Opr, typename Proxy>
  317. void Checker<Opr, Proxy>::exec(TensorLayoutArray layouts) {
  318. auto opr_naive = m_handle_naive->create_operator<Opr>();
  319. auto opr_relayout = m_handle_naive->create_operator<RelayoutForward>();
  320. auto opr_cur = this->opr();
  321. opr_naive->param() = m_param;
  322. opr_cur->param() = m_param;
  323. bool deduce_layout = layouts.back().ndim == 0;
  324. if (deduce_layout || m_force_deduce_dst) {
  325. m_naive_proxy.deduce_layout(opr_naive.get(), layouts);
  326. }
  327. auto exec_naive = [this, &opr_naive, &layouts,
  328. &opr_relayout](const TensorValueArray& values) {
  329. TensorValueArray contig_values = values;
  330. TensorValueArray real_values = values;
  331. std::shared_ptr<TensorValueArray> tensors_naive_contig_storage;
  332. if (m_enable_contig_naive) {
  333. TensorLayoutArray contig_layouts;
  334. for (auto&& layout : layouts) {
  335. contig_layouts.emplace_back(TensorLayout{
  336. static_cast<const TensorShape&>(layout), layout.dtype});
  337. }
  338. m_naive_proxy.deduce_layout(opr_naive.get(), contig_layouts);
  339. tensors_naive_contig_storage = alloc_tensors(
  340. m_handle_naive.get(), contig_layouts, m_offset);
  341. contig_values = *tensors_naive_contig_storage;
  342. //! relayout value to the contig_values
  343. for (size_t i = 0; i < contig_values.size(); ++i) {
  344. if (real_values[i].layout.ndim == 0)
  345. continue;
  346. real_values[i].layout.format = {};
  347. opr_relayout->exec(real_values[i], contig_values[i],
  348. m_handle_naive.get());
  349. }
  350. }
  351. m_naive_proxy.exec(opr_naive.get(), contig_values);
  352. if (m_enable_contig_naive) {
  353. //! relayout to the values
  354. for (size_t i = 0; i < contig_values.size(); ++i) {
  355. if (real_values[i].layout.ndim == 0)
  356. continue;
  357. opr_relayout->exec(contig_values[i], real_values[i],
  358. m_handle_naive.get());
  359. }
  360. }
  361. };
  362. auto exec_opr = [this, opr_cur](const TensorValueArray& values) {
  363. if (m_before_exec_callback) {
  364. m_before_exec_callback(opr_cur, values);
  365. }
  366. m_cur_proxy.exec(opr_cur, values);
  367. };
  368. auto user_layouts = layouts;
  369. do_exec(user_layouts, layouts, exec_naive, exec_opr);
  370. }
  371. template <typename Opr, typename Proxy>
  372. Checker<Opr, Proxy>& Checker<Opr, Proxy>::exect(
  373. const TensorValueArray& testcase_in,
  374. const TensorValueArray& testcase_out) {
  375. auto opr_cur = this->opr();
  376. opr_cur->param() = m_param;
  377. auto exec_opr = [this, opr_cur](const TensorValueArray& values) {
  378. if (m_before_exec_callback) {
  379. m_before_exec_callback(opr_cur, values);
  380. }
  381. m_cur_proxy.exec(opr_cur, values);
  382. };
  383. do_exec_with_testcases(testcase_in, testcase_out, exec_opr);
  384. return *this;
  385. }
  386. template <typename T, typename U>
  387. TensorND TensorValue(const TensorShape& shape, T dtype,
  388. std::initializer_list<U> values) {
  389. TensorND tensor;
  390. tensor.layout = {shape, dtype};
  391. tensor.raw_ptr =
  392. static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte()));
  393. megdnn_assert(values.size() == tensor.layout.total_nr_elems(), "%zu == %zu",
  394. values.size(), tensor.layout.total_nr_elems());
  395. auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>();
  396. for (const auto& v : values) {
  397. *ptr++ = typename DTypeTrait<T>::ctype(v);
  398. }
  399. return tensor;
  400. }
  401. template <typename T, typename U>
  402. TensorND TensorValueLowbit4(const TensorShape& shape, T dtype,
  403. std::vector<U> values) {
  404. TensorND tensor;
  405. tensor.layout = {shape, dtype};
  406. tensor.raw_ptr =
  407. static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte()));
  408. megdnn_assert(values.size() == tensor.layout.total_nr_elems());
  409. auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>();
  410. auto layout = tensor.layout;
  411. auto dim_in = shape[layout.ndim - 1];
  412. auto elems = tensor.layout.total_nr_elems();
  413. auto dim_out = elems / dim_in;
  414. auto stride_out = div_ceil(dim_in, 2_z);
  415. size_t in_offset = 0;
  416. for (size_t i = 0; i < dim_out; ++i) {
  417. for (size_t j = 0; j < dim_in; j += 2) {
  418. U a = values[in_offset + j];
  419. U b = 0;
  420. if (j + 1 < dim_in)
  421. b = values[in_offset + j + 1];
  422. megdnn_assert(a >= DTypeTrait<T>::min());
  423. megdnn_assert(a <= DTypeTrait<T>::max());
  424. megdnn_assert(b >= DTypeTrait<T>::min());
  425. megdnn_assert(b <= DTypeTrait<T>::max());
  426. ptr[j / 2] = (a & 0xF) | (b << 4);
  427. }
  428. in_offset += dim_in;
  429. ptr += stride_out;
  430. }
  431. return tensor;
  432. }
  433. class Testcase : public SmallVector<TensorND> {
  434. public:
  435. using SmallVector<TensorND>::SmallVector;
  436. ~Testcase() {
  437. // Suicide
  438. for (const auto& tensor : *this) {
  439. if (tensor.raw_ptr) {
  440. free(tensor.raw_ptr);
  441. }
  442. }
  443. }
  444. Testcase(const Testcase&) = delete;
  445. Testcase operator=(const Testcase&) = delete;
  446. };
  447. struct ExecutionPolicyAlgoName {
  448. std::string name;
  449. std::vector<ExecutionPolicyAlgoName> sub_policy_names;
  450. ExecutionPolicyAlgoName(const char* name) : name{name} {}
  451. ExecutionPolicyAlgoName(
  452. const char* name,
  453. const std::vector<ExecutionPolicyAlgoName>& sub_policy)
  454. : name{name}, sub_policy_names{sub_policy} {}
  455. };
  456. /*!
  457. * \brief a callable to check that given algorithm is used for heuristic
  458. * \param require_algo if its value is true, then requires
  459. * get_algorithm_heuristic() to return the expected algo; otherwise the
  460. * expected algo must exist in get_all_algorithms_safe() and it would be set to
  461. * be used
  462. */
  463. template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>>
  464. class AlgoChecker {
  465. public:
  466. AlgoChecker(ExecutionPolicyAlgoName name, bool* require_algo = nullptr)
  467. : m_policy_name{name}, m_require_algo{require_algo} {}
  468. AlgoChecker(ExecutionPolicy policy, bool* require_algo = nullptr)
  469. : m_policy{policy}, m_require_algo{require_algo} {}
  470. static ExecutionPolicy construct_execution_policy_from_name(
  471. const ExecutionPolicyAlgoName& policy_name,
  472. const TensorLayoutArray& layouts, const std::string& param,
  473. Handle* handle) {
  474. ExecutionPolicy ret;
  475. megdnn_assert(layouts.size() == OprTrait<Opr>::arity);
  476. auto opr = handle->create_operator<Opr>();
  477. opr->param() =
  478. Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  479. for (auto algo_info :
  480. AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe(
  481. opr.get(), layouts)) {
  482. if (std::regex_match(
  483. algo_info.desc.name,
  484. std::regex("(" + policy_name.name + ")(.*)"))) {
  485. ret.algo = algo_info.desc;
  486. } else {
  487. continue;
  488. }
  489. Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
  490. std::vector<Algorithm::SearchItem>&& sub_items =
  491. algo->get_subopr_list(layouts, opr.get());
  492. if (sub_items.size() != policy_name.sub_policy_names.size()) {
  493. printf("Invalid sub_policy_names in %s, expected %zu but got "
  494. "%zu\n",
  495. algo_info.desc.name.c_str(), sub_items.size(),
  496. policy_name.sub_policy_names.size());
  497. return {};
  498. }
  499. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  500. ExecutionPolicy policy =
  501. AlgoChecker<_Opr>::construct_execution_policy_from_name(
  502. policy_name.sub_policy_names[_item_idx],
  503. _item.layouts, _item.param, handle);
  504. ret.sub_policy.push_back(policy);
  505. });
  506. return ret;
  507. }
  508. return ret;
  509. }
  510. void operator()(Opr* opr, const CheckerHelper::TensorValueArray& arr) {
  511. TensorLayoutArray layouts;
  512. for (auto&& val : arr) {
  513. layouts.push_back(val.layout);
  514. }
  515. if (!m_policy_name.name.empty()) {
  516. std::string param_str;
  517. Algorithm::serialize_write_pod(opr->param(), param_str);
  518. m_policy = construct_execution_policy_from_name(
  519. m_policy_name, layouts, param_str, opr->handle());
  520. ASSERT_TRUE(m_policy.algo.valid())
  521. << "algorithm " << m_policy_name.name << " not found";
  522. }
  523. if (m_require_algo && *m_require_algo) {
  524. auto algo =
  525. OprAlgoProxy::get_algorithm_info_heuristic(opr, layouts);
  526. ASSERT_STREQ(opr->get_algorithm_from_desc(m_policy.algo)->name(),
  527. algo.desc.name.c_str());
  528. } else {
  529. opr->execution_policy() = m_policy;
  530. }
  531. }
  532. private:
  533. ExecutionPolicyAlgoName m_policy_name;
  534. ExecutionPolicy m_policy;
  535. bool* m_require_algo;
  536. };
  537. template <typename Opr>
  538. void construct_sub_execution_policy_heuristic(ExecutionPolicy& policy,
  539. const TensorLayoutArray& layouts,
  540. const std::string& param,
  541. Handle* handle) {
  542. megdnn_assert(layouts.size() == OprTrait<Opr>::arity);
  543. auto opr = handle->create_operator<Opr>();
  544. opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param);
  545. if (!policy.algo.valid()) {
  546. policy.algo = AlgoProxy<Opr, OprTrait<Opr>::arity>::
  547. get_algorithm_info_heuristic(opr.get(), layouts)
  548. .desc;
  549. }
  550. Algorithm* algo = opr->get_algorithm_from_desc(policy.algo);
  551. std::vector<Algorithm::SearchItem>&& sub_items =
  552. algo->get_subopr_list(layouts, opr.get());
  553. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  554. policy.sub_policy.push_back(ExecutionPolicy{});
  555. construct_sub_execution_policy_heuristic<_Opr>(
  556. policy.sub_policy.back(), _item.layouts, _item.param, handle);
  557. });
  558. }
  559. } // namespace test
  560. } // namespace megdnn
  561. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台