You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

accuracy_shake_checker.h 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #pragma once
  2. #include <vector>
  3. #include "megdnn/oprs.h"
  4. #include "src/common/conv_bias.h"
  5. #include "src/common/utils.h"
  6. #include "test/common/checker.h"
  7. #include "test/common/index.h"
  8. namespace megdnn {
  9. namespace test {
  10. namespace {
  11. template <class Opr>
  12. struct BatchTrait {
  13. //! index of batch in tensor, 3 for CHWN4 e.g.
  14. static size_t index_of_batch(const typename Opr::Param&) { return 0; }
  15. //! indices contain batch in inputs and outputs, src(0) dst(2) for conv e.g.
  16. static std::vector<size_t> indices_contain_batch;
  17. static std::vector<size_t> indices_contain_batch_broadcast;
  18. };
  19. template <class Opr>
  20. std::vector<size_t> BatchTrait<Opr>::indices_contain_batch = {};
  21. template <class Opr>
  22. std::vector<size_t> BatchTrait<Opr>::indices_contain_batch_broadcast = {};
  23. #define DEFAULT_INDEX_OF_BATCH(opr) \
  24. static size_t index_of_batch(const opr::Param&) { return 0; }
  25. #define CONV_INDEX_OF_BATCH(opr) \
  26. static size_t index_of_batch(const opr::Param& p) { \
  27. if (p.format == opr::Param::Format::CHWN4) { \
  28. return 3; \
  29. } \
  30. return 0; \
  31. }
  32. #define OPR_WITHOUT_INPUT_BROADCAST(INDEX_OF_BATCH, opr, idxs, idxs_brdcst) \
  33. template <> \
  34. struct BatchTrait<opr> { \
  35. INDEX_OF_BATCH(opr) \
  36. static std::vector<size_t> indices_contain_batch; \
  37. static std::vector<size_t> indices_contain_batch_broadcast; \
  38. }; \
  39. std::vector<size_t> BatchTrait<opr>::indices_contain_batch = idxs; \
  40. std::vector<size_t> BatchTrait<opr>::indices_contain_batch_broadcast = idxs_brdcst;
  41. OPR_WITHOUT_INPUT_BROADCAST(
  42. DEFAULT_INDEX_OF_BATCH, megdnn::Convolution3DForward,
  43. (std::initializer_list<size_t>{0, 2}), {})
  44. OPR_WITHOUT_INPUT_BROADCAST(
  45. DEFAULT_INDEX_OF_BATCH, megdnn::Convolution3DBackwardData,
  46. (std::initializer_list<size_t>{1, 2}), {})
  47. OPR_WITHOUT_INPUT_BROADCAST(
  48. DEFAULT_INDEX_OF_BATCH, megdnn::Convolution3DBackwardFilter,
  49. (std::initializer_list<size_t>{0, 1}), {})
  50. OPR_WITHOUT_INPUT_BROADCAST(
  51. DEFAULT_INDEX_OF_BATCH, megdnn::BatchedMatrixMul,
  52. (std::initializer_list<size_t>{0, 1, 2}), {})
  53. OPR_WITHOUT_INPUT_BROADCAST(
  54. CONV_INDEX_OF_BATCH, megdnn::ConvolutionForward,
  55. (std::initializer_list<size_t>{0, 2}), {})
  56. OPR_WITHOUT_INPUT_BROADCAST(
  57. CONV_INDEX_OF_BATCH, megdnn::ConvolutionBackwardData,
  58. (std::initializer_list<size_t>{1, 2}), {})
  59. OPR_WITHOUT_INPUT_BROADCAST(
  60. CONV_INDEX_OF_BATCH, megdnn::ConvolutionBackwardFilter,
  61. (std::initializer_list<size_t>{0, 1}), {})
  62. OPR_WITHOUT_INPUT_BROADCAST(
  63. CONV_INDEX_OF_BATCH, megdnn::LocalShareForward,
  64. (std::initializer_list<size_t>{0, 2}), {})
  65. OPR_WITHOUT_INPUT_BROADCAST(
  66. CONV_INDEX_OF_BATCH, megdnn::LocalShareBackwardData,
  67. (std::initializer_list<size_t>{1, 2}), {})
  68. OPR_WITHOUT_INPUT_BROADCAST(
  69. CONV_INDEX_OF_BATCH, megdnn::LocalShareBackwardFilter,
  70. (std::initializer_list<size_t>{0, 1}), {})
  71. OPR_WITHOUT_INPUT_BROADCAST(
  72. CONV_INDEX_OF_BATCH, megdnn::DeformableConvForward,
  73. (std::initializer_list<size_t>{0, 2, 3, 4}), {})
  74. OPR_WITHOUT_INPUT_BROADCAST(
  75. CONV_INDEX_OF_BATCH, megdnn::DeformableConvBackwardData,
  76. (std::initializer_list<size_t>{0, 2, 3, 4, 5, 6, 7}), {})
  77. OPR_WITHOUT_INPUT_BROADCAST(
  78. CONV_INDEX_OF_BATCH, megdnn::DeformableConvBackwardFilter,
  79. (std::initializer_list<size_t>{0, 1, 2, 3}), {})
  80. OPR_WITHOUT_INPUT_BROADCAST(
  81. CONV_INDEX_OF_BATCH, megdnn::BatchConvBiasForward,
  82. (std::initializer_list<size_t>{0, 1, 2, 3, 4}), {})
  83. OPR_WITHOUT_INPUT_BROADCAST(
  84. CONV_INDEX_OF_BATCH, megdnn::ConvBiasForward,
  85. (std::initializer_list<size_t>{0, 3, 4}), {2})
  86. #undef OPR_WITHOUT_INPUT_BROADCAST
  87. #undef DEFAULT_INDEX_OF_BATCH
  88. #undef CONV_INDEX_OF_BATCH
  89. template <class Opr>
  90. struct LayoutsModifier {
  91. static void on(
  92. TensorLayoutArray& layouts, const typename Opr::Param& p,
  93. size_t new_batch_size) {
  94. size_t batch_index = BatchTrait<Opr>::index_of_batch(p);
  95. for (size_t index : BatchTrait<Opr>::indices_contain_batch) {
  96. layouts.at(index)[batch_index] = new_batch_size;
  97. }
  98. for (size_t index : BatchTrait<Opr>::indices_contain_batch_broadcast) {
  99. if (!check_bias_share_in_channel(layouts.at(index), p.format)) {
  100. layouts.at(index)[batch_index] = new_batch_size;
  101. }
  102. }
  103. }
  104. };
  105. #define OPR_NO_BIAS(opr) \
  106. template <> \
  107. struct LayoutsModifier<opr> { \
  108. static void on( \
  109. TensorLayoutArray& layouts, const typename opr::Param& p, \
  110. size_t new_batch_size) { \
  111. size_t batch_index = BatchTrait<opr>::index_of_batch(p); \
  112. for (size_t index : BatchTrait<opr>::indices_contain_batch) { \
  113. layouts.at(index)[batch_index] = new_batch_size; \
  114. } \
  115. } \
  116. };
  117. OPR_NO_BIAS(megdnn::Convolution3D)
  118. OPR_NO_BIAS(megdnn::BatchedMatrixMul)
  119. #undef OPR_NO_BIAS
  120. template <>
  121. struct LayoutsModifier<megdnn::MatrixMul> {
  122. public:
  123. static void on(
  124. TensorLayoutArray& layouts, const megdnn::MatrixMul::Param& p,
  125. size_t new_batch_size) {
  126. assert(!p.transposeA && !p.transposeB);
  127. MEGDNN_MARK_USED_VAR(p);
  128. layouts.at(0)[0] = new_batch_size;
  129. layouts.at(2)[0] = new_batch_size;
  130. }
  131. };
  132. template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>>
  133. class AlgoGenerator {
  134. public:
  135. AlgoGenerator(ExecutionPolicyAlgoName name) : m_policy_name{name} {}
  136. std::vector<Algorithm::Info::Desc> operator()(
  137. Opr* opr, const CheckerHelper::TensorValueArray& arr) {
  138. TensorLayoutArray layouts;
  139. for (auto&& val : arr) {
  140. layouts.push_back(val.layout);
  141. }
  142. std::vector<Algorithm::Info::Desc> ret;
  143. megdnn_assert(layouts.size() == OprTrait<Opr>::arity);
  144. auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe(
  145. opr, layouts);
  146. for (auto algo_info : vec) {
  147. if (!(algo_info.attribute & AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) &&
  148. (algo_info.attribute & AlgoAttribute::REPRODUCIBLE) &&
  149. std::regex_match(
  150. algo_info.desc.name,
  151. std::regex("(.*)(" + m_policy_name.name + ")(.*)"))) {
  152. ret.push_back(algo_info.desc);
  153. } else {
  154. continue;
  155. }
  156. }
  157. return ret;
  158. }
  159. private:
  160. ExecutionPolicyAlgoName m_policy_name;
  161. };
  162. } // namespace
  163. ::testing::AssertionResult __assert_tensor_binary_eq(
  164. const char* expr0, const char* expr1, const char* expr2, const TensorND& v0,
  165. const TensorND& v1, const Algorithm::Info::Desc& algo);
  166. template <typename Opr, typename Proxy = OprProxy<Opr>>
  167. class AccuracyShakeChecker : public CheckerHelper {
  168. public:
  169. static constexpr int arity_in = OprArityTrait<Opr>::arity_in;
  170. using Param = typename Opr::Param;
  171. using BeforeExecCallback = std::function<std::vector<Algorithm::Info::Desc>(
  172. Opr*, const TensorValueArray&)>;
  173. AccuracyShakeChecker(Handle* handle, bool check_dispatch = false)
  174. : CheckerHelper(handle, check_dispatch),
  175. m_before_exec_callback{AlgoGenerator<Opr>("")},
  176. m_param(Param()) {}
  177. TensorLayoutArray make_layouts(const TensorShapeArray& shapes) {
  178. TensorLayoutArray layouts(shapes.size());
  179. for (size_t i = 0; i < shapes.size(); ++i) {
  180. DType dt =
  181. (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] : dtype::Float32());
  182. TensorFormat fmt =
  183. (m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{});
  184. layouts[i] = TensorLayout(shapes[i], dt, fmt);
  185. }
  186. return layouts;
  187. }
  188. /*!
  189. * \brief execute opr on current param/dtype/rng config
  190. * \param shapes input/output shapes, which would be passed as
  191. * arguments to Opr::deduce_layout
  192. *
  193. * Checker would construct TensorLayout vectors from shapes and dtypes,
  194. * and call exec(TensorLayoutArray &).
  195. */
  196. AccuracyShakeChecker& exec(const TensorShapeArray& shapes) {
  197. exec(make_layouts(shapes));
  198. return *this;
  199. }
  200. void exec(TensorLayoutArray layouts);
  201. AccuracyShakeChecker& set_param(Param p) {
  202. m_param = p;
  203. opr()->param() = p;
  204. return *this;
  205. }
  206. AccuracyShakeChecker& set_dtype(size_t idx, DType dtype) {
  207. m_dtype[idx] = dtype;
  208. return *this;
  209. }
  210. AccuracyShakeChecker& set_rng(size_t idx, RNG* rng) {
  211. m_rng[idx] = rng;
  212. return *this;
  213. }
  214. //! set a callback to be invoked before executing the operator
  215. AccuracyShakeChecker& set_before_exec_callback(const BeforeExecCallback& cb) {
  216. m_before_exec_callback = cb;
  217. return *this;
  218. }
  219. AccuracyShakeChecker& reset_before_exec_callback() {
  220. m_before_exec_callback = nullptr;
  221. return *this;
  222. }
  223. //! get the opr impl so setting other than param() can be modified
  224. Opr* opr() {
  225. if (!m_opr_cur) {
  226. m_opr_cur = m_handle_cur->create_operator<Opr>();
  227. }
  228. return m_opr_cur.get();
  229. }
  230. private:
  231. BeforeExecCallback m_before_exec_callback;
  232. Param m_param;
  233. Proxy m_proxy;
  234. std::unique_ptr<Opr> m_opr_cur;
  235. std::shared_ptr<TensorValueArray> m_tensors_cur_host, m_tensors_single_batch_host;
  236. void init_host_values();
  237. void check_tensors_ignore_batch(
  238. const TensorValueArray& tensors_single_batch,
  239. const TensorValueArray& tensors, const Algorithm::Info::Desc& desc);
  240. };
  241. template <typename Opr, typename Proxy>
  242. void AccuracyShakeChecker<Opr, Proxy>::exec(TensorLayoutArray layouts) {
  243. auto opr_cur = this->opr();
  244. opr_cur->param() = m_param;
  245. m_proxy.deduce_layout(opr_cur, layouts);
  246. TensorLayoutArray layouts_single_batch = layouts;
  247. for (size_t i = 0; i < layouts_single_batch.size(); ++i) {
  248. ASSERT_TRUE(layouts[i].is_physical_contiguous())
  249. << "layouts should be physical contiguous " << layouts[i].to_string();
  250. }
  251. ASSERT_TRUE(0 == BatchTrait<Opr>::index_of_batch(opr_cur->param()))
  252. << "index of batch should be 0 ";
  253. LayoutsModifier<Opr>::on(layouts_single_batch, opr_cur->param(), 1);
  254. // allocate input
  255. auto tensors_single_batch_storage =
  256. alloc_tensors(m_handle_cur, layouts_single_batch, 0);
  257. m_tensors_single_batch_host =
  258. alloc_tensors(m_handle_naive.get(), layouts_single_batch, 0);
  259. auto tensors_cur_storage = alloc_tensors(m_handle_cur, layouts, 0);
  260. m_tensors_cur_host = alloc_tensors(m_handle_naive.get(), layouts, 0);
  261. auto&& tensors_single_batch = *tensors_single_batch_storage;
  262. auto&& tensors_single_batch_host = *m_tensors_single_batch_host;
  263. auto&& tensors_cur = *tensors_cur_storage;
  264. auto&& tensors_cur_host = *m_tensors_cur_host;
  265. // allocate output
  266. auto tensors_single_batch_storage_out =
  267. alloc_tensors(m_handle_naive.get(), layouts_single_batch, 0);
  268. auto tensors_cur_storage_out = alloc_tensors(m_handle_naive.get(), layouts, 0);
  269. auto&& tensors_single_batch_out = *tensors_single_batch_storage_out;
  270. auto&& tensors_cur_out = *tensors_cur_storage_out;
  271. init_host_values();
  272. copy_tensors_to_device(tensors_cur, tensors_cur_host);
  273. copy_tensors_to_device(tensors_single_batch, tensors_single_batch_host);
  274. std::vector<Algorithm::Info::Desc> algo_desc;
  275. if (m_before_exec_callback) {
  276. algo_desc = m_before_exec_callback(opr_cur, tensors_cur);
  277. } else {
  278. algo_desc.push_back({});
  279. }
  280. for (size_t i = 0; i < algo_desc.size(); ++i) {
  281. opr_cur->execution_policy().algo = algo_desc[i];
  282. m_proxy.exec(opr_cur, tensors_cur);
  283. m_proxy.exec(opr_cur, tensors_single_batch);
  284. copy_tensors_from_device(tensors_cur_out, tensors_cur);
  285. copy_tensors_from_device(tensors_single_batch_out, tensors_single_batch);
  286. check_tensors_ignore_batch(
  287. tensors_single_batch_out, tensors_cur_out, algo_desc[i]);
  288. }
  289. }
  290. template <typename Opr, typename Proxy>
  291. void AccuracyShakeChecker<Opr, Proxy>::init_host_values() {
  292. size_t index_of_batch = 0;
  293. auto&& tensors_single_batch = *m_tensors_single_batch_host;
  294. auto&& tensors_cur = *m_tensors_cur_host;
  295. for (size_t i = 0; i < arity_in; ++i) {
  296. auto&& tensor_single_batch = tensors_single_batch[i];
  297. auto&& tensor_cur = tensors_cur[i];
  298. auto rng = m_rng[i];
  299. if (!rng)
  300. rng = m_default_rng.get();
  301. rng->gen(tensor_single_batch);
  302. dt_byte* raw_storage_cur = static_cast<dt_byte*>(tensor_cur.raw_ptr()) +
  303. tensor_cur.layout.span().low_byte;
  304. dt_byte* raw_storage_single_batch =
  305. static_cast<dt_byte*>(tensor_single_batch.raw_ptr()) +
  306. tensor_single_batch.layout.span().low_byte;
  307. const size_t step = tensor_single_batch.layout.span().dist_byte();
  308. if (tensor_cur.layout.eq_shape(tensor_single_batch.layout)) {
  309. memcpy(raw_storage_cur, raw_storage_single_batch, step);
  310. } else {
  311. ASSERT_TRUE(1 == tensor_single_batch.layout[index_of_batch])
  312. << "bad batch size " << tensor_single_batch.layout[index_of_batch];
  313. for (size_t b = 0; b < tensor_cur.layout[index_of_batch]; ++b) {
  314. memcpy(raw_storage_cur, raw_storage_single_batch, step);
  315. raw_storage_cur += step;
  316. }
  317. }
  318. }
  319. }
  320. template <typename Opr, typename Proxy>
  321. void AccuracyShakeChecker<Opr, Proxy>::check_tensors_ignore_batch(
  322. const TensorValueArray& tensors_single_batch, const TensorValueArray& tensors,
  323. const Algorithm::Info::Desc& algo) {
  324. for (size_t i = 0; i < tensors_single_batch.size(); ++i) {
  325. if (tensors_single_batch[i].layout.ndim == 0 ||
  326. tensors_single_batch[i].layout.eq_shape(tensors[i].layout))
  327. continue;
  328. ASSERT_PRED_FORMAT3(
  329. ::megdnn::test::__assert_tensor_binary_eq, tensors_single_batch[i],
  330. tensors[i], algo);
  331. }
  332. }
  333. } // namespace test
  334. } // namespace megdnn
  335. // vim: syntax=cpp.doxygen