You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/handle.h"
  4. #include "src/common/utils.h"
  5. #include <gtest/gtest.h>
  6. #include <cmath>
  7. #include <cstdlib>
  8. #include <iostream>
  9. #include <memory>
  10. #if MEGDNN_ENABLE_MULTI_THREADS
  11. #include <atomic>
  12. #endif
  13. #define megcore_check(x) \
  14. do { \
  15. auto status = (x); \
  16. if (status != megcoreSuccess) { \
  17. std::cerr << "megcore_check error: " << megcoreGetErrorName(status) \
  18. << std::endl; \
  19. megdnn_trap(); \
  20. } \
  21. } while (0)
  22. namespace megdnn {
  23. namespace test {
  24. struct TaskExecutorConfig {
  25. //! Number of threads.
  26. size_t nr_thread;
  27. //! The core id to bind. The size of affinity_core_set should be equal to
  28. //! nr_thread.
  29. std::vector<size_t> affinity_core_set;
  30. };
  31. class CpuDispatchChecker final : MegcoreCPUDispatcher {
  32. class TaskExecutor {
  33. using Task = megcore::CPUDispatcher::Task;
  34. using MultiThreadingTask = megcore::CPUDispatcher::MultiThreadingTask;
  35. #if MEGDNN_ENABLE_MULTI_THREADS
  36. #if defined(WIN32)
  37. using thread_affinity_type = DWORD;
  38. #else // not WIN32
  39. #if defined(__APPLE__)
  40. using thread_affinity_type = int;
  41. #else
  42. using thread_affinity_type = cpu_set_t;
  43. #endif
  44. #endif
  45. #endif
  46. public:
  47. TaskExecutor(TaskExecutorConfig* config = nullptr);
  48. ~TaskExecutor();
  49. /*!
  50. * Sync all workers.
  51. */
  52. void sync();
  53. /*!
  54. * Number of threads in this thread pool, including the main thread.
  55. */
  56. size_t nr_threads() const { return m_nr_threads; }
  57. void add_task(const MultiThreadingTask& task, size_t parallelism);
  58. void add_task(const Task& task);
  59. private:
  60. #if MEGDNN_ENABLE_MULTI_THREADS
  61. size_t m_all_task_iter = 0;
  62. std::atomic_int m_current_task_iter{0};
  63. //! Indicate whether the thread should work, used for main thread sync
  64. std::vector<std::atomic_bool*> m_workers_flag;
  65. //! Whether the main thread affinity has been set.
  66. bool m_main_thread_affinity = false;
  67. //! Stop the worker threads.
  68. bool m_stop{false};
  69. MultiThreadingTask m_task;
  70. //! The cpuids to be bound.
  71. //! If the m_cpu_ids is empty, then none of the threads will be bound to
  72. //! cpus, else the size of m_cpu_ids should equal to m_nr_threads.
  73. std::vector<size_t> m_cpu_ids;
  74. //! The previous affinity mask of the main thread.
  75. thread_affinity_type m_main_thread_prev_affinity_mask;
  76. std::vector<std::thread> m_workers;
  77. #endif
  78. //! Total number of threads, including main thread.
  79. size_t m_nr_threads = 1;
  80. };
  81. //! track number of CpuDispatchChecker instances to avoid leaking
  82. class InstCounter {
  83. bool m_used = false;
  84. int m_cnt = 0, m_max_cnt = 0;
  85. public:
  86. ~InstCounter() {
  87. auto check = [this]() {
  88. ASSERT_NE(0, m_max_cnt) << "no kernel dispatched on CPU";
  89. ASSERT_EQ(0, m_cnt) << "leaked CpuDispatchChecker object";
  90. };
  91. if (m_used) {
  92. check();
  93. }
  94. }
  95. int& cnt() {
  96. m_used = true;
  97. m_max_cnt = std::max(m_cnt, m_max_cnt);
  98. return m_cnt;
  99. }
  100. };
  101. static InstCounter sm_inst_counter;
  102. bool m_recursive_dispatch = false;
  103. #if MEGDNN_ENABLE_MULTI_THREADS
  104. std::atomic_size_t m_nr_call{0};
  105. #else
  106. size_t m_nr_call = 0;
  107. #endif
  108. std::unique_ptr<TaskExecutor> m_task_executor;
  109. CpuDispatchChecker(TaskExecutorConfig* config) {
  110. ++sm_inst_counter.cnt();
  111. megdnn_assert(sm_inst_counter.cnt() < 10);
  112. m_task_executor = std::make_unique<TaskExecutor>(config);
  113. }
  114. void dispatch(Task&& task) override {
  115. megdnn_assert(!m_recursive_dispatch);
  116. m_recursive_dispatch = true;
  117. ++m_nr_call;
  118. m_task_executor->add_task(std::move(task));
  119. m_recursive_dispatch = false;
  120. }
  121. void dispatch(MultiThreadingTask&& task, size_t parallelism) override {
  122. megdnn_assert(!m_recursive_dispatch);
  123. m_recursive_dispatch = true;
  124. ++m_nr_call;
  125. m_task_executor->add_task(std::move(task), parallelism);
  126. m_recursive_dispatch = false;
  127. }
  128. size_t nr_threads() override { return m_task_executor->nr_threads(); }
  129. CpuDispatchChecker() {
  130. ++sm_inst_counter.cnt();
  131. megdnn_assert(sm_inst_counter.cnt() < 10);
  132. }
  133. void sync() override {}
  134. public:
  135. ~CpuDispatchChecker() {
  136. if (!std::uncaught_exception()) {
  137. megdnn_assert(!m_recursive_dispatch);
  138. } else {
  139. if (m_recursive_dispatch) {
  140. fprintf(stderr,
  141. "CpuDispatchChecker: "
  142. "detected recursive dispatch\n");
  143. }
  144. if (!m_nr_call) {
  145. fprintf(stderr, "CpuDispatchChecker: dispatch not called\n");
  146. }
  147. }
  148. --sm_inst_counter.cnt();
  149. }
  150. static std::unique_ptr<MegcoreCPUDispatcher> make(TaskExecutorConfig* config) {
  151. return std::unique_ptr<MegcoreCPUDispatcher>(new CpuDispatchChecker(config));
  152. }
  153. };
  154. std::unique_ptr<Handle> create_cpu_handle(
  155. int debug_level, bool check_dispatch = true,
  156. TaskExecutorConfig* config = nullptr);
  157. std::unique_ptr<Handle> create_cpu_handle_with_dispatcher(
  158. int debug_level, const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher);
  159. static inline dt_float32 diff(dt_float32 x, dt_float32 y) {
  160. auto numerator = x - y;
  161. auto denominator = std::max(std::max(std::abs(x), std::abs(y)), 1.f);
  162. return numerator / denominator;
  163. }
  164. static inline int diff(int x, int y) {
  165. return x - y;
  166. }
  167. static inline int diff(dt_quint8 x, dt_quint8 y) {
  168. return x.as_uint8() - y.as_uint8();
  169. }
  170. static inline int diff(dt_qint32 x, dt_qint32 y) {
  171. return x.as_int32() - y.as_int32();
  172. }
  173. static inline int diff(dt_qint16 x, dt_qint16 y) {
  174. return x.as_int16() - y.as_int16();
  175. }
  176. static inline int diff(dt_qint8 x, dt_qint8 y) {
  177. return x.as_int8() - y.as_int8();
  178. }
  179. static inline int diff(dt_qint4 x, dt_qint4 y) {
  180. return x.as_int8() - y.as_int8();
  181. }
  182. static inline int diff(dt_qint1 x, dt_qint1 y) {
  183. return x.as_int8() - y.as_int8();
  184. }
  185. static inline int diff(dt_quint4 x, dt_quint4 y) {
  186. return x.as_uint8() - y.as_uint8();
  187. }
  188. inline TensorShape cvt_src_or_dst_nchw2nhwc(const TensorShape& shape) {
  189. megdnn_assert(shape.ndim == 4);
  190. auto N = shape[0], C = shape[1], H = shape[2], W = shape[3];
  191. return TensorShape{N, H, W, C};
  192. }
  193. inline TensorShape cvt_src_or_dst_ncdhw2ndhwc(const TensorShape& shape) {
  194. megdnn_assert(shape.ndim == 5);
  195. auto N = shape[0], C = shape[1], D = shape[2], H = shape[3], W = shape[4];
  196. return TensorShape{N, D, H, W, C};
  197. }
  198. inline TensorShape cvt_filter_nchw2nhwc(const TensorShape& shape) {
  199. if (shape.ndim == 4) {
  200. auto OC = shape[0], IC = shape[1], FH = shape[2], FW = shape[3];
  201. return TensorShape{OC, FH, FW, IC};
  202. } else {
  203. megdnn_assert(shape.ndim == 5);
  204. auto G = shape[0], OC = shape[1], IC = shape[2], FH = shape[3], FW = shape[4];
  205. return TensorShape{G, OC, FH, FW, IC};
  206. }
  207. }
  208. inline TensorShape cvt_filter_ncdhw2ndhwc(const TensorShape& shape) {
  209. if (shape.ndim == 5) {
  210. auto OC = shape[0], IC = shape[1], FD = shape[2], FH = shape[3], FW = shape[4];
  211. return TensorShape{OC, FD, FH, FW, IC};
  212. } else {
  213. megdnn_assert(shape.ndim == 6);
  214. auto G = shape[0], OC = shape[1], IC = shape[2], FD = shape[3], FH = shape[4],
  215. FW = shape[5];
  216. return TensorShape{G, OC, FD, FH, FW, IC};
  217. }
  218. }
  219. void megdnn_sync(Handle* handle);
  220. void* megdnn_malloc(Handle* handle, size_t size_in_bytes);
  221. void megdnn_free(Handle* handle, void* ptr);
  222. void megdnn_memcpy_D2H(
  223. Handle* handle, void* dst, const void* src, size_t size_in_bytes);
  224. void megdnn_memcpy_H2D(
  225. Handle* handle, void* dst, const void* src, size_t size_in_bytes);
  226. void megdnn_memcpy_D2D(
  227. Handle* handle, void* dst, const void* src, size_t size_in_bytes);
  228. //! default implementation for DynOutMallocPolicy
  229. class DynOutMallocPolicyImpl final : public DynOutMallocPolicy {
  230. Handle* m_handle;
  231. public:
  232. DynOutMallocPolicyImpl(Handle* handle) : m_handle{handle} {}
  233. TensorND alloc_output(
  234. size_t id, DType dtype, const TensorShape& shape, void* user_data) override;
  235. void* alloc_workspace(size_t sz, void* user_data) override;
  236. void free_workspace(void* ptr, void* user_data) override;
  237. /*!
  238. * \brief make a shared_ptr which would release output memory when
  239. * deleted
  240. * \param out output tensor allocated by alloc_output()
  241. */
  242. std::shared_ptr<void> make_output_refholder(const TensorND& out);
  243. };
  244. //! replace ErrorHandler::on_megdnn_error
  245. class MegDNNError : public std::exception {
  246. std::string m_msg;
  247. public:
  248. MegDNNError(const std::string& msg) : m_msg{msg} {}
  249. const char* what() const noexcept { return m_msg.c_str(); }
  250. };
  251. class TensorReshapeError : public MegDNNError {
  252. public:
  253. using MegDNNError::MegDNNError;
  254. };
  255. size_t get_cpu_count();
  256. static inline bool good_float(float val) {
  257. return std::isfinite(val);
  258. }
  259. static inline bool good_float(int) {
  260. return true;
  261. }
  262. static inline bool good_float(dt_qint8) {
  263. return true;
  264. }
  265. static inline bool good_float(dt_qint16) {
  266. return true;
  267. }
  268. static inline bool good_float(dt_quint8) {
  269. return true;
  270. }
  271. static inline bool good_float(dt_qint32) {
  272. return true;
  273. }
  274. static inline bool good_float(dt_qint4) {
  275. return true;
  276. }
  277. static inline bool good_float(dt_qint1) {
  278. return true;
  279. }
  280. static inline bool good_float(dt_quint4) {
  281. return true;
  282. }
  283. // A hack for the (x+0) promote to int trick on dt_quint8.
  284. static inline int operator+(dt_quint8 lhs, int rhs) {
  285. megdnn_assert(rhs == 0, "unexpected rhs");
  286. return lhs.as_uint8();
  287. }
  288. static inline int operator+(dt_qint32 lhs, int rhs) {
  289. megdnn_assert(rhs == 0, "unexpected rhs");
  290. return lhs.as_int32();
  291. }
  292. static inline int operator+(dt_qint8 lhs, int rhs) {
  293. megdnn_assert(rhs == 0, "unexpected rhs");
  294. return int8_t(lhs);
  295. }
  296. static inline int operator+(dt_qint16 lhs, int rhs) {
  297. megdnn_assert(rhs == 0, "unexpected rhs");
  298. return lhs.as_int16();
  299. }
  300. static inline int operator+(dt_quint4 lhs, int rhs) {
  301. megdnn_assert(rhs == 0, "unexpected rhs");
  302. return lhs.as_uint8();
  303. }
  304. static inline int operator+(dt_qint4 lhs, int rhs) {
  305. megdnn_assert(rhs == 0, "unexpected rhs");
  306. return lhs.as_int8();
  307. }
  308. static inline int operator+(dt_qint1 lhs, int rhs) {
  309. megdnn_assert(rhs == 0, "unexpected rhs");
  310. return lhs.as_int8();
  311. }
  312. } // namespace test
  313. static inline bool operator==(const TensorLayout& a, const TensorLayout& b) {
  314. return a.eq_layout(b);
  315. }
  316. static inline std::ostream& operator<<(std::ostream& ostr, const TensorLayout& layout) {
  317. return ostr << layout.to_string();
  318. }
  319. //! change the image2d_pitch_alignment of naive handle in this scope
  320. class NaivePitchAlignmentScope {
  321. size_t m_orig_val, m_new_val;
  322. megdnn::Handle::HandleVendorType m_orig_vendor, m_new_vendor;
  323. public:
  324. NaivePitchAlignmentScope(size_t alignment, megdnn::Handle::HandleVendorType vendor);
  325. ~NaivePitchAlignmentScope();
  326. };
  327. } // namespace megdnn
  328. // vim: syntax=cpp.doxygen