You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_record_check.h 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #pragma once
  2. #include <memory>
  3. #include <vector>
  4. #include "megdnn/oprs.h"
  5. #include "src/common/conv_bias.h"
  6. #include "src/common/utils.h"
  7. #include "src/naive/handle.h"
  8. #include "test/common/checker.h"
  9. #include "test/common/index.h"
  10. namespace megdnn {
  11. namespace test {
  12. //! simulation the task dispatch progress
  13. class CpuRecordDispatcher : public MegcoreCPUDispatcher {
  14. std::vector<MegcoreCPUDispatcher::Task> tasks;
  15. bool execute_inplace = false;
  16. public:
  17. void dispatch(MultiThreadingTask&& task, size_t parallelism) override {
  18. if (execute_inplace) {
  19. for (size_t i = 0; i < parallelism; i++) {
  20. task(i, 0);
  21. }
  22. } else {
  23. tasks.push_back([task, parallelism]() {
  24. for (size_t i = 0; i < parallelism; i++) {
  25. task(i, 0);
  26. }
  27. });
  28. }
  29. }
  30. void dispatch(Task&& task) override {
  31. // printf("dispatch one task with execute_inplace = %d\n", execute_inplace);
  32. if (execute_inplace) {
  33. task();
  34. } else {
  35. tasks.push_back(task);
  36. };
  37. }
  38. size_t nr_threads() override { return 1_z; }
  39. void sync() override {}
  40. void enable_execute_inplace() { execute_inplace = true; }
  41. void disable_execute_inplace() { execute_inplace = false; }
  42. void run_task() {
  43. // printf("size of task : %zu\n", tasks.size());
  44. for (auto&& task : tasks) {
  45. task();
  46. }
  47. }
  48. void clear_task() { tasks.clear(); }
  49. };
  50. template <typename Opr, typename Proxy = OprProxy<Opr>>
  51. class TaskRecordChecker : public CheckerHelper {
  52. std::shared_ptr<CpuRecordDispatcher> m_dispatcher;
  53. std::unique_ptr<Handle> m_handle;
  54. Proxy m_naive_proxy, m_cur_proxy;
  55. public:
  56. using Param = typename Opr::Param;
  57. using CheckerHelper::CheckerHelper;
  58. TaskRecordChecker(int debug_level = 0) {
  59. m_dispatcher = std::make_shared<CpuRecordDispatcher>();
  60. m_handle = create_cpu_handle_with_dispatcher(debug_level, m_dispatcher);
  61. }
  62. TensorLayoutArray make_layouts(const TensorShapeArray& shapes) {
  63. TensorLayoutArray layouts(shapes.size());
  64. for (size_t i = 0; i < shapes.size(); ++i) {
  65. DType dt =
  66. (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] : dtype::Float32());
  67. TensorFormat fmt =
  68. (m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{});
  69. layouts[i] = TensorLayout(shapes[i], dt, fmt);
  70. }
  71. return layouts;
  72. }
  73. /*!
  74. * \brief execute opr on current param/dtype/rng config
  75. * \param shapes input/output shapes, which would be passed as
  76. * arguments to Opr::deduce_layout
  77. *
  78. * Checker would construct TensorLayout vectors from shapes and dtypes,
  79. * and call exec(TensorLayoutArray &).
  80. */
  81. TaskRecordChecker& exec(const TensorShapeArray& shapes) {
  82. exec(make_layouts(shapes));
  83. return *this;
  84. }
  85. void exec(TensorLayoutArray layouts);
  86. //! explicitly require argument to be TensorShape
  87. TaskRecordChecker& execs(const TensorShapeArray& shapes) { return exec(shapes); }
  88. //! explicitly require argument to be TensorLayout
  89. TaskRecordChecker& execl(const TensorLayoutArray& layouts) {
  90. exec(layouts);
  91. return *this;
  92. }
  93. TaskRecordChecker& set_param(Param p) {
  94. m_param = p;
  95. opr()->param() = p;
  96. return *this;
  97. }
  98. TaskRecordChecker& set_dtype(size_t idx, DType dtype) {
  99. m_dtype[idx] = dtype;
  100. return *this;
  101. }
  102. TaskRecordChecker& set_rng(size_t idx, RNG* rng) {
  103. m_rng[idx] = rng;
  104. return *this;
  105. }
  106. TaskRecordChecker& set_epsilon(dt_float32 epsilon) {
  107. m_epsilon = epsilon;
  108. m_max_avg_error = epsilon;
  109. m_max_avg_biased_error = epsilon;
  110. return *this;
  111. }
  112. TaskRecordChecker& set_proxy(const Proxy& proxy) {
  113. m_naive_proxy = proxy;
  114. m_cur_proxy = proxy;
  115. return *this;
  116. }
  117. //! get the opr impl so setting other than param() can be modified
  118. Opr* opr() {
  119. if (!m_opr_cur) {
  120. m_opr_cur = m_handle->create_operator<Opr>();
  121. }
  122. return m_opr_cur.get();
  123. }
  124. void free_opr() {
  125. if (m_opr_cur) {
  126. m_opr_cur.reset();
  127. }
  128. }
  129. Handle* get_handle() {
  130. megdnn_assert(m_handle);
  131. return m_handle.get();
  132. }
  133. void copy_tensors(
  134. const CheckerHelper::TensorValueArray& dest,
  135. const CheckerHelper::TensorValueArray& src) {
  136. megdnn_assert(dest.size() == src.size());
  137. for (size_t i = 0; i < src.size(); i++) {
  138. auto&& tensor = src[i];
  139. if (tensor.layout.ndim == 0)
  140. continue;
  141. auto layout = tensor.layout;
  142. auto span = layout.span();
  143. auto dst_ptr = static_cast<dt_byte*>(dest[i].raw_ptr()) + span.low_byte;
  144. auto src_ptr =
  145. static_cast<const dt_byte*>(src[i].raw_ptr()) + span.low_byte;
  146. memcpy(dst_ptr, src_ptr, span.dist_byte());
  147. }
  148. }
  149. private:
  150. Param m_param;
  151. Proxy m_proxy;
  152. std::unique_ptr<Opr> m_opr_cur;
  153. std::shared_ptr<TensorValueArray> m_tensors_first, m_tensors_second,
  154. m_tensors_truth;
  155. std::vector<void*> m_recovery_ptrs;
  156. void init_host_values();
  157. void change_tensor_ptr(
  158. std::shared_ptr<TensorValueArray> des,
  159. std::shared_ptr<TensorValueArray> src, std::vector<void*>&);
  160. void recovery_tensor_ptr(
  161. std::shared_ptr<TensorValueArray> src, const std::vector<void*>&);
  162. };
  163. template <typename Opr, typename Proxy>
  164. void TaskRecordChecker<Opr, Proxy>::exec(TensorLayoutArray layouts) {
  165. auto opr_cur = this->opr();
  166. opr_cur->param() = m_param;
  167. m_proxy.deduce_layout(opr_cur, layouts);
  168. for (size_t i = 0; i < layouts.size(); ++i) {
  169. if (layouts[i].dtype == dtype::Byte()) {
  170. layouts[i] = TensorLayout(layouts[i], dtype::Int8());
  171. }
  172. }
  173. // allocate input
  174. m_tensors_truth = alloc_tensors(m_handle.get(), layouts, 0);
  175. m_tensors_first = alloc_tensors(m_handle.get(), layouts, 0);
  176. m_tensors_second = alloc_tensors(m_handle.get(), layouts, 0);
  177. init_host_values();
  178. copy_tensors(*m_tensors_first, *m_tensors_truth);
  179. copy_tensors(*m_tensors_second, *m_tensors_truth);
  180. m_dispatcher->enable_execute_inplace();
  181. m_proxy.exec(opr_cur, *m_tensors_truth);
  182. m_dispatcher->clear_task();
  183. m_dispatcher->disable_execute_inplace();
  184. //! record the task
  185. m_proxy.exec(opr_cur, *m_tensors_first);
  186. m_dispatcher->run_task();
  187. //! if check record2, the opr should be free
  188. // free_opr();
  189. check_tensors(*m_tensors_truth, *m_tensors_first);
  190. //! change the src and out ptr and run again
  191. change_tensor_ptr(m_tensors_first, m_tensors_second, m_recovery_ptrs);
  192. m_dispatcher->run_task();
  193. check_tensors(*m_tensors_truth, *m_tensors_second);
  194. m_dispatcher->clear_task();
  195. recovery_tensor_ptr(m_tensors_first, m_recovery_ptrs);
  196. m_recovery_ptrs.clear();
  197. }
  198. template <typename Opr, typename Proxy>
  199. void TaskRecordChecker<Opr, Proxy>::init_host_values() {
  200. for (size_t i = 0; i < m_tensors_truth->size(); ++i) {
  201. auto&& tensor = (*m_tensors_truth)[i];
  202. auto rng = m_rng[i];
  203. if (!rng)
  204. rng = m_default_rng.get();
  205. rng->gen(tensor);
  206. }
  207. }
  208. template <typename Opr, typename Proxy>
  209. void TaskRecordChecker<Opr, Proxy>::change_tensor_ptr(
  210. std::shared_ptr<TensorValueArray> des, std::shared_ptr<TensorValueArray> src,
  211. std::vector<void*>& recovery_ptrs) {
  212. for (size_t i = 0; i < des->size(); ++i) {
  213. auto&& tensor_dest = (*des)[i];
  214. auto&& tensor_src = (*src)[i];
  215. megdnn_assert(tensor_dest.layout.eq_layout(tensor_src.layout));
  216. recovery_ptrs.push_back(tensor_dest.raw_ptr());
  217. tensor_dest.reset_ptr(tensor_src.raw_ptr());
  218. }
  219. }
  220. template <typename Opr, typename Proxy>
  221. void TaskRecordChecker<Opr, Proxy>::recovery_tensor_ptr(
  222. std::shared_ptr<TensorValueArray> src,
  223. const std::vector<void*>& recovery_ptrs) {
  224. megdnn_assert(src->size() == recovery_ptrs.size());
  225. for (size_t i = 0; i < src->size(); ++i) {
  226. auto&& tensor_src = (*src)[i];
  227. tensor_src.reset_ptr(recovery_ptrs[i]);
  228. }
  229. }
  230. } // namespace test
  231. } // namespace megdnn
  232. // vim: syntax=cpp.doxygen