You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020
  1. /**
  2. * \file dnn/test/common/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/elemwise.h"
  12. #include "src/common/utils.cuh"
  13. #include "test/common/checker.h"
  14. #include "test/common/utils.h"
  15. #include "megdnn/oprs/general.h"
  16. #include "test/common/fix_gtest_on_platforms_without_exception.inl"
  17. using namespace megdnn;
  18. using namespace test;
  19. namespace {
  20. void fma3_extra_opr_impl(const TensorNDArray& data) {
  21. megdnn_assert(data.size() == 4);
  22. auto handle = create_cpu_handle(2);
  23. auto opr = handle->create_operator<Elemwise>();
  24. using Mode = Elemwise::Mode;
  25. opr->param().mode = Mode::MUL;
  26. opr->exec({data[0], data[1]}, data[3]);
  27. opr->param().mode = Mode::ADD;
  28. opr->exec({data[2], data[3]}, data[3]);
  29. }
  30. void fma4_extra_opr_impl(const TensorNDArray& data) {
  31. megdnn_assert(data.size() == 5);
  32. std::vector<uint8_t> tmp_storage(data[4].layout.span().dist_byte());
  33. TensorND tmp;
  34. tmp.reset_ptr(tmp_storage.data());
  35. tmp.layout = data[4].layout;
  36. tmp.layout.init_contiguous_stride();
  37. auto handle = create_cpu_handle(2);
  38. auto opr = handle->create_operator<Elemwise>();
  39. using Mode = Elemwise::Mode;
  40. opr->param().mode = Mode::MUL;
  41. opr->exec({data[0], data[1]}, data[4]);
  42. opr->exec({data[2], data[3]}, tmp);
  43. opr->param().mode = Mode::ADD;
  44. opr->exec({tmp, data[4]}, data[4]);
  45. }
  46. TensorLayout make_layout(
  47. const TensorShape& shp, std::initializer_list<ptrdiff_t> stride) {
  48. TensorLayout ret{shp, dtype::Float32()};
  49. megdnn_assert(stride.size() == shp.ndim);
  50. auto idx = 0;
  51. for (auto i : stride)
  52. ret.stride[idx++] = i;
  53. return ret;
  54. }
  55. } // anonymous namespace
  56. namespace megdnn {
  57. namespace test {
  58. namespace elemwise {
  59. #define DEF_TEST(name) \
  60. template <> \
  61. void run_test<name>(Handle * handle)
  62. DEF_TEST(unary) {
  63. using Mode = ElemwiseForward::Param::Mode;
  64. Checker<ElemwiseForward> checker(handle);
  65. checker.set_param(Mode::SIN);
  66. checker.set_dtype(0, dtype::Float32()).execs({{3, 4, 1}, {}});
  67. checker.set_dtype(0, dtype::Float16()).execs({{3, 4, 1}, {}});
  68. }
  69. DEF_TEST(binary_brdcst) {
  70. auto run = [&](DType dtype) {
  71. using Mode = ElemwiseForward::Param::Mode;
  72. Checker<ElemwiseForward> checker(handle);
  73. checker.set_param(Mode::ADD);
  74. checker.set_dtype(0, dtype);
  75. checker.set_dtype(1, dtype);
  76. checker.execs({{3, 1}, {1, 3}, {3, 3}});
  77. {
  78. checker.execs({{10, 11}, {10, 11}, {10, 11}});
  79. //
  80. checker.execs({{2, 3, 4, 5, 6, 7}, {1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}});
  81. checker.execs({{1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}, {2, 3, 4, 5, 6, 7}});
  82. //
  83. checker.execs({{256, 256, 3}, {1, 1, 3}, {256, 256, 3}});
  84. checker.execs({{1, 1, 3}, {256, 256, 3}, {256, 256, 3}});
  85. //
  86. checker.execs({{8, 1, 6, 1}, {1, 7, 1, 5}, {8, 7, 6, 5}});
  87. checker.execs({{1, 7, 1, 5}, {8, 1, 6, 1}, {8, 7, 6, 5}});
  88. //
  89. checker.execs({{5, 4}, {1, 1}, {5, 4}});
  90. checker.execs({{1, 1}, {5, 4}, {5, 4}});
  91. //
  92. checker.execs({{5, 4}, {1, 4}, {5, 4}});
  93. checker.execs({{1, 4}, {5, 4}, {5, 4}});
  94. //
  95. checker.execs({{15, 3, 5}, {15, 1, 5}, {15, 3, 5}});
  96. checker.execs({{15, 1, 5}, {15, 3, 5}, {15, 3, 5}});
  97. //
  98. checker.execs({{15, 3, 5}, {1, 3, 5}, {15, 3, 5}});
  99. checker.execs({{1, 3, 5}, {15, 3, 5}, {15, 3, 5}});
  100. //
  101. checker.execs({{15, 3, 5}, {1, 3, 1}, {15, 3, 5}});
  102. checker.execs({{1, 3, 1}, {15, 3, 5}, {15, 3, 5}});
  103. //
  104. checker.execs({{3, 1}, {1, 4}, {3, 4}});
  105. // numpy broadcast
  106. checker.execs({{2, 3, 1, 5}, {4, 5}, {2, 3, 4, 5}});
  107. checker.execs({{3, 1, 1}, {4, 5}, {3, 4, 5}});
  108. }
  109. {
  110. // 1d
  111. {
  112. auto n = 1000u;
  113. checker.execs({{n}, {n}, {n}});
  114. checker.execs({{1}, {n}, {n}});
  115. checker.execs({{n}, {1}, {n}});
  116. }
  117. // 2d
  118. {
  119. auto m = 200u, n = 100u;
  120. auto collapse = [](size_t n, bool is_collapsed) {
  121. return is_collapsed ? 1u : n;
  122. };
  123. for (auto msk = 0u; msk < 16; ++msk) {
  124. checker.execs(
  125. {{collapse(m, msk & 1), collapse(n, msk & 2)},
  126. {collapse(m, msk & 4), collapse(n, msk & 8)},
  127. {}});
  128. }
  129. }
  130. // nd
  131. {
  132. checker.execs({{2, 3, 4, 5, 6}, {1, 3, 1, 5, 6}, {2, 3, 4, 5, 6}});
  133. checker.execs({{2, 3, 4, 5, 6}, {2, 1, 4, 1, 6}, {2, 3, 4, 5, 6}});
  134. }
  135. }
  136. };
  137. run(dtype::Float32());
  138. // run(dtype::Float16());
  139. }
  140. DEF_TEST(binary_non_contig) {
  141. using Mode = ElemwiseForward::Param::Mode;
  142. Checker<ElemwiseForward> checker(handle);
  143. checker.set_param(Mode::ADD);
  144. TensorLayout ly{{2, 3}, dtype::Float32()};
  145. ly.stride[0] = 4;
  146. checker.execl({ly, ly, {{2, 3}, dtype::Float32()}});
  147. }
  148. DEF_TEST(ternary) {
  149. using Mode = ElemwiseForward::Param::Mode;
  150. Checker<ElemwiseForward> checker(handle);
  151. checker.set_param(Mode::COND_LEQ_MOV);
  152. checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  153. checker.set_dtype(0, dtype::Float32())
  154. .set_dtype(1, dtype::Float32())
  155. .set_dtype(2, dtype::Float32())
  156. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  157. checker.set_dtype(0, dtype::Float16())
  158. .set_dtype(1, dtype::Float16())
  159. .set_dtype(2, dtype::Float16())
  160. .set_dtype(3, dtype::Float16())
  161. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  162. checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
  163. checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
  164. ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
  165. ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
  166. }
  167. DEF_TEST(ternary_non_contig) {
  168. using Mode = ElemwiseForward::Param::Mode;
  169. Checker<ElemwiseForward> checker(handle);
  170. checker.set_param(Mode::COND_LEQ_MOV);
  171. TensorLayout ly{{2, 3}, dtype::Float32()};
  172. ly.stride[0] = 4;
  173. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  174. }
  175. DEF_TEST(fuse_mul_add3) {
  176. using Mode = ElemwiseForward::Param::Mode;
  177. Checker<ElemwiseForward> checker(handle);
  178. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  179. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  180. const TensorShape& s2) {
  181. TensorShape dest;
  182. dest.ndim = s0.ndim;
  183. for (size_t i = 0; i < dest.ndim; ++i) {
  184. auto a = i < s0.ndim ? s0[i] : 1;
  185. auto b = i < s1.ndim ? s1[i] : 1;
  186. dest[i] = std::max(a, b);
  187. }
  188. return TensorShapeArray{s0, s1, s2, dest};
  189. };
  190. checker.exec(make_shape({2, 1}, {2, 2}, {2, 2}));
  191. checker.exec(make_shape({2, 2}, {2, 1}, {2, 2}));
  192. checker.exec(make_shape({2, 2}, {2, 2}, {1}));
  193. checker.exec(make_shape({3, 1}, {1, 3}, {3, 1}));
  194. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}, {1}));
  195. checker.exec(make_shape({1, 1, 3}, {5, 8, 1}, {5, 8, 1}));
  196. checker.exec(make_shape({1, 192, 9, 16}, {1}, {1, 192, 9, 16}));
  197. }
  198. DEF_TEST(fuse_mul_add3_non_contig) {
  199. using Mode = ElemwiseForward::Param::Mode;
  200. Checker<ElemwiseForward> checker(handle);
  201. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  202. TensorLayout ly{{2, 3}, dtype::Float32()};
  203. ly.stride[0] = 4;
  204. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  205. }
  206. DEF_TEST(fuse_mul_add4) {
  207. using Mode = ElemwiseForward::Param::Mode;
  208. Checker<ElemwiseForward> checker(handle);
  209. checker.set_param(Mode::FUSE_MUL_ADD4).set_extra_opr_impl(fma4_extra_opr_impl);
  210. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  211. bool swap = false) {
  212. TensorShape dest;
  213. dest.ndim = s0.ndim;
  214. for (size_t i = 0; i < dest.ndim; ++i) {
  215. auto a = i < s0.ndim ? s0[i] : 1;
  216. auto b = i < s1.ndim ? s1[i] : 1;
  217. dest[i] = std::max(a, b);
  218. }
  219. TensorShapeArray ret{s0, s1, s0, s1, dest};
  220. if (swap)
  221. std::swap(ret[2], ret[3]);
  222. return ret;
  223. };
  224. checker.exec(make_shape({2, 2}, {2, 2}));
  225. checker.exec(make_shape({3, 1}, {1, 3}));
  226. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}));
  227. checker.exec(make_shape({4, 2}, {1, 2}, true));
  228. }
  229. DEF_TEST(rmulh) {
  230. using Mode = ElemwiseForward::Param::Mode;
  231. Checker<ElemwiseForward> checker(handle);
  232. auto run_for_dtype = [&checker](auto dtype) {
  233. auto minv = DTypeTrait<decltype(dtype)>::min();
  234. auto maxv = DTypeTrait<decltype(dtype)>::max();
  235. UniformIntRNG rng0{minv, maxv};
  236. UniformIntRNG rngM{(maxv >> 1) + 1, maxv};
  237. checker.set_param({Mode::RMULH})
  238. .set_dtype(0, dtype)
  239. .set_dtype(1, dtype)
  240. .set_dtype(2, dtype)
  241. .set_rng(0, &rng0)
  242. .set_rng(1, &rngM);
  243. checker.execs({{7, 9, 11, 13}, {1}, {}})
  244. .execs({{16, 3, 256, 256}, {1}, {}})
  245. .execs({{2, 3, 1, 7}, {2, 3, 1, 7}, {}})
  246. .execs({{9, 5, 4}, {1, 5, 1}, {}})
  247. .execs({{233}, {1}, {}});
  248. };
  249. run_for_dtype(dtype::Int8());
  250. run_for_dtype(dtype::Int16());
  251. run_for_dtype(dtype::Int32());
  252. }
  253. /* ============= migrated from x86 tests ============= */
  254. #define UNARY_TEST_CASE(_optr) \
  255. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  256. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  257. #define BUILD_UNARY_TEST_CASE_INT \
  258. UNARY_TEST_CASE(RELU) \
  259. UNARY_TEST_CASE(ABS)
  260. #define BUILD_UNARY_TEST_CASE_FLOAT \
  261. UNARY_TEST_CASE(ABS) \
  262. UNARY_TEST_CASE(LOG) \
  263. UNARY_TEST_CASE(COS) \
  264. UNARY_TEST_CASE(SIN) \
  265. UNARY_TEST_CASE(FLOOR) \
  266. UNARY_TEST_CASE(CEIL) \
  267. UNARY_TEST_CASE(SIGMOID) \
  268. UNARY_TEST_CASE(EXP) \
  269. UNARY_TEST_CASE(TANH) \
  270. UNARY_TEST_CASE(FAST_TANH) \
  271. UNARY_TEST_CASE(RELU) \
  272. UNARY_TEST_CASE(ROUND)
  273. DEF_TEST(unary1) {
  274. using Mode = ElemwiseForward::Param::Mode;
  275. Checker<ElemwiseForward> checker(handle);
  276. // case int
  277. checker.set_dtype(0, dtype::Int8());
  278. BUILD_UNARY_TEST_CASE_INT
  279. checker.set_dtype(0, dtype::Int16());
  280. BUILD_UNARY_TEST_CASE_INT
  281. checker.set_dtype(0, dtype::Int32());
  282. BUILD_UNARY_TEST_CASE_INT
  283. // case float
  284. UniformFloatRNG rng(1e-2, 6e1);
  285. checker.set_rng(0, &rng);
  286. checker.set_epsilon(1e-5);
  287. checker.set_dtype(0, dtype::Float32());
  288. BUILD_UNARY_TEST_CASE_FLOAT
  289. }
  290. #undef UNARY_TEST_CASE
  291. #undef BUILD_UNARY_TEST_CASE_INT
  292. #undef BUILD_UNARY_TEST_CASE_FLOAT
  293. #define BINARY_TEST_CASE(_optr) \
  294. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  295. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  296. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  297. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  298. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  299. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  300. #define BUILD_BINARY_TEST_CASE \
  301. BINARY_TEST_CASE(MIN) \
  302. BINARY_TEST_CASE(MAX)
  303. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  304. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  305. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  306. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  307. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  308. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  309. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  310. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  311. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  312. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  313. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  314. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  315. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  316. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  317. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  318. BINARY_COMPLATE_TEST_CASE(ADD) \
  319. BINARY_COMPLATE_TEST_CASE(MUL) \
  320. BINARY_COMPLATE_TEST_CASE(MAX) \
  321. BINARY_COMPLATE_TEST_CASE(MIN) \
  322. BINARY_COMPLATE_TEST_CASE(SUB)
  323. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  324. BINARY_COMPLATE_TEST_CASE(POW) \
  325. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  326. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  327. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  328. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU) \
  329. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_H_SWISH) \
  330. BINARY_COMPLATE_TEST_CASE(FAST_TANH_GRAD) \
  331. BINARY_COMPLATE_TEST_CASE(H_SWISH_GRAD)
  332. DEF_TEST(binary1) {
  333. using Mode = ElemwiseForward::Param::Mode;
  334. Checker<ElemwiseForward> checker(handle);
  335. // case float
  336. UniformFloatRNG rng(1e-5, 7e1);
  337. checker.set_rng(0, &rng);
  338. checker.set_epsilon(1e-5);
  339. checker.set_dtype(0, dtype::Float32());
  340. checker.set_dtype(1, dtype::Float32());
  341. BUILD_BINARY_COMPLATE_TEST_CASE
  342. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  343. // case int
  344. checker.set_dtype(0, dtype::Int8());
  345. checker.set_dtype(1, dtype::Int8());
  346. BUILD_BINARY_TEST_CASE
  347. BUILD_BINARY_COMPLATE_TEST_CASE
  348. checker.set_dtype(0, dtype::Int16());
  349. checker.set_dtype(1, dtype::Int16());
  350. BUILD_BINARY_TEST_CASE
  351. BUILD_BINARY_COMPLATE_TEST_CASE
  352. checker.set_dtype(0, dtype::Int32());
  353. checker.set_dtype(1, dtype::Int32());
  354. BUILD_BINARY_TEST_CASE
  355. BUILD_BINARY_COMPLATE_TEST_CASE
  356. }
  357. #undef BINARY_TEST_CASE
  358. #undef BUILD_BINARY_TEST_CASE
  359. #undef BINARY_COMPLATE_TEST_CASE
  360. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  361. #undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  362. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  363. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  364. checker.set_param(Mode::_optr) \
  365. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  366. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  367. checker.set_param(Mode::_optr) \
  368. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  369. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  370. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  371. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  372. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  373. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  374. DEF_TEST(ternary1) {
  375. using Mode = ElemwiseForward::Param::Mode;
  376. Checker<ElemwiseForward> checker(handle);
  377. // case int
  378. checker.set_dtype(0, dtype::Int8());
  379. checker.set_dtype(1, dtype::Int8());
  380. checker.set_dtype(2, dtype::Int8());
  381. // BUILD_TERNARY_TEST_CASE
  382. BUILD_TERNARY_COMPLATE_TEST_CASE
  383. checker.set_dtype(0, dtype::Int16());
  384. checker.set_dtype(1, dtype::Int16());
  385. checker.set_dtype(2, dtype::Int16());
  386. // BUILD_TERNARY_TEST_CASE
  387. BUILD_TERNARY_COMPLATE_TEST_CASE
  388. checker.set_dtype(0, dtype::Int32());
  389. checker.set_dtype(1, dtype::Int32());
  390. checker.set_dtype(2, dtype::Int32());
  391. // BUILD_TERNARY_TEST_CASE
  392. BUILD_TERNARY_COMPLATE_TEST_CASE
  393. // case float
  394. UniformFloatRNG rng(1e-5, 7e1);
  395. checker.set_rng(0, &rng);
  396. checker.set_epsilon(1e-5);
  397. checker.set_dtype(0, dtype::Float32());
  398. checker.set_dtype(1, dtype::Float32());
  399. checker.set_dtype(2, dtype::Float32());
  400. // BUILD_TERNARY_TEST_CASE
  401. BUILD_TERNARY_COMPLATE_TEST_CASE
  402. // TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  403. }
  404. #undef TERNARY_COMPLATE_TEST_CASE
  405. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  406. /* ============= migrated from arm tests ============= */
  407. #define UNARY_TEST_CASE(_optr) \
  408. checker.set_param(Mode::_optr).execs({{1, 129}, {}}); \
  409. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  410. #define BUILD_UNARY_TEST_CASE_INT \
  411. UNARY_TEST_CASE(RELU) \
  412. UNARY_TEST_CASE(ABS) \
  413. UNARY_TEST_CASE(NEGATE)
  414. #define BUILD_UNARY_TEST_CASE_FLOAT \
  415. BUILD_UNARY_TEST_CASE_INT \
  416. UNARY_TEST_CASE(SIGMOID) \
  417. UNARY_TEST_CASE(EXP) \
  418. UNARY_TEST_CASE(TANH) \
  419. UNARY_TEST_CASE(FAST_TANH) \
  420. UNARY_TEST_CASE(H_SWISH)
  421. DEF_TEST(unary2) {
  422. using Mode = ElemwiseForward::Param::Mode;
  423. Checker<ElemwiseForward> checker(handle);
  424. // case int
  425. checker.set_dtype(0, dtype::Int8());
  426. BUILD_UNARY_TEST_CASE_INT
  427. checker.set_dtype(0, dtype::Int16());
  428. BUILD_UNARY_TEST_CASE_INT
  429. checker.set_dtype(0, dtype::Int32());
  430. BUILD_UNARY_TEST_CASE_INT
  431. // case float
  432. {
  433. UniformFloatRNG rng(1e-5, 7e1);
  434. checker.set_rng(0, &rng);
  435. checker.set_epsilon(1e-5);
  436. checker.set_dtype(0, dtype::Float32());
  437. BUILD_UNARY_TEST_CASE_FLOAT
  438. }
  439. {
  440. UniformFloatRNG rng(1e-2, 1e1);
  441. checker.set_rng(0, &rng);
  442. checker.set_epsilon(6e-3);
  443. checker.set_dtype(0, dtype::Float16());
  444. BUILD_UNARY_TEST_CASE_FLOAT
  445. }
  446. // tanh NaN bug case
  447. {
  448. UniformFloatRNG rng(100, 200);
  449. checker.set_rng(0, &rng);
  450. checker.set_epsilon(1e-5);
  451. checker.set_dtype(0, dtype::Float32());
  452. checker.set_param(Mode::TANH).execs({{1, 1025}, {}});
  453. checker.set_param(Mode::TANH).execs({{1, 7}, {}});
  454. }
  455. }
  456. #undef UNARY_TEST_CASE
  457. #undef BUILD_UNARY_TEST_CASE_INT
  458. #undef BUILD_UNARY_TEST_CASE_FLOAT
  459. #define BINARY_TEST_CASE(_optr) \
  460. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  461. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  462. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  463. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  464. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  465. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  466. #define BUILD_BINARY_TEST_CASE \
  467. BINARY_TEST_CASE(MIN) \
  468. BINARY_TEST_CASE(MAX)
  469. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  470. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  471. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  472. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  473. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  474. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  475. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  476. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  477. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  478. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  479. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  480. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  481. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  482. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  483. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  484. BINARY_COMPLATE_TEST_CASE(ADD) \
  485. BINARY_COMPLATE_TEST_CASE(MUL) \
  486. BINARY_COMPLATE_TEST_CASE(MAX) \
  487. BINARY_COMPLATE_TEST_CASE(MIN) \
  488. BINARY_COMPLATE_TEST_CASE(SUB) \
  489. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  490. DEF_TEST(binary2) {
  491. using Mode = ElemwiseForward::Param::Mode;
  492. Checker<ElemwiseForward> checker(handle);
  493. // case float
  494. UniformFloatRNG rng(1e-5, 7e1);
  495. checker.set_rng(0, &rng);
  496. checker.set_epsilon(1e-5);
  497. checker.set_dtype(0, dtype::Float32());
  498. checker.set_dtype(1, dtype::Float32());
  499. BUILD_BINARY_COMPLATE_TEST_CASE
  500. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID)
  501. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH)
  502. // case int
  503. checker.set_dtype(0, dtype::Int8());
  504. checker.set_dtype(1, dtype::Int8());
  505. // BUILD_BINARY_TEST_CASE
  506. BUILD_BINARY_COMPLATE_TEST_CASE
  507. checker.set_dtype(0, dtype::Int16());
  508. checker.set_dtype(1, dtype::Int16());
  509. // BUILD_BINARY_TEST_CASE
  510. BUILD_BINARY_COMPLATE_TEST_CASE
  511. checker.set_dtype(0, dtype::Int32());
  512. checker.set_dtype(1, dtype::Int32());
  513. BUILD_BINARY_TEST_CASE
  514. BUILD_BINARY_COMPLATE_TEST_CASE
  515. // case float
  516. checker.set_rng(0, &rng);
  517. checker.set_epsilon(1e-5);
  518. checker.set_dtype(0, dtype::Float32());
  519. checker.set_dtype(1, dtype::Float32());
  520. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  521. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  522. // commutable
  523. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  524. BUILD_BINARY_TEST_CASE
  525. BUILD_BINARY_COMPLATE_TEST_CASE
  526. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  527. {
  528. UniformFloatRNG rng(1e-3, 3e1);
  529. checker.set_rng(0, &rng);
  530. checker.set_rng(1, &rng);
  531. checker.set_epsilon(1e-3);
  532. checker.set_dtype(0, dtype::Float16());
  533. checker.set_dtype(1, dtype::Float16());
  534. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  535. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  536. BUILD_BINARY_TEST_CASE
  537. BUILD_BINARY_COMPLATE_TEST_CASE
  538. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  539. // commutable
  540. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  541. }
  542. }
  543. #undef BINARY_TEST_CASE
  544. #undef BUILD_BINARY_TEST_CASE
  545. #undef BINARY_COMPLATE_TEST_CASE
  546. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  547. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  548. checker.set_param(Mode::_optr) \
  549. .execs({{1, 123, 1}, {300, 123, 253}, {300, 123, 253}, {}}); \
  550. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  551. checker.set_param(Mode::_optr) \
  552. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  553. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  554. checker.set_param(Mode::_optr) \
  555. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  556. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  557. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  558. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  559. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
  560. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {1, 1, 1}, {3, 4, 1}, {}});
  561. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  562. DEF_TEST(ternary2) {
  563. using Mode = ElemwiseForward::Param::Mode;
  564. Checker<ElemwiseForward> checker(handle);
  565. // case int
  566. checker.set_dtype(0, dtype::Int8());
  567. checker.set_dtype(1, dtype::Int8());
  568. checker.set_dtype(2, dtype::Int8());
  569. BUILD_TERNARY_COMPLATE_TEST_CASE
  570. checker.set_dtype(0, dtype::Int16());
  571. checker.set_dtype(1, dtype::Int16());
  572. checker.set_dtype(2, dtype::Int16());
  573. BUILD_TERNARY_COMPLATE_TEST_CASE
  574. checker.set_dtype(0, dtype::Int32());
  575. checker.set_dtype(1, dtype::Int32());
  576. checker.set_dtype(2, dtype::Int32());
  577. BUILD_TERNARY_COMPLATE_TEST_CASE
  578. // case float
  579. UniformFloatRNG rng(1e-5, 7e1);
  580. checker.set_rng(0, &rng);
  581. checker.set_epsilon(1e-5);
  582. checker.set_dtype(0, dtype::Float32());
  583. checker.set_dtype(1, dtype::Float32());
  584. checker.set_dtype(2, dtype::Float32());
  585. BUILD_TERNARY_COMPLATE_TEST_CASE
  586. {
  587. UniformFloatRNG rng(1e-3, 3e1);
  588. checker.set_rng(0, &rng);
  589. checker.set_rng(1, &rng);
  590. checker.set_rng(2, &rng);
  591. checker.set_epsilon(1e-3);
  592. checker.set_dtype(0, dtype::Float16());
  593. checker.set_dtype(1, dtype::Float16());
  594. checker.set_dtype(2, dtype::Float16());
  595. BUILD_TERNARY_COMPLATE_TEST_CASE
  596. }
  597. }
  598. #undef TERNARY_COMPLATE_TEST_CASE
  599. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  600. /* ============= migrated from fallback tests ============= */
  601. DEF_TEST(unary3) {
  602. Checker<Elemwise> checker(handle);
  603. auto make_layouts =
  604. [](const TensorShape& shp,
  605. std::initializer_list<ptrdiff_t> stride) -> TensorLayoutArray {
  606. return {make_layout(shp, stride), {shp, dtype::Float32()}};
  607. };
  608. checker.set_param({Elemwise::Mode::SIN});
  609. checker.exec(make_layouts({2, 2}, {2, 1}));
  610. checker.exec(make_layouts({4}, {3}));
  611. }
  612. DEF_TEST(binary3) {
  613. Checker<Elemwise> checker(handle);
  614. checker.set_param({Elemwise::Mode::ADD});
  615. auto run = [&](const TensorShape& shp0, std::initializer_list<ptrdiff_t> stride0,
  616. const TensorShape& shp1, std::initializer_list<ptrdiff_t> stride1) {
  617. TensorShape shpo;
  618. Elemwise::deduce_shape({shp0, shp1}, shpo);
  619. auto ly0 = make_layout(shp0, stride0), ly1 = make_layout(shp1, stride1),
  620. lyo = TensorLayout{shpo, dtype::Float32()};
  621. checker.execl({ly0, ly1, lyo});
  622. checker.execl({ly1, ly0, lyo});
  623. };
  624. run({2, 2}, {2, 1}, {2, 2}, {2, 1});
  625. run({1}, {1}, {3, 3}, {1, 2});
  626. run({3, 4, 5}, {40, 10, 2}, {1, 4, 1}, {1, 1, 1});
  627. }
  628. DEF_TEST(all_modes) {
  629. // test correctness of all elemwise modes
  630. Checker<Elemwise> checker(handle);
  631. TensorShapeArray shapes;
  632. UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f},
  633. small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f},
  634. abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f},
  635. tanh_rng_f32{-5.f, 5.f};
  636. UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f},
  637. big_nonzero_rng_f32{100.f, 1000.f};
  638. UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2},
  639. shift_rng_i32_i32{0, 31}, shift_rng_i32_i8{0, 7};
  640. UniformIntNonZeroRNG nonzero_rng_i32{1, 100};
  641. using Mode = Elemwise::Mode;
  642. auto should_ignore = [handle](Mode mode) {
  643. MEGDNN_MARK_USED_VAR(mode);
  644. return false;
  645. };
  646. for (int mode_nr = 0; mode_nr < static_cast<int>(Elemwise::Param::MODE_NR_MEMBER);
  647. ++mode_nr) {
  648. auto mode = static_cast<Mode>(mode_nr);
  649. // ignore unsupported modes
  650. if (should_ignore(mode)) {
  651. continue;
  652. }
  653. checker.set_param({mode});
  654. auto&& trait = Elemwise::ModeTrait::from_mode(mode);
  655. shapes.resize(trait.arity + 1);
  656. for (size_t i = 0; i < shapes.size() - 1; ++i) {
  657. shapes[i] = {3, 9, 7};
  658. }
  659. auto do_run = [&](DType dtype, float eps = 1e-3) {
  660. // limit value ranges for some modes
  661. if (mode == Mode::LOG || mode == Mode::LOG1P) {
  662. checker.set_rng(0, &pos_rng_f32);
  663. } else if (mode == Mode::POW) {
  664. checker.set_rng(0, &small_pos_rng_f32);
  665. checker.set_rng(1, &small_rng_f32);
  666. } else if (mode == Mode::EXP || mode == Mode::EXPM1) {
  667. checker.set_rng(0, &small_rng_f32);
  668. } else if (mode == Mode::FAST_TANH) {
  669. checker.set_rng(0, &tanh_rng_f32);
  670. } else if (mode == Mode::LOG_SUM_EXP) {
  671. // check numerical stability with large values
  672. checker.set_rng(0, &big_nonzero_rng_f32);
  673. checker.set_rng(1, &big_nonzero_rng_f32);
  674. } else if (
  675. mode == Mode::ASIN || mode == Mode::ACOS ||
  676. mode == Mode::SIGMOID_GRAD || mode == Mode::TANH_GRAD ||
  677. mode == Mode::ERFINV) {
  678. checker.set_rng(0, &abslt1_rng_f32);
  679. checker.set_rng(1, &default_rng_f32);
  680. } else if (mode == Mode::ERFCINV) {
  681. checker.set_rng(0, &uniform_0_2_rng);
  682. } else if (
  683. mode == Mode::MOD || mode == Mode::TRUE_DIV ||
  684. mode == Mode::FLOOR_DIV) {
  685. if (dtype.category() == DTypeCategory::INT) {
  686. checker.set_rng(0, &default_rng_i32);
  687. checker.set_rng(1, &nonzero_rng_i32);
  688. } else {
  689. checker.set_rng(0, &default_rng_f32);
  690. checker.set_rng(1, &nonzero_rng_f32);
  691. }
  692. } else if (mode == Mode::EQ) {
  693. checker.set_rng(0, &small_rng_i32);
  694. checker.set_rng(1, &small_rng_i32);
  695. } else if (mode == Mode::SHL || mode == Mode::SHR) {
  696. checker.set_rng(0, &default_rng_i32);
  697. if (dtype.size() == 4) {
  698. checker.set_rng(1, &shift_rng_i32_i32);
  699. } else {
  700. megdnn_assert(dtype.size() == 1);
  701. checker.set_rng(1, &shift_rng_i32_i8);
  702. }
  703. } else if (mode == Mode::ATAN2) {
  704. checker.set_rng(0, &nonzero_rng_f32);
  705. checker.set_rng(1, &nonzero_rng_f32);
  706. } else {
  707. RNG* rng;
  708. if (dtype.category() == DTypeCategory::INT) {
  709. rng = &default_rng_i32;
  710. } else {
  711. rng = &default_rng_f32;
  712. }
  713. for (size_t i = 0; i < shapes.size(); ++i) {
  714. checker.set_rng(i, rng);
  715. }
  716. }
  717. checker.set_epsilon(eps);
  718. for (size_t i = 0; i < shapes.size(); ++i) {
  719. checker.set_dtype(i, dtype);
  720. }
  721. EXPECT_NO_THROW(checker.execs(shapes));
  722. if (!::testing::Test::HasFailure() && shapes.size() == 3) {
  723. // channel bcast
  724. shapes[1][0] = 1;
  725. shapes[1][2] = 1;
  726. EXPECT_NO_THROW(checker.execs(shapes));
  727. if (!::testing::Test::HasFailure()) {
  728. // scalar bcast
  729. shapes[1][1] = 1;
  730. EXPECT_NO_THROW(checker.execs(shapes));
  731. }
  732. }
  733. if (::testing::Test::HasFailure()) {
  734. printf("failed on mode=%d(%s) dtype=%s\n", mode_nr, trait.name,
  735. dtype.name());
  736. for (auto&& i : shapes) {
  737. printf("ishape: %s\n", i.to_string().c_str());
  738. }
  739. return false;
  740. }
  741. return true;
  742. };
  743. #define run(args...) \
  744. do { \
  745. if (!do_run(args)) { \
  746. return; \
  747. } \
  748. } while (0)
  749. if (trait.allow_int) {
  750. run(dtype::Int8{});
  751. run(dtype::Int32{});
  752. }
  753. if (trait.allow_float) {
  754. DNN_FLOAT16_SELECT(
  755. run(dtype::Float16{}, mode == Mode::FAST_TANH_GRAD ? 0.5 : 0.05), );
  756. run(dtype::Float32{});
  757. }
  758. }
  759. #undef run
  760. }
  761. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  762. checker.set_param(Mode::_optr) \
  763. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
  764. checker.set_param(Mode::_optr) \
  765. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
  766. checker.set_param(Mode::_optr) \
  767. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});
  768. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
  769. checker.set_param(Mode::_optr) \
  770. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});
  771. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  772. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
  773. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);
  774. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
  775. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
  776. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
  777. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
  778. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
  779. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
  780. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
  781. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
  782. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
  783. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
  784. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
  785. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
  786. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)
  787. DEF_TEST(unary_negative_stride) {
  788. using Mode = ElemwiseForward::Param::Mode;
  789. Checker<ElemwiseForward> checker(handle);
  790. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  791. UniformFloatRNG rng(1e-2, 6e1);
  792. checker.set_rng(0, &rng);
  793. checker.set_epsilon(1e-5);
  794. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
  795. }
  796. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  797. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  798. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  799. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  800. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  801. checker.set_param(Mode::_optr) \
  802. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
  803. {{1, 4, 1}, dtype::Int8()}, \
  804. {}}); \
  805. checker.set_param(Mode::_optr) \
  806. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
  807. {{1, 4, 1}, dtype::Int16()}, \
  808. {}}); \
  809. checker.set_param(Mode::_optr) \
  810. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
  811. {{1, 4, 1}, dtype::Int32()}, \
  812. {}});
  813. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
  814. checker.set_param(Mode::_optr) \
  815. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
  816. {{1, 4, 1}, dtype::Float32()}, \
  817. {}});
  818. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  819. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
  820. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
  821. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
  822. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
  823. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)
  824. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
  825. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
  826. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
  827. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
  828. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
  829. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
  830. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
  831. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
  832. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)
  833. DEF_TEST(binary_negative_stride) {
  834. using Mode = ElemwiseForward::Param::Mode;
  835. Checker<ElemwiseForward> checker(handle);
  836. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  837. UniformFloatRNG rng(1e-2, 2e1);
  838. checker.set_rng(0, &rng);
  839. checker.set_epsilon(1e-5);
  840. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
  841. }
  842. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  843. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  844. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  845. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  846. DEF_TEST(ternary_negative_stride) {
  847. using Mode = ElemwiseForward::Param::Mode;
  848. Checker<ElemwiseForward> checker(handle);
  849. checker.set_param(Mode::FUSE_MUL_ADD3);
  850. checker.execl(
  851. {{{1, 7}, {-7, -1}, dtype::Int8()},
  852. {{1, 7}, {-3, -1}, dtype::Int8()},
  853. {{1, 7}, {-7, -1}, dtype::Int8()},
  854. {}});
  855. checker.execl(
  856. {{{1, 7}, {-7, -1}, dtype::Int16()},
  857. {{1, 7}, {-3, -1}, dtype::Int16()},
  858. {{1, 7}, {-7, -1}, dtype::Int16()},
  859. {}});
  860. checker.execl(
  861. {{{1, 7}, {-7, -1}, dtype::Int32()},
  862. {{1, 7}, {-3, -1}, dtype::Int32()},
  863. {{1, 7}, {-7, -1}, dtype::Int32()},
  864. {}});
  865. UniformFloatRNG rng(1e-2, 2e1);
  866. checker.set_rng(0, &rng);
  867. checker.set_epsilon(1e-5);
  868. checker.execl(
  869. {{{1, 7}, {-7, -1}, dtype::Float32()},
  870. {{1, 7}, {-3, -1}, dtype::Float32()},
  871. {{1, 7}, {-7, -1}, dtype::Float32()},
  872. {}});
  873. }
  874. TEST(TEST_ELEMWISE, MODE_TRAIT) {
  875. using M = Elemwise::Mode;
  876. using T = Elemwise::ModeTrait;
  877. ASSERT_EQ(1u, T::from_mode(M::RELU).arity);
  878. ASSERT_EQ(2u, T::from_mode(M::ADD).arity);
  879. ASSERT_EQ(3u, T::from_mode(M::FUSE_MUL_ADD3).arity);
  880. ASSERT_EQ(4u, T::from_mode(M::FUSE_MUL_ADD4).arity);
  881. ASSERT_TRUE(T::from_mode(M::ADD).commutable);
  882. ASSERT_FALSE(T::from_mode(M::TRUE_DIV).commutable);
  883. ASSERT_TRUE(T::from_mode(M::ADD).allow_int);
  884. ASSERT_FALSE(T::from_mode(M::EXP).allow_int);
  885. ASSERT_TRUE(T::from_mode(M::ADD).allow_float);
  886. ASSERT_FALSE(T::from_mode(M::SHL).allow_float);
  887. ASSERT_TRUE(T::from_mode(M::RMULH).commutable);
  888. ASSERT_FALSE(T::from_mode(M::RMULH).allow_float);
  889. ASSERT_TRUE(T::from_mode(M::XOR).allow_bool);
  890. }
  891. } // namespace elemwise
  892. } // namespace test
  893. } // namespace megdnn
  894. // vim: syntax=cpp.doxygen