You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_record_check.h 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. /**
  2. * \file dnn/test/common/task_record_check.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <memory>
  13. #include <vector>
  14. #include "megdnn/oprs.h"
  15. #include "src/common/conv_bias.h"
  16. #include "src/common/utils.h"
  17. #include "src/naive/handle.h"
  18. #include "test/common/checker.h"
  19. #include "test/common/index.h"
  20. namespace megdnn {
  21. namespace test {
  22. //! simulation the task dispatch progress
  23. class CpuRecordDispatcher : public MegcoreCPUDispatcher {
  24. std::vector<MegcoreCPUDispatcher::Task> tasks;
  25. bool execute_inplace = false;
  26. public:
  27. void dispatch(MultiThreadingTask&& task, size_t parallelism) override {
  28. if (execute_inplace) {
  29. for (size_t i = 0; i < parallelism; i++) {
  30. task(i, 0);
  31. }
  32. } else {
  33. tasks.push_back([task, parallelism]() {
  34. for (size_t i = 0; i < parallelism; i++) {
  35. task(i, 0);
  36. }
  37. });
  38. }
  39. }
  40. void dispatch(Task&& task) override {
  41. // printf("dispatch one task with execute_inplace = %d\n", execute_inplace);
  42. if (execute_inplace) {
  43. task();
  44. } else {
  45. tasks.push_back(task);
  46. };
  47. }
  48. size_t nr_threads() override { return 1_z; }
  49. void sync() override {}
  50. void enable_execute_inplace() { execute_inplace = true; }
  51. void disable_execute_inplace() { execute_inplace = false; }
  52. void run_task() {
  53. // printf("size of task : %zu\n", tasks.size());
  54. for (auto&& task : tasks) {
  55. task();
  56. }
  57. }
  58. void clear_task() { tasks.clear(); }
  59. };
  60. template <typename Opr, typename Proxy = OprProxy<Opr>>
  61. class TaskRecordChecker : public CheckerHelper {
  62. std::shared_ptr<CpuRecordDispatcher> m_dispatcher;
  63. std::unique_ptr<Handle> m_handle;
  64. Proxy m_naive_proxy, m_cur_proxy;
  65. public:
  66. using Param = typename Opr::Param;
  67. using CheckerHelper::CheckerHelper;
  68. TaskRecordChecker(int debug_level = 0) {
  69. m_dispatcher = std::make_shared<CpuRecordDispatcher>();
  70. m_handle = create_cpu_handle_with_dispatcher(debug_level, m_dispatcher);
  71. }
  72. TensorLayoutArray make_layouts(const TensorShapeArray& shapes) {
  73. TensorLayoutArray layouts(shapes.size());
  74. for (size_t i = 0; i < shapes.size(); ++i) {
  75. DType dt =
  76. (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] : dtype::Float32());
  77. TensorFormat fmt =
  78. (m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{});
  79. layouts[i] = TensorLayout(shapes[i], dt, fmt);
  80. }
  81. return layouts;
  82. }
  83. /*!
  84. * \brief execute opr on current param/dtype/rng config
  85. * \param shapes input/output shapes, which would be passed as
  86. * arguments to Opr::deduce_layout
  87. *
  88. * Checker would construct TensorLayout vectors from shapes and dtypes,
  89. * and call exec(TensorLayoutArray &).
  90. */
  91. TaskRecordChecker& exec(const TensorShapeArray& shapes) {
  92. exec(make_layouts(shapes));
  93. return *this;
  94. }
  95. void exec(TensorLayoutArray layouts);
  96. //! explicitly require argument to be TensorShape
  97. TaskRecordChecker& execs(const TensorShapeArray& shapes) { return exec(shapes); }
  98. //! explicitly require argument to be TensorLayout
  99. TaskRecordChecker& execl(const TensorLayoutArray& layouts) {
  100. exec(layouts);
  101. return *this;
  102. }
  103. TaskRecordChecker& set_param(Param p) {
  104. m_param = p;
  105. opr()->param() = p;
  106. return *this;
  107. }
  108. TaskRecordChecker& set_dtype(size_t idx, DType dtype) {
  109. m_dtype[idx] = dtype;
  110. return *this;
  111. }
  112. TaskRecordChecker& set_rng(size_t idx, RNG* rng) {
  113. m_rng[idx] = rng;
  114. return *this;
  115. }
  116. TaskRecordChecker& set_epsilon(dt_float32 epsilon) {
  117. m_epsilon = epsilon;
  118. m_max_avg_error = epsilon;
  119. m_max_avg_biased_error = epsilon;
  120. return *this;
  121. }
  122. TaskRecordChecker& set_proxy(const Proxy& proxy) {
  123. m_naive_proxy = proxy;
  124. m_cur_proxy = proxy;
  125. return *this;
  126. }
  127. //! get the opr impl so setting other than param() can be modified
  128. Opr* opr() {
  129. if (!m_opr_cur) {
  130. m_opr_cur = m_handle->create_operator<Opr>();
  131. }
  132. return m_opr_cur.get();
  133. }
  134. void free_opr() {
  135. if (m_opr_cur) {
  136. m_opr_cur.reset();
  137. }
  138. }
  139. Handle* get_handle() {
  140. megdnn_assert(m_handle);
  141. return m_handle.get();
  142. }
  143. void copy_tensors(
  144. const CheckerHelper::TensorValueArray& dest,
  145. const CheckerHelper::TensorValueArray& src) {
  146. megdnn_assert(dest.size() == src.size());
  147. for (size_t i = 0; i < src.size(); i++) {
  148. auto&& tensor = src[i];
  149. if (tensor.layout.ndim == 0)
  150. continue;
  151. auto layout = tensor.layout;
  152. auto span = layout.span();
  153. auto dst_ptr = static_cast<dt_byte*>(dest[i].raw_ptr()) + span.low_byte;
  154. auto src_ptr =
  155. static_cast<const dt_byte*>(src[i].raw_ptr()) + span.low_byte;
  156. memcpy(dst_ptr, src_ptr, span.dist_byte());
  157. }
  158. }
  159. private:
  160. Param m_param;
  161. Proxy m_proxy;
  162. std::unique_ptr<Opr> m_opr_cur;
  163. std::shared_ptr<TensorValueArray> m_tensors_first, m_tensors_second,
  164. m_tensors_truth;
  165. std::vector<void*> m_recovery_ptrs;
  166. void init_host_values();
  167. void change_tensor_ptr(
  168. std::shared_ptr<TensorValueArray> des,
  169. std::shared_ptr<TensorValueArray> src, std::vector<void*>&);
  170. void recovery_tensor_ptr(
  171. std::shared_ptr<TensorValueArray> src, const std::vector<void*>&);
  172. };
  173. template <typename Opr, typename Proxy>
  174. void TaskRecordChecker<Opr, Proxy>::exec(TensorLayoutArray layouts) {
  175. auto opr_cur = this->opr();
  176. opr_cur->param() = m_param;
  177. m_proxy.deduce_layout(opr_cur, layouts);
  178. for (size_t i = 0; i < layouts.size(); ++i) {
  179. if (layouts[i].dtype == dtype::Byte()) {
  180. layouts[i] = TensorLayout(layouts[i], dtype::Int8());
  181. }
  182. }
  183. // allocate input
  184. m_tensors_truth = alloc_tensors(m_handle.get(), layouts, 0);
  185. m_tensors_first = alloc_tensors(m_handle.get(), layouts, 0);
  186. m_tensors_second = alloc_tensors(m_handle.get(), layouts, 0);
  187. init_host_values();
  188. copy_tensors(*m_tensors_first, *m_tensors_truth);
  189. copy_tensors(*m_tensors_second, *m_tensors_truth);
  190. m_dispatcher->enable_execute_inplace();
  191. m_proxy.exec(opr_cur, *m_tensors_truth);
  192. m_dispatcher->clear_task();
  193. m_dispatcher->disable_execute_inplace();
  194. //! record the task
  195. m_proxy.exec(opr_cur, *m_tensors_first);
  196. m_dispatcher->run_task();
  197. //! if check record2, the opr should be free
  198. // free_opr();
  199. check_tensors(*m_tensors_truth, *m_tensors_first);
  200. //! change the src and out ptr and run again
  201. change_tensor_ptr(m_tensors_first, m_tensors_second, m_recovery_ptrs);
  202. m_dispatcher->run_task();
  203. check_tensors(*m_tensors_truth, *m_tensors_second);
  204. m_dispatcher->clear_task();
  205. recovery_tensor_ptr(m_tensors_first, m_recovery_ptrs);
  206. m_recovery_ptrs.clear();
  207. }
  208. template <typename Opr, typename Proxy>
  209. void TaskRecordChecker<Opr, Proxy>::init_host_values() {
  210. for (size_t i = 0; i < m_tensors_truth->size(); ++i) {
  211. auto&& tensor = (*m_tensors_truth)[i];
  212. auto rng = m_rng[i];
  213. if (!rng)
  214. rng = m_default_rng.get();
  215. rng->gen(tensor);
  216. }
  217. }
  218. template <typename Opr, typename Proxy>
  219. void TaskRecordChecker<Opr, Proxy>::change_tensor_ptr(
  220. std::shared_ptr<TensorValueArray> des, std::shared_ptr<TensorValueArray> src,
  221. std::vector<void*>& recovery_ptrs) {
  222. for (size_t i = 0; i < des->size(); ++i) {
  223. auto&& tensor_dest = (*des)[i];
  224. auto&& tensor_src = (*src)[i];
  225. megdnn_assert(tensor_dest.layout.eq_layout(tensor_src.layout));
  226. recovery_ptrs.push_back(tensor_dest.raw_ptr());
  227. tensor_dest.reset_ptr(tensor_src.raw_ptr());
  228. }
  229. }
  230. template <typename Opr, typename Proxy>
  231. void TaskRecordChecker<Opr, Proxy>::recovery_tensor_ptr(
  232. std::shared_ptr<TensorValueArray> src,
  233. const std::vector<void*>& recovery_ptrs) {
  234. megdnn_assert(src->size() == recovery_ptrs.size());
  235. for (size_t i = 0; i < src->size(); ++i) {
  236. auto&& tensor_src = (*src)[i];
  237. tensor_src.reset_ptr(recovery_ptrs[i]);
  238. }
  239. }
  240. } // namespace test
  241. } // namespace megdnn
  242. // vim: syntax=cpp.doxygen