You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 40 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041
  1. #include "test/common/elemwise.h"
  2. #include "src/common/utils.cuh"
  3. #include "test/common/checker.h"
  4. #include "test/common/utils.h"
  5. #include "megdnn/oprs/general.h"
  6. #include "test/common/fix_gtest_on_platforms_without_exception.inl"
  7. using namespace megdnn;
  8. using namespace test;
  9. namespace {
  10. void fma3_extra_opr_impl(const TensorNDArray& data) {
  11. megdnn_assert(data.size() == 4);
  12. auto handle = create_cpu_handle(2);
  13. auto opr = handle->create_operator<Elemwise>();
  14. using Mode = Elemwise::Mode;
  15. opr->param().mode = Mode::MUL;
  16. opr->exec({data[0], data[1]}, data[3]);
  17. opr->param().mode = Mode::ADD;
  18. opr->exec({data[2], data[3]}, data[3]);
  19. }
  20. void fma4_extra_opr_impl(const TensorNDArray& data) {
  21. megdnn_assert(data.size() == 5);
  22. std::vector<uint8_t> tmp_storage(data[4].layout.span().dist_byte());
  23. TensorND tmp;
  24. tmp.reset_ptr(tmp_storage.data());
  25. tmp.layout = data[4].layout;
  26. tmp.layout.init_contiguous_stride();
  27. auto handle = create_cpu_handle(2);
  28. auto opr = handle->create_operator<Elemwise>();
  29. using Mode = Elemwise::Mode;
  30. opr->param().mode = Mode::MUL;
  31. opr->exec({data[0], data[1]}, data[4]);
  32. opr->exec({data[2], data[3]}, tmp);
  33. opr->param().mode = Mode::ADD;
  34. opr->exec({tmp, data[4]}, data[4]);
  35. }
  36. TensorLayout make_layout(
  37. const TensorShape& shp, std::initializer_list<ptrdiff_t> stride) {
  38. TensorLayout ret{shp, dtype::Float32()};
  39. megdnn_assert(stride.size() == shp.ndim);
  40. auto idx = 0;
  41. for (auto i : stride)
  42. ret.stride[idx++] = i;
  43. return ret;
  44. }
  45. } // anonymous namespace
  46. namespace megdnn {
  47. namespace test {
  48. namespace elemwise {
  49. #define DEF_TEST(name) \
  50. template <> \
  51. void run_test<name>(Handle * handle)
  52. DEF_TEST(unary) {
  53. using Mode = ElemwiseForward::Param::Mode;
  54. Checker<ElemwiseForward> checker(handle);
  55. checker.set_param(Mode::SIN);
  56. checker.set_dtype(0, dtype::Float32()).execs({{3, 4, 1}, {}});
  57. checker.set_dtype(0, dtype::Float16()).execs({{3, 4, 1}, {}});
  58. }
  59. DEF_TEST(binary_brdcst) {
  60. auto run = [&](DType dtype) {
  61. using Mode = ElemwiseForward::Param::Mode;
  62. Checker<ElemwiseForward> checker(handle);
  63. checker.set_param(Mode::ADD);
  64. checker.set_dtype(0, dtype);
  65. checker.set_dtype(1, dtype);
  66. checker.execs({{3, 1}, {1, 3}, {3, 3}});
  67. {
  68. checker.execs({{10, 11}, {10, 11}, {10, 11}});
  69. //
  70. checker.execs({{2, 3, 4, 5, 6, 7}, {1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}});
  71. checker.execs({{1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}, {2, 3, 4, 5, 6, 7}});
  72. //
  73. checker.execs({{256, 256, 3}, {1, 1, 3}, {256, 256, 3}});
  74. checker.execs({{1, 1, 3}, {256, 256, 3}, {256, 256, 3}});
  75. //
  76. checker.execs({{8, 1, 6, 1}, {1, 7, 1, 5}, {8, 7, 6, 5}});
  77. checker.execs({{1, 7, 1, 5}, {8, 1, 6, 1}, {8, 7, 6, 5}});
  78. //
  79. checker.execs({{5, 4}, {1, 1}, {5, 4}});
  80. checker.execs({{1, 1}, {5, 4}, {5, 4}});
  81. //
  82. checker.execs({{5, 4}, {1, 4}, {5, 4}});
  83. checker.execs({{1, 4}, {5, 4}, {5, 4}});
  84. //
  85. checker.execs({{15, 3, 5}, {15, 1, 5}, {15, 3, 5}});
  86. checker.execs({{15, 1, 5}, {15, 3, 5}, {15, 3, 5}});
  87. //
  88. checker.execs({{15, 3, 5}, {1, 3, 5}, {15, 3, 5}});
  89. checker.execs({{1, 3, 5}, {15, 3, 5}, {15, 3, 5}});
  90. //
  91. checker.execs({{15, 3, 5}, {1, 3, 1}, {15, 3, 5}});
  92. checker.execs({{1, 3, 1}, {15, 3, 5}, {15, 3, 5}});
  93. //
  94. checker.execs({{3, 1}, {1, 4}, {3, 4}});
  95. // numpy broadcast
  96. checker.execs({{2, 3, 1, 5}, {4, 5}, {2, 3, 4, 5}});
  97. checker.execs({{3, 1, 1}, {4, 5}, {3, 4, 5}});
  98. }
  99. {
  100. // 1d
  101. {
  102. auto n = 1000u;
  103. checker.execs({{n}, {n}, {n}});
  104. checker.execs({{1}, {n}, {n}});
  105. checker.execs({{n}, {1}, {n}});
  106. }
  107. // 2d
  108. {
  109. auto m = 200u, n = 100u;
  110. auto collapse = [](size_t n, bool is_collapsed) {
  111. return is_collapsed ? 1u : n;
  112. };
  113. for (auto msk = 0u; msk < 16; ++msk) {
  114. checker.execs(
  115. {{collapse(m, msk & 1), collapse(n, msk & 2)},
  116. {collapse(m, msk & 4), collapse(n, msk & 8)},
  117. {}});
  118. }
  119. }
  120. // nd
  121. {
  122. checker.execs({{2, 3, 4, 5, 6}, {1, 3, 1, 5, 6}, {2, 3, 4, 5, 6}});
  123. checker.execs({{2, 3, 4, 5, 6}, {2, 1, 4, 1, 6}, {2, 3, 4, 5, 6}});
  124. }
  125. }
  126. };
  127. run(dtype::Float32());
  128. // run(dtype::Float16());
  129. }
  130. DEF_TEST(binary_non_contig) {
  131. using Mode = ElemwiseForward::Param::Mode;
  132. Checker<ElemwiseForward> checker(handle);
  133. checker.set_param(Mode::ADD);
  134. TensorLayout ly{{2, 3}, dtype::Float32()};
  135. ly.stride[0] = 4;
  136. checker.execl({ly, ly, {{2, 3}, dtype::Float32()}});
  137. }
  138. DEF_TEST(ternary) {
  139. using Mode = ElemwiseForward::Param::Mode;
  140. Checker<ElemwiseForward> checker(handle);
  141. checker.set_param(Mode::COND_LEQ_MOV);
  142. checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  143. checker.set_dtype(0, dtype::Float32())
  144. .set_dtype(1, dtype::Float32())
  145. .set_dtype(2, dtype::Float32())
  146. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  147. checker.set_dtype(0, dtype::Float16())
  148. .set_dtype(1, dtype::Float16())
  149. .set_dtype(2, dtype::Float16())
  150. .set_dtype(3, dtype::Float16())
  151. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  152. checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
  153. checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
  154. ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
  155. ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
  156. }
  157. DEF_TEST(ternary_non_contig) {
  158. using Mode = ElemwiseForward::Param::Mode;
  159. Checker<ElemwiseForward> checker(handle);
  160. checker.set_param(Mode::COND_LEQ_MOV);
  161. TensorLayout ly{{2, 3}, dtype::Float32()};
  162. ly.stride[0] = 4;
  163. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  164. }
  165. DEF_TEST(ternary_lt) {
  166. using Mode = ElemwiseForward::Param::Mode;
  167. Checker<ElemwiseForward> checker(handle);
  168. checker.set_param(Mode::COND_LT_MOV);
  169. checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  170. checker.set_dtype(0, dtype::Float32())
  171. .set_dtype(1, dtype::Float32())
  172. .set_dtype(2, dtype::Float32())
  173. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  174. checker.set_dtype(0, dtype::Float16())
  175. .set_dtype(1, dtype::Float16())
  176. .set_dtype(2, dtype::Float16())
  177. .set_dtype(3, dtype::Float16())
  178. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  179. checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
  180. checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
  181. ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
  182. ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
  183. }
  184. DEF_TEST(ternary_lt_non_contig) {
  185. using Mode = ElemwiseForward::Param::Mode;
  186. Checker<ElemwiseForward> checker(handle);
  187. checker.set_param(Mode::COND_LT_MOV);
  188. TensorLayout ly{{2, 3}, dtype::Float32()};
  189. ly.stride[0] = 4;
  190. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  191. }
  192. DEF_TEST(fuse_mul_add3) {
  193. using Mode = ElemwiseForward::Param::Mode;
  194. Checker<ElemwiseForward> checker(handle);
  195. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  196. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  197. const TensorShape& s2) {
  198. TensorShape dest;
  199. dest.ndim = s0.ndim;
  200. for (size_t i = 0; i < dest.ndim; ++i) {
  201. auto a = i < s0.ndim ? s0[i] : 1;
  202. auto b = i < s1.ndim ? s1[i] : 1;
  203. dest[i] = std::max(a, b);
  204. }
  205. return TensorShapeArray{s0, s1, s2, dest};
  206. };
  207. checker.exec(make_shape({2, 1}, {2, 2}, {2, 2}));
  208. checker.exec(make_shape({2, 2}, {2, 1}, {2, 2}));
  209. checker.exec(make_shape({2, 2}, {2, 2}, {1}));
  210. checker.exec(make_shape({3, 1}, {1, 3}, {3, 1}));
  211. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}, {1}));
  212. checker.exec(make_shape({1, 1, 3}, {5, 8, 1}, {5, 8, 1}));
  213. checker.exec(make_shape({1, 192, 9, 16}, {1}, {1, 192, 9, 16}));
  214. }
  215. DEF_TEST(fuse_mul_add3_non_contig) {
  216. using Mode = ElemwiseForward::Param::Mode;
  217. Checker<ElemwiseForward> checker(handle);
  218. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  219. TensorLayout ly{{2, 3}, dtype::Float32()};
  220. ly.stride[0] = 4;
  221. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  222. }
  223. DEF_TEST(fuse_mul_add4) {
  224. using Mode = ElemwiseForward::Param::Mode;
  225. Checker<ElemwiseForward> checker(handle);
  226. checker.set_param(Mode::FUSE_MUL_ADD4).set_extra_opr_impl(fma4_extra_opr_impl);
  227. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  228. bool swap = false) {
  229. TensorShape dest;
  230. dest.ndim = s0.ndim;
  231. for (size_t i = 0; i < dest.ndim; ++i) {
  232. auto a = i < s0.ndim ? s0[i] : 1;
  233. auto b = i < s1.ndim ? s1[i] : 1;
  234. dest[i] = std::max(a, b);
  235. }
  236. TensorShapeArray ret{s0, s1, s0, s1, dest};
  237. if (swap)
  238. std::swap(ret[2], ret[3]);
  239. return ret;
  240. };
  241. checker.exec(make_shape({2, 2}, {2, 2}));
  242. checker.exec(make_shape({3, 1}, {1, 3}));
  243. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}));
  244. checker.exec(make_shape({4, 2}, {1, 2}, true));
  245. }
  246. DEF_TEST(rmulh) {
  247. using Mode = ElemwiseForward::Param::Mode;
  248. Checker<ElemwiseForward> checker(handle);
  249. auto run_for_dtype = [&checker](auto dtype) {
  250. auto minv = DTypeTrait<decltype(dtype)>::min();
  251. auto maxv = DTypeTrait<decltype(dtype)>::max();
  252. UniformIntRNG rng0{minv, maxv};
  253. UniformIntRNG rngM{(maxv >> 1) + 1, maxv};
  254. checker.set_param({Mode::RMULH})
  255. .set_dtype(0, dtype)
  256. .set_dtype(1, dtype)
  257. .set_dtype(2, dtype)
  258. .set_rng(0, &rng0)
  259. .set_rng(1, &rngM);
  260. checker.execs({{7, 9, 11, 13}, {1}, {}})
  261. .execs({{16, 3, 256, 256}, {1}, {}})
  262. .execs({{2, 3, 1, 7}, {2, 3, 1, 7}, {}})
  263. .execs({{9, 5, 4}, {1, 5, 1}, {}})
  264. .execs({{233}, {1}, {}});
  265. };
  266. run_for_dtype(dtype::Int8());
  267. run_for_dtype(dtype::Int16());
  268. run_for_dtype(dtype::Int32());
  269. }
  270. /* ============= migrated from x86 tests ============= */
  271. #define UNARY_TEST_CASE(_optr) \
  272. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  273. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  274. #define BUILD_UNARY_TEST_CASE_INT \
  275. UNARY_TEST_CASE(RELU) \
  276. UNARY_TEST_CASE(ABS)
  277. #define BUILD_UNARY_TEST_CASE_FLOAT \
  278. UNARY_TEST_CASE(ABS) \
  279. UNARY_TEST_CASE(LOG) \
  280. UNARY_TEST_CASE(COS) \
  281. UNARY_TEST_CASE(SIN) \
  282. UNARY_TEST_CASE(FLOOR) \
  283. UNARY_TEST_CASE(CEIL) \
  284. UNARY_TEST_CASE(SIGMOID) \
  285. UNARY_TEST_CASE(EXP) \
  286. UNARY_TEST_CASE(TANH) \
  287. UNARY_TEST_CASE(FAST_TANH) \
  288. UNARY_TEST_CASE(RELU) \
  289. UNARY_TEST_CASE(ROUND)
  290. DEF_TEST(unary1) {
  291. using Mode = ElemwiseForward::Param::Mode;
  292. Checker<ElemwiseForward> checker(handle);
  293. // case int
  294. checker.set_dtype(0, dtype::Int8());
  295. BUILD_UNARY_TEST_CASE_INT
  296. checker.set_dtype(0, dtype::Int16());
  297. BUILD_UNARY_TEST_CASE_INT
  298. checker.set_dtype(0, dtype::Int32());
  299. BUILD_UNARY_TEST_CASE_INT
  300. // case float
  301. UniformFloatRNG rng(1e-2, 6e1);
  302. checker.set_rng(0, &rng);
  303. checker.set_epsilon(1e-5);
  304. checker.set_dtype(0, dtype::Float32());
  305. BUILD_UNARY_TEST_CASE_FLOAT
  306. }
  307. #undef UNARY_TEST_CASE
  308. #undef BUILD_UNARY_TEST_CASE_INT
  309. #undef BUILD_UNARY_TEST_CASE_FLOAT
  310. #define BINARY_TEST_CASE(_optr) \
  311. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  312. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  313. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  314. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  315. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  316. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  317. #define BUILD_BINARY_TEST_CASE \
  318. BINARY_TEST_CASE(MIN) \
  319. BINARY_TEST_CASE(MAX)
  320. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  321. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  322. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  323. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  324. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  325. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  326. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  327. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  328. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  329. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  330. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  331. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  332. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  333. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  334. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  335. BINARY_COMPLATE_TEST_CASE(ADD) \
  336. BINARY_COMPLATE_TEST_CASE(MUL) \
  337. BINARY_COMPLATE_TEST_CASE(MAX) \
  338. BINARY_COMPLATE_TEST_CASE(MIN) \
  339. BINARY_COMPLATE_TEST_CASE(SUB)
  340. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  341. BINARY_COMPLATE_TEST_CASE(POW) \
  342. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  343. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  344. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  345. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU) \
  346. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_H_SWISH) \
  347. BINARY_COMPLATE_TEST_CASE(FAST_TANH_GRAD) \
  348. BINARY_COMPLATE_TEST_CASE(H_SWISH_GRAD)
  349. DEF_TEST(binary1) {
  350. using Mode = ElemwiseForward::Param::Mode;
  351. Checker<ElemwiseForward> checker(handle);
  352. // case float
  353. UniformFloatRNG rng(1e-5, 7e1);
  354. checker.set_rng(0, &rng);
  355. checker.set_epsilon(1e-5);
  356. checker.set_dtype(0, dtype::Float32());
  357. checker.set_dtype(1, dtype::Float32());
  358. BUILD_BINARY_COMPLATE_TEST_CASE
  359. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  360. // case int
  361. checker.set_dtype(0, dtype::Int8());
  362. checker.set_dtype(1, dtype::Int8());
  363. BUILD_BINARY_TEST_CASE
  364. BUILD_BINARY_COMPLATE_TEST_CASE
  365. checker.set_dtype(0, dtype::Int16());
  366. checker.set_dtype(1, dtype::Int16());
  367. BUILD_BINARY_TEST_CASE
  368. BUILD_BINARY_COMPLATE_TEST_CASE
  369. checker.set_dtype(0, dtype::Int32());
  370. checker.set_dtype(1, dtype::Int32());
  371. BUILD_BINARY_TEST_CASE
  372. BUILD_BINARY_COMPLATE_TEST_CASE
  373. }
  374. #undef BINARY_TEST_CASE
  375. #undef BUILD_BINARY_TEST_CASE
  376. #undef BINARY_COMPLATE_TEST_CASE
  377. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  378. #undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  379. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  380. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  381. checker.set_param(Mode::_optr) \
  382. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  383. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  384. checker.set_param(Mode::_optr) \
  385. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  386. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  387. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  388. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  389. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  390. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  391. DEF_TEST(ternary1) {
  392. using Mode = ElemwiseForward::Param::Mode;
  393. Checker<ElemwiseForward> checker(handle);
  394. // case int
  395. checker.set_dtype(0, dtype::Int8());
  396. checker.set_dtype(1, dtype::Int8());
  397. checker.set_dtype(2, dtype::Int8());
  398. // BUILD_TERNARY_TEST_CASE
  399. BUILD_TERNARY_COMPLATE_TEST_CASE
  400. checker.set_dtype(0, dtype::Int16());
  401. checker.set_dtype(1, dtype::Int16());
  402. checker.set_dtype(2, dtype::Int16());
  403. // BUILD_TERNARY_TEST_CASE
  404. BUILD_TERNARY_COMPLATE_TEST_CASE
  405. checker.set_dtype(0, dtype::Int32());
  406. checker.set_dtype(1, dtype::Int32());
  407. checker.set_dtype(2, dtype::Int32());
  408. // BUILD_TERNARY_TEST_CASE
  409. BUILD_TERNARY_COMPLATE_TEST_CASE
  410. // case float
  411. UniformFloatRNG rng(1e-5, 7e1);
  412. checker.set_rng(0, &rng);
  413. checker.set_epsilon(1e-5);
  414. checker.set_dtype(0, dtype::Float32());
  415. checker.set_dtype(1, dtype::Float32());
  416. checker.set_dtype(2, dtype::Float32());
  417. // BUILD_TERNARY_TEST_CASE
  418. BUILD_TERNARY_COMPLATE_TEST_CASE
  419. // TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  420. }
  421. #undef TERNARY_COMPLATE_TEST_CASE
  422. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  423. /* ============= migrated from arm tests ============= */
  424. #define UNARY_TEST_CASE(_optr) \
  425. checker.set_param(Mode::_optr).execs({{1, 129}, {}}); \
  426. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  427. #define BUILD_UNARY_TEST_CASE_INT \
  428. UNARY_TEST_CASE(RELU) \
  429. UNARY_TEST_CASE(ABS) \
  430. UNARY_TEST_CASE(NEGATE)
  431. #define BUILD_UNARY_TEST_CASE_FLOAT \
  432. BUILD_UNARY_TEST_CASE_INT \
  433. UNARY_TEST_CASE(SIGMOID) \
  434. UNARY_TEST_CASE(EXP) \
  435. UNARY_TEST_CASE(TANH) \
  436. UNARY_TEST_CASE(FAST_TANH) \
  437. UNARY_TEST_CASE(H_SWISH)
  438. DEF_TEST(unary2) {
  439. using Mode = ElemwiseForward::Param::Mode;
  440. Checker<ElemwiseForward> checker(handle);
  441. // case int
  442. checker.set_dtype(0, dtype::Int8());
  443. BUILD_UNARY_TEST_CASE_INT
  444. checker.set_dtype(0, dtype::Int16());
  445. BUILD_UNARY_TEST_CASE_INT
  446. checker.set_dtype(0, dtype::Int32());
  447. BUILD_UNARY_TEST_CASE_INT
  448. // case float
  449. {
  450. UniformFloatRNG rng(1e-5, 7e1);
  451. checker.set_rng(0, &rng);
  452. checker.set_epsilon(1e-5);
  453. checker.set_dtype(0, dtype::Float32());
  454. BUILD_UNARY_TEST_CASE_FLOAT
  455. }
  456. {
  457. UniformFloatRNG rng(1e-2, 1e1);
  458. checker.set_rng(0, &rng);
  459. checker.set_epsilon(6e-3);
  460. checker.set_dtype(0, dtype::Float16());
  461. BUILD_UNARY_TEST_CASE_FLOAT
  462. }
  463. // tanh NaN bug case
  464. {
  465. UniformFloatRNG rng(100, 200);
  466. checker.set_rng(0, &rng);
  467. checker.set_epsilon(1e-5);
  468. checker.set_dtype(0, dtype::Float32());
  469. checker.set_param(Mode::TANH).execs({{1, 1025}, {}});
  470. checker.set_param(Mode::TANH).execs({{1, 7}, {}});
  471. }
  472. }
  473. #undef UNARY_TEST_CASE
  474. #undef BUILD_UNARY_TEST_CASE_INT
  475. #undef BUILD_UNARY_TEST_CASE_FLOAT
  476. #define BINARY_TEST_CASE(_optr) \
  477. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  478. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  479. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  480. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  481. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  482. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  483. #define BUILD_BINARY_TEST_CASE \
  484. BINARY_TEST_CASE(MIN) \
  485. BINARY_TEST_CASE(MAX)
  486. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  487. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  488. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  489. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  490. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  491. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  492. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  493. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  494. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  495. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  496. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  497. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  498. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  499. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  500. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  501. BINARY_COMPLATE_TEST_CASE(ADD) \
  502. BINARY_COMPLATE_TEST_CASE(MUL) \
  503. BINARY_COMPLATE_TEST_CASE(MAX) \
  504. BINARY_COMPLATE_TEST_CASE(MIN) \
  505. BINARY_COMPLATE_TEST_CASE(SUB) \
  506. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  507. DEF_TEST(binary2) {
  508. using Mode = ElemwiseForward::Param::Mode;
  509. Checker<ElemwiseForward> checker(handle);
  510. // case float
  511. UniformFloatRNG rng(1e-5, 7e1);
  512. checker.set_rng(0, &rng);
  513. checker.set_epsilon(1e-5);
  514. checker.set_dtype(0, dtype::Float32());
  515. checker.set_dtype(1, dtype::Float32());
  516. BUILD_BINARY_COMPLATE_TEST_CASE
  517. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID)
  518. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH)
  519. // case int
  520. checker.set_dtype(0, dtype::Int8());
  521. checker.set_dtype(1, dtype::Int8());
  522. // BUILD_BINARY_TEST_CASE
  523. BUILD_BINARY_COMPLATE_TEST_CASE
  524. checker.set_dtype(0, dtype::Int16());
  525. checker.set_dtype(1, dtype::Int16());
  526. // BUILD_BINARY_TEST_CASE
  527. BUILD_BINARY_COMPLATE_TEST_CASE
  528. checker.set_dtype(0, dtype::Int32());
  529. checker.set_dtype(1, dtype::Int32());
  530. BUILD_BINARY_TEST_CASE
  531. BUILD_BINARY_COMPLATE_TEST_CASE
  532. // case float
  533. checker.set_rng(0, &rng);
  534. checker.set_epsilon(1e-5);
  535. checker.set_dtype(0, dtype::Float32());
  536. checker.set_dtype(1, dtype::Float32());
  537. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  538. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  539. // commutable
  540. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  541. BUILD_BINARY_TEST_CASE
  542. BUILD_BINARY_COMPLATE_TEST_CASE
  543. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  544. {
  545. UniformFloatRNG rng(1e-3, 3e1);
  546. checker.set_rng(0, &rng);
  547. checker.set_rng(1, &rng);
  548. checker.set_epsilon(1e-3);
  549. checker.set_dtype(0, dtype::Float16());
  550. checker.set_dtype(1, dtype::Float16());
  551. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  552. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  553. BUILD_BINARY_TEST_CASE
  554. BUILD_BINARY_COMPLATE_TEST_CASE
  555. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  556. // commutable
  557. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  558. }
  559. }
  560. #undef BINARY_TEST_CASE
  561. #undef BUILD_BINARY_TEST_CASE
  562. #undef BINARY_COMPLATE_TEST_CASE
  563. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  564. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  565. checker.set_param(Mode::_optr) \
  566. .execs({{1, 123, 1}, {300, 123, 253}, {300, 123, 253}, {}}); \
  567. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  568. checker.set_param(Mode::_optr) \
  569. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  570. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  571. checker.set_param(Mode::_optr) \
  572. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  573. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  574. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  575. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  576. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
  577. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {1, 1, 1}, {3, 4, 1}, {}});
  578. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  579. DEF_TEST(ternary2) {
  580. using Mode = ElemwiseForward::Param::Mode;
  581. Checker<ElemwiseForward> checker(handle);
  582. // case int
  583. checker.set_dtype(0, dtype::Int8());
  584. checker.set_dtype(1, dtype::Int8());
  585. checker.set_dtype(2, dtype::Int8());
  586. BUILD_TERNARY_COMPLATE_TEST_CASE
  587. checker.set_dtype(0, dtype::Int16());
  588. checker.set_dtype(1, dtype::Int16());
  589. checker.set_dtype(2, dtype::Int16());
  590. BUILD_TERNARY_COMPLATE_TEST_CASE
  591. checker.set_dtype(0, dtype::Int32());
  592. checker.set_dtype(1, dtype::Int32());
  593. checker.set_dtype(2, dtype::Int32());
  594. BUILD_TERNARY_COMPLATE_TEST_CASE
  595. // case float
  596. UniformFloatRNG rng(1e-5, 7e1);
  597. checker.set_rng(0, &rng);
  598. checker.set_epsilon(1e-5);
  599. checker.set_dtype(0, dtype::Float32());
  600. checker.set_dtype(1, dtype::Float32());
  601. checker.set_dtype(2, dtype::Float32());
  602. BUILD_TERNARY_COMPLATE_TEST_CASE
  603. {
  604. UniformFloatRNG rng(1e-3, 3e1);
  605. checker.set_rng(0, &rng);
  606. checker.set_rng(1, &rng);
  607. checker.set_rng(2, &rng);
  608. checker.set_epsilon(1e-3);
  609. checker.set_dtype(0, dtype::Float16());
  610. checker.set_dtype(1, dtype::Float16());
  611. checker.set_dtype(2, dtype::Float16());
  612. BUILD_TERNARY_COMPLATE_TEST_CASE
  613. }
  614. }
  615. #undef TERNARY_COMPLATE_TEST_CASE
  616. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  617. /* ============= migrated from fallback tests ============= */
  618. DEF_TEST(unary3) {
  619. Checker<Elemwise> checker(handle);
  620. auto make_layouts =
  621. [](const TensorShape& shp,
  622. std::initializer_list<ptrdiff_t> stride) -> TensorLayoutArray {
  623. return {make_layout(shp, stride), {shp, dtype::Float32()}};
  624. };
  625. checker.set_param({Elemwise::Mode::SIN});
  626. checker.exec(make_layouts({2, 2}, {2, 1}));
  627. checker.exec(make_layouts({4}, {3}));
  628. }
  629. DEF_TEST(binary3) {
  630. Checker<Elemwise> checker(handle);
  631. checker.set_param({Elemwise::Mode::ADD});
  632. auto run = [&](const TensorShape& shp0, std::initializer_list<ptrdiff_t> stride0,
  633. const TensorShape& shp1, std::initializer_list<ptrdiff_t> stride1) {
  634. TensorShape shpo;
  635. Elemwise::deduce_shape({shp0, shp1}, shpo);
  636. auto ly0 = make_layout(shp0, stride0), ly1 = make_layout(shp1, stride1),
  637. lyo = TensorLayout{shpo, dtype::Float32()};
  638. checker.execl({ly0, ly1, lyo});
  639. checker.execl({ly1, ly0, lyo});
  640. };
  641. run({2, 2}, {2, 1}, {2, 2}, {2, 1});
  642. run({1}, {1}, {3, 3}, {1, 2});
  643. run({3, 4, 5}, {40, 10, 2}, {1, 4, 1}, {1, 1, 1});
  644. }
  645. DEF_TEST(all_modes) {
  646. // test correctness of all elemwise modes
  647. Checker<Elemwise> checker(handle);
  648. TensorShapeArray shapes;
  649. UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f},
  650. small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f},
  651. abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f},
  652. tanh_rng_f32{-5.f, 5.f};
  653. UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f},
  654. big_nonzero_rng_f32{100.f, 1000.f};
  655. UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2},
  656. shift_rng_i32_i32{0, 31}, shift_rng_i32_i8{0, 7};
  657. UniformIntNonZeroRNG nonzero_rng_i32{1, 100};
  658. using Mode = Elemwise::Mode;
  659. auto should_ignore = [handle](Mode mode) {
  660. MEGDNN_MARK_USED_VAR(mode);
  661. return false;
  662. };
  663. for (int mode_nr = 0; mode_nr < static_cast<int>(Elemwise::Param::MODE_NR_MEMBER);
  664. ++mode_nr) {
  665. auto mode = static_cast<Mode>(mode_nr);
  666. // ignore unsupported modes
  667. if (should_ignore(mode)) {
  668. continue;
  669. }
  670. checker.set_param({mode});
  671. auto&& trait = Elemwise::ModeTrait::from_mode(mode);
  672. shapes.resize(trait.arity + 1);
  673. for (size_t i = 0; i < shapes.size() - 1; ++i) {
  674. shapes[i] = {3, 9, 7};
  675. }
  676. //! NOTE: force set output layout to empty to trigger layout deduce
  677. shapes[shapes.size() - 1] = {};
  678. auto do_run = [&](DType dtype, float eps = 1e-3) {
  679. // limit value ranges for some modes
  680. if (mode == Mode::LOG || mode == Mode::LOG1P) {
  681. checker.set_rng(0, &pos_rng_f32);
  682. } else if (mode == Mode::POW) {
  683. checker.set_rng(0, &small_pos_rng_f32);
  684. checker.set_rng(1, &small_rng_f32);
  685. } else if (mode == Mode::EXP || mode == Mode::EXPM1) {
  686. checker.set_rng(0, &small_rng_f32);
  687. } else if (mode == Mode::FAST_TANH) {
  688. checker.set_rng(0, &tanh_rng_f32);
  689. } else if (mode == Mode::LOG_SUM_EXP) {
  690. // check numerical stability with large values
  691. checker.set_rng(0, &big_nonzero_rng_f32);
  692. checker.set_rng(1, &big_nonzero_rng_f32);
  693. } else if (
  694. mode == Mode::ASIN || mode == Mode::ACOS ||
  695. mode == Mode::SIGMOID_GRAD || mode == Mode::TANH_GRAD ||
  696. mode == Mode::ERFINV) {
  697. checker.set_rng(0, &abslt1_rng_f32);
  698. checker.set_rng(1, &default_rng_f32);
  699. } else if (mode == Mode::ERFCINV) {
  700. checker.set_rng(0, &uniform_0_2_rng);
  701. } else if (
  702. mode == Mode::MOD || mode == Mode::TRUE_DIV ||
  703. mode == Mode::FLOOR_DIV) {
  704. if (dtype.category() == DTypeCategory::INT) {
  705. checker.set_rng(0, &default_rng_i32);
  706. checker.set_rng(1, &nonzero_rng_i32);
  707. } else {
  708. checker.set_rng(0, &default_rng_f32);
  709. checker.set_rng(1, &nonzero_rng_f32);
  710. }
  711. } else if (mode == Mode::EQ) {
  712. checker.set_rng(0, &small_rng_i32);
  713. checker.set_rng(1, &small_rng_i32);
  714. } else if (mode == Mode::SHL || mode == Mode::SHR) {
  715. checker.set_rng(0, &default_rng_i32);
  716. if (dtype.size() == 4) {
  717. checker.set_rng(1, &shift_rng_i32_i32);
  718. } else {
  719. megdnn_assert(dtype.size() == 1);
  720. checker.set_rng(1, &shift_rng_i32_i8);
  721. }
  722. } else if (mode == Mode::ATAN2) {
  723. checker.set_rng(0, &nonzero_rng_f32);
  724. checker.set_rng(1, &nonzero_rng_f32);
  725. } else {
  726. RNG* rng;
  727. if (dtype.category() == DTypeCategory::INT) {
  728. rng = &default_rng_i32;
  729. } else {
  730. rng = &default_rng_f32;
  731. }
  732. for (size_t i = 0; i < shapes.size(); ++i) {
  733. checker.set_rng(i, rng);
  734. }
  735. }
  736. checker.set_epsilon(eps);
  737. for (size_t i = 0; i < shapes.size(); ++i) {
  738. checker.set_dtype(i, dtype);
  739. }
  740. EXPECT_NO_THROW(checker.execs(shapes));
  741. if (!::testing::Test::HasFailure() && shapes.size() == 3) {
  742. // channel bcast
  743. shapes[1][0] = 1;
  744. shapes[1][2] = 1;
  745. EXPECT_NO_THROW(checker.execs(shapes));
  746. if (!::testing::Test::HasFailure()) {
  747. // scalar bcast
  748. shapes[1][1] = 1;
  749. EXPECT_NO_THROW(checker.execs(shapes));
  750. }
  751. }
  752. if (::testing::Test::HasFailure()) {
  753. printf("failed on mode=%d(%s) dtype=%s\n", mode_nr, trait.name,
  754. dtype.name());
  755. for (auto&& i : shapes) {
  756. printf("ishape: %s\n", i.to_string().c_str());
  757. }
  758. return false;
  759. }
  760. return true;
  761. };
  762. #define run(args...) \
  763. do { \
  764. if (!do_run(args)) { \
  765. return; \
  766. } \
  767. } while (0)
  768. if (trait.allow_int) {
  769. run(dtype::Int8{});
  770. run(dtype::Int32{});
  771. }
  772. if (trait.allow_float) {
  773. DNN_FLOAT16_SELECT(
  774. run(dtype::Float16{}, mode == Mode::FAST_TANH_GRAD ? 0.5 : 0.05), );
  775. run(dtype::Float32{});
  776. }
  777. }
  778. #undef run
  779. }
  780. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  781. checker.set_param(Mode::_optr) \
  782. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
  783. checker.set_param(Mode::_optr) \
  784. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
  785. checker.set_param(Mode::_optr) \
  786. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});
  787. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
  788. checker.set_param(Mode::_optr) \
  789. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});
  790. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  791. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
  792. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);
  793. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
  794. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
  795. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
  796. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
  797. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
  798. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
  799. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
  800. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
  801. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
  802. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
  803. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
  804. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
  805. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)
  806. DEF_TEST(unary_negative_stride) {
  807. using Mode = ElemwiseForward::Param::Mode;
  808. Checker<ElemwiseForward> checker(handle);
  809. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  810. UniformFloatRNG rng(1e-2, 6e1);
  811. checker.set_rng(0, &rng);
  812. checker.set_epsilon(1e-5);
  813. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
  814. }
  815. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  816. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  817. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  818. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  819. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  820. checker.set_param(Mode::_optr) \
  821. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
  822. {{1, 4, 1}, dtype::Int8()}, \
  823. {}}); \
  824. checker.set_param(Mode::_optr) \
  825. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
  826. {{1, 4, 1}, dtype::Int16()}, \
  827. {}}); \
  828. checker.set_param(Mode::_optr) \
  829. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
  830. {{1, 4, 1}, dtype::Int32()}, \
  831. {}});
  832. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
  833. checker.set_param(Mode::_optr) \
  834. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
  835. {{1, 4, 1}, dtype::Float32()}, \
  836. {}});
  837. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  838. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
  839. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
  840. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
  841. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
  842. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)
  843. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
  844. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
  845. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
  846. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
  847. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
  848. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
  849. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
  850. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
  851. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)
  852. DEF_TEST(binary_negative_stride) {
  853. using Mode = ElemwiseForward::Param::Mode;
  854. Checker<ElemwiseForward> checker(handle);
  855. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  856. UniformFloatRNG rng(1e-2, 2e1);
  857. checker.set_rng(0, &rng);
  858. checker.set_epsilon(1e-5);
  859. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
  860. }
  861. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  862. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  863. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  864. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  865. DEF_TEST(ternary_negative_stride) {
  866. using Mode = ElemwiseForward::Param::Mode;
  867. Checker<ElemwiseForward> checker(handle);
  868. checker.set_param(Mode::FUSE_MUL_ADD3);
  869. checker.execl(
  870. {{{1, 7}, {-7, -1}, dtype::Int8()},
  871. {{1, 7}, {-3, -1}, dtype::Int8()},
  872. {{1, 7}, {-7, -1}, dtype::Int8()},
  873. {}});
  874. checker.execl(
  875. {{{1, 7}, {-7, -1}, dtype::Int16()},
  876. {{1, 7}, {-3, -1}, dtype::Int16()},
  877. {{1, 7}, {-7, -1}, dtype::Int16()},
  878. {}});
  879. checker.execl(
  880. {{{1, 7}, {-7, -1}, dtype::Int32()},
  881. {{1, 7}, {-3, -1}, dtype::Int32()},
  882. {{1, 7}, {-7, -1}, dtype::Int32()},
  883. {}});
  884. UniformFloatRNG rng(1e-2, 2e1);
  885. checker.set_rng(0, &rng);
  886. checker.set_epsilon(1e-5);
  887. checker.execl(
  888. {{{1, 7}, {-7, -1}, dtype::Float32()},
  889. {{1, 7}, {-3, -1}, dtype::Float32()},
  890. {{1, 7}, {-7, -1}, dtype::Float32()},
  891. {}});
  892. }
  893. TEST(TEST_ELEMWISE, MODE_TRAIT) {
  894. using M = Elemwise::Mode;
  895. using T = Elemwise::ModeTrait;
  896. ASSERT_EQ(1u, T::from_mode(M::RELU).arity);
  897. ASSERT_EQ(2u, T::from_mode(M::ADD).arity);
  898. ASSERT_EQ(3u, T::from_mode(M::FUSE_MUL_ADD3).arity);
  899. ASSERT_EQ(4u, T::from_mode(M::FUSE_MUL_ADD4).arity);
  900. ASSERT_TRUE(T::from_mode(M::ADD).commutable);
  901. ASSERT_FALSE(T::from_mode(M::TRUE_DIV).commutable);
  902. ASSERT_TRUE(T::from_mode(M::ADD).allow_int);
  903. ASSERT_FALSE(T::from_mode(M::EXP).allow_int);
  904. ASSERT_TRUE(T::from_mode(M::ADD).allow_float);
  905. ASSERT_FALSE(T::from_mode(M::SHL).allow_float);
  906. ASSERT_TRUE(T::from_mode(M::RMULH).commutable);
  907. ASSERT_FALSE(T::from_mode(M::RMULH).allow_float);
  908. ASSERT_TRUE(T::from_mode(M::XOR).allow_bool);
  909. }
  910. } // namespace elemwise
  911. } // namespace test
  912. } // namespace megdnn
  913. // vim: syntax=cpp.doxygen