You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 39 kB


  1. #include "test/common/elemwise.h"
  2. #include "src/common/utils.cuh"
  3. #include "test/common/checker.h"
  4. #include "test/common/utils.h"
  5. #include "megdnn/oprs/general.h"
  6. #include "test/common/fix_gtest_on_platforms_without_exception.inl"
  7. using namespace megdnn;
  8. using namespace test;
  9. namespace {
  10. void fma3_extra_opr_impl(const TensorNDArray& data) {
  11. megdnn_assert(data.size() == 4);
  12. auto handle = create_cpu_handle(2);
  13. auto opr = handle->create_operator<Elemwise>();
  14. using Mode = Elemwise::Mode;
  15. opr->param().mode = Mode::MUL;
  16. opr->exec({data[0], data[1]}, data[3]);
  17. opr->param().mode = Mode::ADD;
  18. opr->exec({data[2], data[3]}, data[3]);
  19. }
  20. void fma4_extra_opr_impl(const TensorNDArray& data) {
  21. megdnn_assert(data.size() == 5);
  22. std::vector<uint8_t> tmp_storage(data[4].layout.span().dist_byte());
  23. TensorND tmp;
  24. tmp.reset_ptr(tmp_storage.data());
  25. tmp.layout = data[4].layout;
  26. tmp.layout.init_contiguous_stride();
  27. auto handle = create_cpu_handle(2);
  28. auto opr = handle->create_operator<Elemwise>();
  29. using Mode = Elemwise::Mode;
  30. opr->param().mode = Mode::MUL;
  31. opr->exec({data[0], data[1]}, data[4]);
  32. opr->exec({data[2], data[3]}, tmp);
  33. opr->param().mode = Mode::ADD;
  34. opr->exec({tmp, data[4]}, data[4]);
  35. }
  36. TensorLayout make_layout(
  37. const TensorShape& shp, std::initializer_list<ptrdiff_t> stride) {
  38. TensorLayout ret{shp, dtype::Float32()};
  39. megdnn_assert(stride.size() == shp.ndim);
  40. auto idx = 0;
  41. for (auto i : stride)
  42. ret.stride[idx++] = i;
  43. return ret;
  44. }
  45. } // anonymous namespace
  46. namespace megdnn {
  47. namespace test {
  48. namespace elemwise {
  49. #define DEF_TEST(name) \
  50. template <> \
  51. void run_test<name>(Handle * handle)
  52. DEF_TEST(unary) {
  53. using Mode = ElemwiseForward::Param::Mode;
  54. Checker<ElemwiseForward> checker(handle);
  55. checker.set_param(Mode::SIN);
  56. checker.set_dtype(0, dtype::Float32()).execs({{3, 4, 1}, {}});
  57. checker.set_dtype(0, dtype::Float16()).execs({{3, 4, 1}, {}});
  58. }
  59. DEF_TEST(binary_brdcst) {
  60. auto run = [&](DType dtype) {
  61. using Mode = ElemwiseForward::Param::Mode;
  62. Checker<ElemwiseForward> checker(handle);
  63. checker.set_param(Mode::ADD);
  64. checker.set_dtype(0, dtype);
  65. checker.set_dtype(1, dtype);
  66. checker.execs({{3, 1}, {1, 3}, {3, 3}});
  67. {
  68. checker.execs({{10, 11}, {10, 11}, {10, 11}});
  69. //
  70. checker.execs({{2, 3, 4, 5, 6, 7}, {1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}});
  71. checker.execs({{1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}, {2, 3, 4, 5, 6, 7}});
  72. //
  73. checker.execs({{256, 256, 3}, {1, 1, 3}, {256, 256, 3}});
  74. checker.execs({{1, 1, 3}, {256, 256, 3}, {256, 256, 3}});
  75. //
  76. checker.execs({{8, 1, 6, 1}, {1, 7, 1, 5}, {8, 7, 6, 5}});
  77. checker.execs({{1, 7, 1, 5}, {8, 1, 6, 1}, {8, 7, 6, 5}});
  78. //
  79. checker.execs({{5, 4}, {1, 1}, {5, 4}});
  80. checker.execs({{1, 1}, {5, 4}, {5, 4}});
  81. //
  82. checker.execs({{5, 4}, {1, 4}, {5, 4}});
  83. checker.execs({{1, 4}, {5, 4}, {5, 4}});
  84. //
  85. checker.execs({{15, 3, 5}, {15, 1, 5}, {15, 3, 5}});
  86. checker.execs({{15, 1, 5}, {15, 3, 5}, {15, 3, 5}});
  87. //
  88. checker.execs({{15, 3, 5}, {1, 3, 5}, {15, 3, 5}});
  89. checker.execs({{1, 3, 5}, {15, 3, 5}, {15, 3, 5}});
  90. //
  91. checker.execs({{15, 3, 5}, {1, 3, 1}, {15, 3, 5}});
  92. checker.execs({{1, 3, 1}, {15, 3, 5}, {15, 3, 5}});
  93. //
  94. checker.execs({{3, 1}, {1, 4}, {3, 4}});
  95. // numpy broadcast
  96. checker.execs({{2, 3, 1, 5}, {4, 5}, {2, 3, 4, 5}});
  97. checker.execs({{3, 1, 1}, {4, 5}, {3, 4, 5}});
  98. }
  99. {
  100. // 1d
  101. {
  102. auto n = 1000u;
  103. checker.execs({{n}, {n}, {n}});
  104. checker.execs({{1}, {n}, {n}});
  105. checker.execs({{n}, {1}, {n}});
  106. }
  107. // 2d
  108. {
  109. auto m = 200u, n = 100u;
  110. auto collapse = [](size_t n, bool is_collapsed) {
  111. return is_collapsed ? 1u : n;
  112. };
  113. for (auto msk = 0u; msk < 16; ++msk) {
  114. checker.execs(
  115. {{collapse(m, msk & 1), collapse(n, msk & 2)},
  116. {collapse(m, msk & 4), collapse(n, msk & 8)},
  117. {}});
  118. }
  119. }
  120. // nd
  121. {
  122. checker.execs({{2, 3, 4, 5, 6}, {1, 3, 1, 5, 6}, {2, 3, 4, 5, 6}});
  123. checker.execs({{2, 3, 4, 5, 6}, {2, 1, 4, 1, 6}, {2, 3, 4, 5, 6}});
  124. }
  125. }
  126. };
  127. run(dtype::Float32());
  128. // run(dtype::Float16());
  129. }
  130. DEF_TEST(binary_non_contig) {
  131. using Mode = ElemwiseForward::Param::Mode;
  132. Checker<ElemwiseForward> checker(handle);
  133. checker.set_param(Mode::ADD);
  134. TensorLayout ly{{2, 3}, dtype::Float32()};
  135. ly.stride[0] = 4;
  136. checker.execl({ly, ly, {{2, 3}, dtype::Float32()}});
  137. }
  138. DEF_TEST(ternary) {
  139. using Mode = ElemwiseForward::Param::Mode;
  140. Checker<ElemwiseForward> checker(handle);
  141. checker.set_param(Mode::COND_LEQ_MOV);
  142. checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  143. checker.set_dtype(0, dtype::Float32())
  144. .set_dtype(1, dtype::Float32())
  145. .set_dtype(2, dtype::Float32())
  146. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  147. checker.set_dtype(0, dtype::Float16())
  148. .set_dtype(1, dtype::Float16())
  149. .set_dtype(2, dtype::Float16())
  150. .set_dtype(3, dtype::Float16())
  151. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  152. checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
  153. checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
  154. ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
  155. ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
  156. }
  157. DEF_TEST(ternary_non_contig) {
  158. using Mode = ElemwiseForward::Param::Mode;
  159. Checker<ElemwiseForward> checker(handle);
  160. checker.set_param(Mode::COND_LEQ_MOV);
  161. TensorLayout ly{{2, 3}, dtype::Float32()};
  162. ly.stride[0] = 4;
  163. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  164. }
  165. DEF_TEST(fuse_mul_add3) {
  166. using Mode = ElemwiseForward::Param::Mode;
  167. Checker<ElemwiseForward> checker(handle);
  168. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  169. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  170. const TensorShape& s2) {
  171. TensorShape dest;
  172. dest.ndim = s0.ndim;
  173. for (size_t i = 0; i < dest.ndim; ++i) {
  174. auto a = i < s0.ndim ? s0[i] : 1;
  175. auto b = i < s1.ndim ? s1[i] : 1;
  176. dest[i] = std::max(a, b);
  177. }
  178. return TensorShapeArray{s0, s1, s2, dest};
  179. };
  180. checker.exec(make_shape({2, 1}, {2, 2}, {2, 2}));
  181. checker.exec(make_shape({2, 2}, {2, 1}, {2, 2}));
  182. checker.exec(make_shape({2, 2}, {2, 2}, {1}));
  183. checker.exec(make_shape({3, 1}, {1, 3}, {3, 1}));
  184. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}, {1}));
  185. checker.exec(make_shape({1, 1, 3}, {5, 8, 1}, {5, 8, 1}));
  186. checker.exec(make_shape({1, 192, 9, 16}, {1}, {1, 192, 9, 16}));
  187. }
  188. DEF_TEST(fuse_mul_add3_non_contig) {
  189. using Mode = ElemwiseForward::Param::Mode;
  190. Checker<ElemwiseForward> checker(handle);
  191. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  192. TensorLayout ly{{2, 3}, dtype::Float32()};
  193. ly.stride[0] = 4;
  194. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  195. }
  196. DEF_TEST(fuse_mul_add4) {
  197. using Mode = ElemwiseForward::Param::Mode;
  198. Checker<ElemwiseForward> checker(handle);
  199. checker.set_param(Mode::FUSE_MUL_ADD4).set_extra_opr_impl(fma4_extra_opr_impl);
  200. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  201. bool swap = false) {
  202. TensorShape dest;
  203. dest.ndim = s0.ndim;
  204. for (size_t i = 0; i < dest.ndim; ++i) {
  205. auto a = i < s0.ndim ? s0[i] : 1;
  206. auto b = i < s1.ndim ? s1[i] : 1;
  207. dest[i] = std::max(a, b);
  208. }
  209. TensorShapeArray ret{s0, s1, s0, s1, dest};
  210. if (swap)
  211. std::swap(ret[2], ret[3]);
  212. return ret;
  213. };
  214. checker.exec(make_shape({2, 2}, {2, 2}));
  215. checker.exec(make_shape({3, 1}, {1, 3}));
  216. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}));
  217. checker.exec(make_shape({4, 2}, {1, 2}, true));
  218. }
  219. DEF_TEST(rmulh) {
  220. using Mode = ElemwiseForward::Param::Mode;
  221. Checker<ElemwiseForward> checker(handle);
  222. auto run_for_dtype = [&checker](auto dtype) {
  223. auto minv = DTypeTrait<decltype(dtype)>::min();
  224. auto maxv = DTypeTrait<decltype(dtype)>::max();
  225. UniformIntRNG rng0{minv, maxv};
  226. UniformIntRNG rngM{(maxv >> 1) + 1, maxv};
  227. checker.set_param({Mode::RMULH})
  228. .set_dtype(0, dtype)
  229. .set_dtype(1, dtype)
  230. .set_dtype(2, dtype)
  231. .set_rng(0, &rng0)
  232. .set_rng(1, &rngM);
  233. checker.execs({{7, 9, 11, 13}, {1}, {}})
  234. .execs({{16, 3, 256, 256}, {1}, {}})
  235. .execs({{2, 3, 1, 7}, {2, 3, 1, 7}, {}})
  236. .execs({{9, 5, 4}, {1, 5, 1}, {}})
  237. .execs({{233}, {1}, {}});
  238. };
  239. run_for_dtype(dtype::Int8());
  240. run_for_dtype(dtype::Int16());
  241. run_for_dtype(dtype::Int32());
  242. }
  243. /* ============= migrated from x86 tests ============= */
  244. #define UNARY_TEST_CASE(_optr) \
  245. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  246. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  247. #define BUILD_UNARY_TEST_CASE_INT \
  248. UNARY_TEST_CASE(RELU) \
  249. UNARY_TEST_CASE(ABS)
  250. #define BUILD_UNARY_TEST_CASE_FLOAT \
  251. UNARY_TEST_CASE(ABS) \
  252. UNARY_TEST_CASE(LOG) \
  253. UNARY_TEST_CASE(COS) \
  254. UNARY_TEST_CASE(SIN) \
  255. UNARY_TEST_CASE(FLOOR) \
  256. UNARY_TEST_CASE(CEIL) \
  257. UNARY_TEST_CASE(SIGMOID) \
  258. UNARY_TEST_CASE(EXP) \
  259. UNARY_TEST_CASE(TANH) \
  260. UNARY_TEST_CASE(FAST_TANH) \
  261. UNARY_TEST_CASE(RELU) \
  262. UNARY_TEST_CASE(ROUND)
  263. DEF_TEST(unary1) {
  264. using Mode = ElemwiseForward::Param::Mode;
  265. Checker<ElemwiseForward> checker(handle);
  266. // case int
  267. checker.set_dtype(0, dtype::Int8());
  268. BUILD_UNARY_TEST_CASE_INT
  269. checker.set_dtype(0, dtype::Int16());
  270. BUILD_UNARY_TEST_CASE_INT
  271. checker.set_dtype(0, dtype::Int32());
  272. BUILD_UNARY_TEST_CASE_INT
  273. // case float
  274. UniformFloatRNG rng(1e-2, 6e1);
  275. checker.set_rng(0, &rng);
  276. checker.set_epsilon(1e-5);
  277. checker.set_dtype(0, dtype::Float32());
  278. BUILD_UNARY_TEST_CASE_FLOAT
  279. }
  280. #undef UNARY_TEST_CASE
  281. #undef BUILD_UNARY_TEST_CASE_INT
  282. #undef BUILD_UNARY_TEST_CASE_FLOAT
  283. #define BINARY_TEST_CASE(_optr) \
  284. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  285. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  286. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  287. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  288. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  289. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  290. #define BUILD_BINARY_TEST_CASE \
  291. BINARY_TEST_CASE(MIN) \
  292. BINARY_TEST_CASE(MAX)
  293. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  294. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  295. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  296. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  297. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  298. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  299. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  300. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  301. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  302. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  303. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  304. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  305. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  306. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  307. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  308. BINARY_COMPLATE_TEST_CASE(ADD) \
  309. BINARY_COMPLATE_TEST_CASE(MUL) \
  310. BINARY_COMPLATE_TEST_CASE(MAX) \
  311. BINARY_COMPLATE_TEST_CASE(MIN) \
  312. BINARY_COMPLATE_TEST_CASE(SUB)
  313. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  314. BINARY_COMPLATE_TEST_CASE(POW) \
  315. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  316. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  317. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  318. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU) \
  319. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_H_SWISH) \
  320. BINARY_COMPLATE_TEST_CASE(FAST_TANH_GRAD) \
  321. BINARY_COMPLATE_TEST_CASE(H_SWISH_GRAD)
  322. DEF_TEST(binary1) {
  323. using Mode = ElemwiseForward::Param::Mode;
  324. Checker<ElemwiseForward> checker(handle);
  325. // case float
  326. UniformFloatRNG rng(1e-5, 7e1);
  327. checker.set_rng(0, &rng);
  328. checker.set_epsilon(1e-5);
  329. checker.set_dtype(0, dtype::Float32());
  330. checker.set_dtype(1, dtype::Float32());
  331. BUILD_BINARY_COMPLATE_TEST_CASE
  332. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  333. // case int
  334. checker.set_dtype(0, dtype::Int8());
  335. checker.set_dtype(1, dtype::Int8());
  336. BUILD_BINARY_TEST_CASE
  337. BUILD_BINARY_COMPLATE_TEST_CASE
  338. checker.set_dtype(0, dtype::Int16());
  339. checker.set_dtype(1, dtype::Int16());
  340. BUILD_BINARY_TEST_CASE
  341. BUILD_BINARY_COMPLATE_TEST_CASE
  342. checker.set_dtype(0, dtype::Int32());
  343. checker.set_dtype(1, dtype::Int32());
  344. BUILD_BINARY_TEST_CASE
  345. BUILD_BINARY_COMPLATE_TEST_CASE
  346. }
  347. #undef BINARY_TEST_CASE
  348. #undef BUILD_BINARY_TEST_CASE
  349. #undef BINARY_COMPLATE_TEST_CASE
  350. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  351. #undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  352. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  353. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  354. checker.set_param(Mode::_optr) \
  355. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  356. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  357. checker.set_param(Mode::_optr) \
  358. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  359. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  360. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  361. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  362. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  363. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  364. DEF_TEST(ternary1) {
  365. using Mode = ElemwiseForward::Param::Mode;
  366. Checker<ElemwiseForward> checker(handle);
  367. // case int
  368. checker.set_dtype(0, dtype::Int8());
  369. checker.set_dtype(1, dtype::Int8());
  370. checker.set_dtype(2, dtype::Int8());
  371. // BUILD_TERNARY_TEST_CASE
  372. BUILD_TERNARY_COMPLATE_TEST_CASE
  373. checker.set_dtype(0, dtype::Int16());
  374. checker.set_dtype(1, dtype::Int16());
  375. checker.set_dtype(2, dtype::Int16());
  376. // BUILD_TERNARY_TEST_CASE
  377. BUILD_TERNARY_COMPLATE_TEST_CASE
  378. checker.set_dtype(0, dtype::Int32());
  379. checker.set_dtype(1, dtype::Int32());
  380. checker.set_dtype(2, dtype::Int32());
  381. // BUILD_TERNARY_TEST_CASE
  382. BUILD_TERNARY_COMPLATE_TEST_CASE
  383. // case float
  384. UniformFloatRNG rng(1e-5, 7e1);
  385. checker.set_rng(0, &rng);
  386. checker.set_epsilon(1e-5);
  387. checker.set_dtype(0, dtype::Float32());
  388. checker.set_dtype(1, dtype::Float32());
  389. checker.set_dtype(2, dtype::Float32());
  390. // BUILD_TERNARY_TEST_CASE
  391. BUILD_TERNARY_COMPLATE_TEST_CASE
  392. // TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  393. }
  394. #undef TERNARY_COMPLATE_TEST_CASE
  395. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  396. /* ============= migrated from arm tests ============= */
  397. #define UNARY_TEST_CASE(_optr) \
  398. checker.set_param(Mode::_optr).execs({{1, 129}, {}}); \
  399. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  400. #define BUILD_UNARY_TEST_CASE_INT \
  401. UNARY_TEST_CASE(RELU) \
  402. UNARY_TEST_CASE(ABS) \
  403. UNARY_TEST_CASE(NEGATE)
  404. #define BUILD_UNARY_TEST_CASE_FLOAT \
  405. BUILD_UNARY_TEST_CASE_INT \
  406. UNARY_TEST_CASE(SIGMOID) \
  407. UNARY_TEST_CASE(EXP) \
  408. UNARY_TEST_CASE(TANH) \
  409. UNARY_TEST_CASE(FAST_TANH) \
  410. UNARY_TEST_CASE(H_SWISH)
  411. DEF_TEST(unary2) {
  412. using Mode = ElemwiseForward::Param::Mode;
  413. Checker<ElemwiseForward> checker(handle);
  414. // case int
  415. checker.set_dtype(0, dtype::Int8());
  416. BUILD_UNARY_TEST_CASE_INT
  417. checker.set_dtype(0, dtype::Int16());
  418. BUILD_UNARY_TEST_CASE_INT
  419. checker.set_dtype(0, dtype::Int32());
  420. BUILD_UNARY_TEST_CASE_INT
  421. // case float
  422. {
  423. UniformFloatRNG rng(1e-5, 7e1);
  424. checker.set_rng(0, &rng);
  425. checker.set_epsilon(1e-5);
  426. checker.set_dtype(0, dtype::Float32());
  427. BUILD_UNARY_TEST_CASE_FLOAT
  428. }
  429. {
  430. UniformFloatRNG rng(1e-2, 1e1);
  431. checker.set_rng(0, &rng);
  432. checker.set_epsilon(6e-3);
  433. checker.set_dtype(0, dtype::Float16());
  434. BUILD_UNARY_TEST_CASE_FLOAT
  435. }
  436. // tanh NaN bug case
  437. {
  438. UniformFloatRNG rng(100, 200);
  439. checker.set_rng(0, &rng);
  440. checker.set_epsilon(1e-5);
  441. checker.set_dtype(0, dtype::Float32());
  442. checker.set_param(Mode::TANH).execs({{1, 1025}, {}});
  443. checker.set_param(Mode::TANH).execs({{1, 7}, {}});
  444. }
  445. }
  446. #undef UNARY_TEST_CASE
  447. #undef BUILD_UNARY_TEST_CASE_INT
  448. #undef BUILD_UNARY_TEST_CASE_FLOAT
  449. #define BINARY_TEST_CASE(_optr) \
  450. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  451. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  452. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  453. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  454. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  455. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  456. #define BUILD_BINARY_TEST_CASE \
  457. BINARY_TEST_CASE(MIN) \
  458. BINARY_TEST_CASE(MAX)
  459. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  460. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  461. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  462. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  463. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  464. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  465. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  466. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  467. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  468. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  469. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  470. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  471. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  472. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  473. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  474. BINARY_COMPLATE_TEST_CASE(ADD) \
  475. BINARY_COMPLATE_TEST_CASE(MUL) \
  476. BINARY_COMPLATE_TEST_CASE(MAX) \
  477. BINARY_COMPLATE_TEST_CASE(MIN) \
  478. BINARY_COMPLATE_TEST_CASE(SUB) \
  479. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  480. DEF_TEST(binary2) {
  481. using Mode = ElemwiseForward::Param::Mode;
  482. Checker<ElemwiseForward> checker(handle);
  483. // case float
  484. UniformFloatRNG rng(1e-5, 7e1);
  485. checker.set_rng(0, &rng);
  486. checker.set_epsilon(1e-5);
  487. checker.set_dtype(0, dtype::Float32());
  488. checker.set_dtype(1, dtype::Float32());
  489. BUILD_BINARY_COMPLATE_TEST_CASE
  490. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID)
  491. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH)
  492. // case int
  493. checker.set_dtype(0, dtype::Int8());
  494. checker.set_dtype(1, dtype::Int8());
  495. // BUILD_BINARY_TEST_CASE
  496. BUILD_BINARY_COMPLATE_TEST_CASE
  497. checker.set_dtype(0, dtype::Int16());
  498. checker.set_dtype(1, dtype::Int16());
  499. // BUILD_BINARY_TEST_CASE
  500. BUILD_BINARY_COMPLATE_TEST_CASE
  501. checker.set_dtype(0, dtype::Int32());
  502. checker.set_dtype(1, dtype::Int32());
  503. BUILD_BINARY_TEST_CASE
  504. BUILD_BINARY_COMPLATE_TEST_CASE
  505. // case float
  506. checker.set_rng(0, &rng);
  507. checker.set_epsilon(1e-5);
  508. checker.set_dtype(0, dtype::Float32());
  509. checker.set_dtype(1, dtype::Float32());
  510. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  511. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  512. // commutable
  513. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  514. BUILD_BINARY_TEST_CASE
  515. BUILD_BINARY_COMPLATE_TEST_CASE
  516. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  517. {
  518. UniformFloatRNG rng(1e-3, 3e1);
  519. checker.set_rng(0, &rng);
  520. checker.set_rng(1, &rng);
  521. checker.set_epsilon(1e-3);
  522. checker.set_dtype(0, dtype::Float16());
  523. checker.set_dtype(1, dtype::Float16());
  524. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  525. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  526. BUILD_BINARY_TEST_CASE
  527. BUILD_BINARY_COMPLATE_TEST_CASE
  528. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  529. // commutable
  530. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  531. }
  532. }
  533. #undef BINARY_TEST_CASE
  534. #undef BUILD_BINARY_TEST_CASE
  535. #undef BINARY_COMPLATE_TEST_CASE
  536. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  537. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  538. checker.set_param(Mode::_optr) \
  539. .execs({{1, 123, 1}, {300, 123, 253}, {300, 123, 253}, {}}); \
  540. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  541. checker.set_param(Mode::_optr) \
  542. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  543. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  544. checker.set_param(Mode::_optr) \
  545. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  546. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  547. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  548. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  549. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
  550. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {1, 1, 1}, {3, 4, 1}, {}});
  551. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  552. DEF_TEST(ternary2) {
  553. using Mode = ElemwiseForward::Param::Mode;
  554. Checker<ElemwiseForward> checker(handle);
  555. // case int
  556. checker.set_dtype(0, dtype::Int8());
  557. checker.set_dtype(1, dtype::Int8());
  558. checker.set_dtype(2, dtype::Int8());
  559. BUILD_TERNARY_COMPLATE_TEST_CASE
  560. checker.set_dtype(0, dtype::Int16());
  561. checker.set_dtype(1, dtype::Int16());
  562. checker.set_dtype(2, dtype::Int16());
  563. BUILD_TERNARY_COMPLATE_TEST_CASE
  564. checker.set_dtype(0, dtype::Int32());
  565. checker.set_dtype(1, dtype::Int32());
  566. checker.set_dtype(2, dtype::Int32());
  567. BUILD_TERNARY_COMPLATE_TEST_CASE
  568. // case float
  569. UniformFloatRNG rng(1e-5, 7e1);
  570. checker.set_rng(0, &rng);
  571. checker.set_epsilon(1e-5);
  572. checker.set_dtype(0, dtype::Float32());
  573. checker.set_dtype(1, dtype::Float32());
  574. checker.set_dtype(2, dtype::Float32());
  575. BUILD_TERNARY_COMPLATE_TEST_CASE
  576. {
  577. UniformFloatRNG rng(1e-3, 3e1);
  578. checker.set_rng(0, &rng);
  579. checker.set_rng(1, &rng);
  580. checker.set_rng(2, &rng);
  581. checker.set_epsilon(1e-3);
  582. checker.set_dtype(0, dtype::Float16());
  583. checker.set_dtype(1, dtype::Float16());
  584. checker.set_dtype(2, dtype::Float16());
  585. BUILD_TERNARY_COMPLATE_TEST_CASE
  586. }
  587. }
  588. #undef TERNARY_COMPLATE_TEST_CASE
  589. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  590. /* ============= migrated from fallback tests ============= */
  591. DEF_TEST(unary3) {
  592. Checker<Elemwise> checker(handle);
  593. auto make_layouts =
  594. [](const TensorShape& shp,
  595. std::initializer_list<ptrdiff_t> stride) -> TensorLayoutArray {
  596. return {make_layout(shp, stride), {shp, dtype::Float32()}};
  597. };
  598. checker.set_param({Elemwise::Mode::SIN});
  599. checker.exec(make_layouts({2, 2}, {2, 1}));
  600. checker.exec(make_layouts({4}, {3}));
  601. }
  602. DEF_TEST(binary3) {
  603. Checker<Elemwise> checker(handle);
  604. checker.set_param({Elemwise::Mode::ADD});
  605. auto run = [&](const TensorShape& shp0, std::initializer_list<ptrdiff_t> stride0,
  606. const TensorShape& shp1, std::initializer_list<ptrdiff_t> stride1) {
  607. TensorShape shpo;
  608. Elemwise::deduce_shape({shp0, shp1}, shpo);
  609. auto ly0 = make_layout(shp0, stride0), ly1 = make_layout(shp1, stride1),
  610. lyo = TensorLayout{shpo, dtype::Float32()};
  611. checker.execl({ly0, ly1, lyo});
  612. checker.execl({ly1, ly0, lyo});
  613. };
  614. run({2, 2}, {2, 1}, {2, 2}, {2, 1});
  615. run({1}, {1}, {3, 3}, {1, 2});
  616. run({3, 4, 5}, {40, 10, 2}, {1, 4, 1}, {1, 1, 1});
  617. }
  618. DEF_TEST(all_modes) {
  619. // test correctness of all elemwise modes
  620. Checker<Elemwise> checker(handle);
  621. TensorShapeArray shapes;
  622. UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f},
  623. small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f},
  624. abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f},
  625. tanh_rng_f32{-5.f, 5.f};
  626. UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f},
  627. big_nonzero_rng_f32{100.f, 1000.f};
  628. UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2},
  629. shift_rng_i32_i32{0, 31}, shift_rng_i32_i8{0, 7};
  630. UniformIntNonZeroRNG nonzero_rng_i32{1, 100};
  631. using Mode = Elemwise::Mode;
  632. auto should_ignore = [handle](Mode mode) {
  633. MEGDNN_MARK_USED_VAR(mode);
  634. return false;
  635. };
  636. for (int mode_nr = 0; mode_nr < static_cast<int>(Elemwise::Param::MODE_NR_MEMBER);
  637. ++mode_nr) {
  638. auto mode = static_cast<Mode>(mode_nr);
  639. // ignore unsupported modes
  640. if (should_ignore(mode)) {
  641. continue;
  642. }
  643. checker.set_param({mode});
  644. auto&& trait = Elemwise::ModeTrait::from_mode(mode);
  645. shapes.resize(trait.arity + 1);
  646. for (size_t i = 0; i < shapes.size() - 1; ++i) {
  647. shapes[i] = {3, 9, 7};
  648. }
  649. //! NOTE: force set output layout to empty to trigger layout deduce
  650. shapes[shapes.size() - 1] = {};
  651. auto do_run = [&](DType dtype, float eps = 1e-3) {
  652. // limit value ranges for some modes
  653. if (mode == Mode::LOG || mode == Mode::LOG1P) {
  654. checker.set_rng(0, &pos_rng_f32);
  655. } else if (mode == Mode::POW) {
  656. checker.set_rng(0, &small_pos_rng_f32);
  657. checker.set_rng(1, &small_rng_f32);
  658. } else if (mode == Mode::EXP || mode == Mode::EXPM1) {
  659. checker.set_rng(0, &small_rng_f32);
  660. } else if (mode == Mode::FAST_TANH) {
  661. checker.set_rng(0, &tanh_rng_f32);
  662. } else if (mode == Mode::LOG_SUM_EXP) {
  663. // check numerical stability with large values
  664. checker.set_rng(0, &big_nonzero_rng_f32);
  665. checker.set_rng(1, &big_nonzero_rng_f32);
  666. } else if (
  667. mode == Mode::ASIN || mode == Mode::ACOS ||
  668. mode == Mode::SIGMOID_GRAD || mode == Mode::TANH_GRAD ||
  669. mode == Mode::ERFINV) {
  670. checker.set_rng(0, &abslt1_rng_f32);
  671. checker.set_rng(1, &default_rng_f32);
  672. } else if (mode == Mode::ERFCINV) {
  673. checker.set_rng(0, &uniform_0_2_rng);
  674. } else if (
  675. mode == Mode::MOD || mode == Mode::TRUE_DIV ||
  676. mode == Mode::FLOOR_DIV) {
  677. if (dtype.category() == DTypeCategory::INT) {
  678. checker.set_rng(0, &default_rng_i32);
  679. checker.set_rng(1, &nonzero_rng_i32);
  680. } else {
  681. checker.set_rng(0, &default_rng_f32);
  682. checker.set_rng(1, &nonzero_rng_f32);
  683. }
  684. } else if (mode == Mode::EQ) {
  685. checker.set_rng(0, &small_rng_i32);
  686. checker.set_rng(1, &small_rng_i32);
  687. } else if (mode == Mode::SHL || mode == Mode::SHR) {
  688. checker.set_rng(0, &default_rng_i32);
  689. if (dtype.size() == 4) {
  690. checker.set_rng(1, &shift_rng_i32_i32);
  691. } else {
  692. megdnn_assert(dtype.size() == 1);
  693. checker.set_rng(1, &shift_rng_i32_i8);
  694. }
  695. } else if (mode == Mode::ATAN2) {
  696. checker.set_rng(0, &nonzero_rng_f32);
  697. checker.set_rng(1, &nonzero_rng_f32);
  698. } else {
  699. RNG* rng;
  700. if (dtype.category() == DTypeCategory::INT) {
  701. rng = &default_rng_i32;
  702. } else {
  703. rng = &default_rng_f32;
  704. }
  705. for (size_t i = 0; i < shapes.size(); ++i) {
  706. checker.set_rng(i, rng);
  707. }
  708. }
  709. checker.set_epsilon(eps);
  710. for (size_t i = 0; i < shapes.size(); ++i) {
  711. checker.set_dtype(i, dtype);
  712. }
  713. EXPECT_NO_THROW(checker.execs(shapes));
  714. if (!::testing::Test::HasFailure() && shapes.size() == 3) {
  715. // channel bcast
  716. shapes[1][0] = 1;
  717. shapes[1][2] = 1;
  718. EXPECT_NO_THROW(checker.execs(shapes));
  719. if (!::testing::Test::HasFailure()) {
  720. // scalar bcast
  721. shapes[1][1] = 1;
  722. EXPECT_NO_THROW(checker.execs(shapes));
  723. }
  724. }
  725. if (::testing::Test::HasFailure()) {
  726. printf("failed on mode=%d(%s) dtype=%s\n", mode_nr, trait.name,
  727. dtype.name());
  728. for (auto&& i : shapes) {
  729. printf("ishape: %s\n", i.to_string().c_str());
  730. }
  731. return false;
  732. }
  733. return true;
  734. };
  735. #define run(args...) \
  736. do { \
  737. if (!do_run(args)) { \
  738. return; \
  739. } \
  740. } while (0)
  741. if (trait.allow_int) {
  742. run(dtype::Int8{});
  743. run(dtype::Int32{});
  744. }
  745. if (trait.allow_float) {
  746. DNN_FLOAT16_SELECT(
  747. run(dtype::Float16{}, mode == Mode::FAST_TANH_GRAD ? 0.5 : 0.05), );
  748. run(dtype::Float32{});
  749. }
  750. }
  751. #undef run
  752. }
  753. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  754. checker.set_param(Mode::_optr) \
  755. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
  756. checker.set_param(Mode::_optr) \
  757. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
  758. checker.set_param(Mode::_optr) \
  759. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});
  760. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
  761. checker.set_param(Mode::_optr) \
  762. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});
  763. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  764. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
  765. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);
  766. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
  767. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
  768. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
  769. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
  770. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
  771. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
  772. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
  773. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
  774. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
  775. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
  776. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
  777. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
  778. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)
  779. DEF_TEST(unary_negative_stride) {
  780. using Mode = ElemwiseForward::Param::Mode;
  781. Checker<ElemwiseForward> checker(handle);
  782. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  783. UniformFloatRNG rng(1e-2, 6e1);
  784. checker.set_rng(0, &rng);
  785. checker.set_epsilon(1e-5);
  786. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
  787. }
  788. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  789. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  790. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  791. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  792. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  793. checker.set_param(Mode::_optr) \
  794. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
  795. {{1, 4, 1}, dtype::Int8()}, \
  796. {}}); \
  797. checker.set_param(Mode::_optr) \
  798. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
  799. {{1, 4, 1}, dtype::Int16()}, \
  800. {}}); \
  801. checker.set_param(Mode::_optr) \
  802. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
  803. {{1, 4, 1}, dtype::Int32()}, \
  804. {}});
  805. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
  806. checker.set_param(Mode::_optr) \
  807. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
  808. {{1, 4, 1}, dtype::Float32()}, \
  809. {}});
  810. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  811. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
  812. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
  813. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
  814. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
  815. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)
  816. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
  817. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
  818. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
  819. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
  820. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
  821. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
  822. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
  823. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
  824. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)
  825. DEF_TEST(binary_negative_stride) {
  826. using Mode = ElemwiseForward::Param::Mode;
  827. Checker<ElemwiseForward> checker(handle);
  828. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  829. UniformFloatRNG rng(1e-2, 2e1);
  830. checker.set_rng(0, &rng);
  831. checker.set_epsilon(1e-5);
  832. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
  833. }
  834. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  835. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  836. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  837. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  838. DEF_TEST(ternary_negative_stride) {
  839. using Mode = ElemwiseForward::Param::Mode;
  840. Checker<ElemwiseForward> checker(handle);
  841. checker.set_param(Mode::FUSE_MUL_ADD3);
  842. checker.execl(
  843. {{{1, 7}, {-7, -1}, dtype::Int8()},
  844. {{1, 7}, {-3, -1}, dtype::Int8()},
  845. {{1, 7}, {-7, -1}, dtype::Int8()},
  846. {}});
  847. checker.execl(
  848. {{{1, 7}, {-7, -1}, dtype::Int16()},
  849. {{1, 7}, {-3, -1}, dtype::Int16()},
  850. {{1, 7}, {-7, -1}, dtype::Int16()},
  851. {}});
  852. checker.execl(
  853. {{{1, 7}, {-7, -1}, dtype::Int32()},
  854. {{1, 7}, {-3, -1}, dtype::Int32()},
  855. {{1, 7}, {-7, -1}, dtype::Int32()},
  856. {}});
  857. UniformFloatRNG rng(1e-2, 2e1);
  858. checker.set_rng(0, &rng);
  859. checker.set_epsilon(1e-5);
  860. checker.execl(
  861. {{{1, 7}, {-7, -1}, dtype::Float32()},
  862. {{1, 7}, {-3, -1}, dtype::Float32()},
  863. {{1, 7}, {-7, -1}, dtype::Float32()},
  864. {}});
  865. }
  866. TEST(TEST_ELEMWISE, MODE_TRAIT) {
  867. using M = Elemwise::Mode;
  868. using T = Elemwise::ModeTrait;
  869. ASSERT_EQ(1u, T::from_mode(M::RELU).arity);
  870. ASSERT_EQ(2u, T::from_mode(M::ADD).arity);
  871. ASSERT_EQ(3u, T::from_mode(M::FUSE_MUL_ADD3).arity);
  872. ASSERT_EQ(4u, T::from_mode(M::FUSE_MUL_ADD4).arity);
  873. ASSERT_TRUE(T::from_mode(M::ADD).commutable);
  874. ASSERT_FALSE(T::from_mode(M::TRUE_DIV).commutable);
  875. ASSERT_TRUE(T::from_mode(M::ADD).allow_int);
  876. ASSERT_FALSE(T::from_mode(M::EXP).allow_int);
  877. ASSERT_TRUE(T::from_mode(M::ADD).allow_float);
  878. ASSERT_FALSE(T::from_mode(M::SHL).allow_float);
  879. ASSERT_TRUE(T::from_mode(M::RMULH).commutable);
  880. ASSERT_FALSE(T::from_mode(M::RMULH).allow_float);
  881. ASSERT_TRUE(T::from_mode(M::XOR).allow_bool);
  882. }
  883. } // namespace elemwise
  884. } // namespace test
  885. } // namespace megdnn
  886. // vim: syntax=cpp.doxygen