You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 40 kB


  1. #include "test/common/elemwise.h"
  2. #include "src/common/utils.cuh"
  3. #include "test/common/checker.h"
  4. #include "test/common/utils.h"
  5. #include "megdnn/oprs/general.h"
  6. #include "test/common/fix_gtest_on_platforms_without_exception.inl"
  7. using namespace megdnn;
  8. using namespace test;
  9. namespace {
  10. void fma3_extra_opr_impl(const TensorNDArray& data) {
  11. megdnn_assert(data.size() == 4);
  12. auto handle = create_cpu_handle(2);
  13. auto opr = handle->create_operator<Elemwise>();
  14. using Mode = Elemwise::Mode;
  15. opr->param().mode = Mode::MUL;
  16. opr->exec({data[0], data[1]}, data[3]);
  17. opr->param().mode = Mode::ADD;
  18. opr->exec({data[2], data[3]}, data[3]);
  19. }
  20. void fma4_extra_opr_impl(const TensorNDArray& data) {
  21. megdnn_assert(data.size() == 5);
  22. std::vector<uint8_t> tmp_storage(data[4].layout.span().dist_byte());
  23. TensorND tmp;
  24. tmp.reset_ptr(tmp_storage.data());
  25. tmp.layout = data[4].layout;
  26. tmp.layout.init_contiguous_stride();
  27. auto handle = create_cpu_handle(2);
  28. auto opr = handle->create_operator<Elemwise>();
  29. using Mode = Elemwise::Mode;
  30. opr->param().mode = Mode::MUL;
  31. opr->exec({data[0], data[1]}, data[4]);
  32. opr->exec({data[2], data[3]}, tmp);
  33. opr->param().mode = Mode::ADD;
  34. opr->exec({tmp, data[4]}, data[4]);
  35. }
  36. TensorLayout make_layout(
  37. const TensorShape& shp, std::initializer_list<ptrdiff_t> stride) {
  38. TensorLayout ret{shp, dtype::Float32()};
  39. megdnn_assert(stride.size() == shp.ndim);
  40. auto idx = 0;
  41. for (auto i : stride)
  42. ret.stride[idx++] = i;
  43. return ret;
  44. }
  45. } // anonymous namespace
  46. namespace megdnn {
  47. namespace test {
  48. namespace elemwise {
  49. #define DEF_TEST(name) \
  50. template <> \
  51. void run_test<name>(Handle * handle)
  52. DEF_TEST(unary) {
  53. using Mode = ElemwiseForward::Param::Mode;
  54. Checker<ElemwiseForward> checker(handle);
  55. checker.set_param(Mode::SIN);
  56. checker.set_dtype(0, dtype::Float32()).execs({{3, 4, 1}, {}});
  57. checker.set_dtype(0, dtype::Float16()).execs({{3, 4, 1}, {}});
  58. }
  59. DEF_TEST(binary_brdcst) {
  60. auto run = [&](DType dtype) {
  61. using Mode = ElemwiseForward::Param::Mode;
  62. Checker<ElemwiseForward> checker(handle);
  63. checker.set_param(Mode::ADD);
  64. checker.set_dtype(0, dtype);
  65. checker.set_dtype(1, dtype);
  66. checker.execs({{3, 1}, {1, 3}, {3, 3}});
  67. {
  68. checker.execs({{10, 11}, {10, 11}, {10, 11}});
  69. //
  70. checker.execs({{2, 3, 4, 5, 6, 7}, {1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}});
  71. checker.execs({{1, 3, 1, 1, 6, 1}, {2, 3, 4, 5, 6, 7}, {2, 3, 4, 5, 6, 7}});
  72. //
  73. checker.execs({{256, 256, 3}, {1, 1, 3}, {256, 256, 3}});
  74. checker.execs({{1, 1, 3}, {256, 256, 3}, {256, 256, 3}});
  75. //
  76. checker.execs({{8, 1, 6, 1}, {1, 7, 1, 5}, {8, 7, 6, 5}});
  77. checker.execs({{1, 7, 1, 5}, {8, 1, 6, 1}, {8, 7, 6, 5}});
  78. //
  79. checker.execs({{5, 4}, {1, 1}, {5, 4}});
  80. checker.execs({{1, 1}, {5, 4}, {5, 4}});
  81. //
  82. checker.execs({{5, 4}, {1, 4}, {5, 4}});
  83. checker.execs({{1, 4}, {5, 4}, {5, 4}});
  84. //
  85. checker.execs({{15, 3, 5}, {15, 1, 5}, {15, 3, 5}});
  86. checker.execs({{15, 1, 5}, {15, 3, 5}, {15, 3, 5}});
  87. //
  88. checker.execs({{15, 3, 5}, {1, 3, 5}, {15, 3, 5}});
  89. checker.execs({{1, 3, 5}, {15, 3, 5}, {15, 3, 5}});
  90. //
  91. checker.execs({{15, 3, 5}, {1, 3, 1}, {15, 3, 5}});
  92. checker.execs({{1, 3, 1}, {15, 3, 5}, {15, 3, 5}});
  93. //
  94. checker.execs({{3, 1}, {1, 4}, {3, 4}});
  95. // numpy broadcast
  96. checker.execs({{2, 3, 1, 5}, {4, 5}, {2, 3, 4, 5}});
  97. checker.execs({{3, 1, 1}, {4, 5}, {3, 4, 5}});
  98. }
  99. {
  100. // 1d
  101. {
  102. auto n = 1000u;
  103. checker.execs({{n}, {n}, {n}});
  104. checker.execs({{1}, {n}, {n}});
  105. checker.execs({{n}, {1}, {n}});
  106. }
  107. // 2d
  108. {
  109. auto m = 200u, n = 100u;
  110. auto collapse = [](size_t n, bool is_collapsed) {
  111. return is_collapsed ? 1u : n;
  112. };
  113. for (auto msk = 0u; msk < 16; ++msk) {
  114. checker.execs(
  115. {{collapse(m, msk & 1), collapse(n, msk & 2)},
  116. {collapse(m, msk & 4), collapse(n, msk & 8)},
  117. {}});
  118. }
  119. }
  120. // nd
  121. {
  122. checker.execs({{2, 3, 4, 5, 6}, {1, 3, 1, 5, 6}, {2, 3, 4, 5, 6}});
  123. checker.execs({{2, 3, 4, 5, 6}, {2, 1, 4, 1, 6}, {2, 3, 4, 5, 6}});
  124. }
  125. }
  126. };
  127. run(dtype::Float32());
  128. // run(dtype::Float16());
  129. }
  130. DEF_TEST(binary_non_contig) {
  131. using Mode = ElemwiseForward::Param::Mode;
  132. Checker<ElemwiseForward> checker(handle);
  133. checker.set_param(Mode::ADD);
  134. TensorLayout ly{{2, 3}, dtype::Float32()};
  135. ly.stride[0] = 4;
  136. checker.execl({ly, ly, {{2, 3}, dtype::Float32()}});
  137. }
  138. DEF_TEST(ternary) {
  139. using Mode = ElemwiseForward::Param::Mode;
  140. Checker<ElemwiseForward> checker(handle);
  141. checker.set_param(Mode::COND_LEQ_MOV);
  142. checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  143. checker.set_dtype(0, dtype::Float32())
  144. .set_dtype(1, dtype::Float32())
  145. .set_dtype(2, dtype::Float32())
  146. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  147. checker.set_dtype(0, dtype::Float16())
  148. .set_dtype(1, dtype::Float16())
  149. .set_dtype(2, dtype::Float16())
  150. .set_dtype(3, dtype::Float16())
  151. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  152. checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
  153. checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
  154. ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
  155. ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
  156. }
  157. DEF_TEST(ternary_non_contig) {
  158. using Mode = ElemwiseForward::Param::Mode;
  159. Checker<ElemwiseForward> checker(handle);
  160. checker.set_param(Mode::COND_LEQ_MOV);
  161. TensorLayout ly{{2, 3}, dtype::Float32()};
  162. ly.stride[0] = 4;
  163. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  164. }
  165. DEF_TEST(ternary_lt) {
  166. using Mode = ElemwiseForward::Param::Mode;
  167. Checker<ElemwiseForward> checker(handle);
  168. checker.set_param(Mode::COND_LT_MOV);
  169. checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  170. checker.set_dtype(0, dtype::Float32())
  171. .set_dtype(1, dtype::Float32())
  172. .set_dtype(2, dtype::Float32())
  173. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  174. checker.set_dtype(0, dtype::Float16())
  175. .set_dtype(1, dtype::Float16())
  176. .set_dtype(2, dtype::Float16())
  177. .set_dtype(3, dtype::Float16())
  178. .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  179. checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
  180. checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
  181. ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
  182. ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
  183. }
  184. DEF_TEST(ternary_lt_non_contig) {
  185. using Mode = ElemwiseForward::Param::Mode;
  186. Checker<ElemwiseForward> checker(handle);
  187. checker.set_param(Mode::COND_LT_MOV);
  188. TensorLayout ly{{2, 3}, dtype::Float32()};
  189. ly.stride[0] = 4;
  190. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  191. }
  192. DEF_TEST(fuse_mul_add3) {
  193. using Mode = ElemwiseForward::Param::Mode;
  194. Checker<ElemwiseForward> checker(handle);
  195. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  196. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  197. const TensorShape& s2) {
  198. TensorShape dest;
  199. dest.ndim = s0.ndim;
  200. for (size_t i = 0; i < dest.ndim; ++i) {
  201. auto a = i < s0.ndim ? s0[i] : 1;
  202. auto b = i < s1.ndim ? s1[i] : 1;
  203. dest[i] = std::max(a, b);
  204. }
  205. return TensorShapeArray{s0, s1, s2, dest};
  206. };
  207. checker.exec(make_shape({2, 1}, {2, 2}, {2, 2}));
  208. checker.exec(make_shape({2, 2}, {2, 1}, {2, 2}));
  209. checker.exec(make_shape({2, 2}, {2, 2}, {1}));
  210. checker.exec(make_shape({3, 1}, {1, 3}, {3, 1}));
  211. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}, {1}));
  212. checker.exec(make_shape({1, 1, 3}, {5, 8, 1}, {5, 8, 1}));
  213. checker.exec(make_shape({1, 192, 9, 16}, {1}, {1, 192, 9, 16}));
  214. }
  215. DEF_TEST(fuse_mul_add3_non_contig) {
  216. using Mode = ElemwiseForward::Param::Mode;
  217. Checker<ElemwiseForward> checker(handle);
  218. checker.set_param(Mode::FUSE_MUL_ADD3).set_extra_opr_impl(fma3_extra_opr_impl);
  219. TensorLayout ly{{2, 3}, dtype::Float32()};
  220. ly.stride[0] = 4;
  221. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  222. }
  223. DEF_TEST(fuse_mul_add4) {
  224. using Mode = ElemwiseForward::Param::Mode;
  225. Checker<ElemwiseForward> checker(handle);
  226. checker.set_param(Mode::FUSE_MUL_ADD4).set_extra_opr_impl(fma4_extra_opr_impl);
  227. auto make_shape = [](const TensorShape& s0, const TensorShape& s1,
  228. bool swap = false) {
  229. TensorShape dest;
  230. dest.ndim = s0.ndim;
  231. for (size_t i = 0; i < dest.ndim; ++i) {
  232. auto a = i < s0.ndim ? s0[i] : 1;
  233. auto b = i < s1.ndim ? s1[i] : 1;
  234. dest[i] = std::max(a, b);
  235. }
  236. TensorShapeArray ret{s0, s1, s0, s1, dest};
  237. if (swap)
  238. std::swap(ret[2], ret[3]);
  239. return ret;
  240. };
  241. checker.exec(make_shape({2, 2}, {2, 2}));
  242. checker.exec(make_shape({3, 1}, {1, 3}));
  243. checker.exec(make_shape({2, 1, 2, 1, 2, 1}, {1, 2, 1, 2, 1, 2}));
  244. checker.exec(make_shape({4, 2}, {1, 2}, true));
  245. }
  246. DEF_TEST(rmulh) {
  247. using Mode = ElemwiseForward::Param::Mode;
  248. Checker<ElemwiseForward> checker(handle);
  249. auto run_for_dtype = [&checker](auto dtype) {
  250. auto minv = DTypeTrait<decltype(dtype)>::min();
  251. auto maxv = DTypeTrait<decltype(dtype)>::max();
  252. UniformIntRNG rng0{minv, maxv};
  253. UniformIntRNG rngM{(maxv >> 1) + 1, maxv};
  254. checker.set_param({Mode::RMULH})
  255. .set_dtype(0, dtype)
  256. .set_dtype(1, dtype)
  257. .set_dtype(2, dtype)
  258. .set_rng(0, &rng0)
  259. .set_rng(1, &rngM);
  260. checker.execs({{7, 9, 11, 13}, {1}, {}})
  261. .execs({{16, 3, 256, 256}, {1}, {}})
  262. .execs({{2, 3, 1, 7}, {2, 3, 1, 7}, {}})
  263. .execs({{9, 5, 4}, {1, 5, 1}, {}})
  264. .execs({{233}, {1}, {}});
  265. };
  266. run_for_dtype(dtype::Int8());
  267. run_for_dtype(dtype::Int16());
  268. run_for_dtype(dtype::Int32());
  269. }
  270. /* ============= migrated from x86 tests ============= */
  271. #define UNARY_TEST_CASE(_optr) \
  272. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  273. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  274. #define BUILD_UNARY_TEST_CASE_INT \
  275. UNARY_TEST_CASE(RELU) \
  276. UNARY_TEST_CASE(ABS)
  277. #define BUILD_UNARY_TEST_CASE_FLOAT \
  278. UNARY_TEST_CASE(ABS) \
  279. UNARY_TEST_CASE(LOG) \
  280. UNARY_TEST_CASE(COS) \
  281. UNARY_TEST_CASE(SIN) \
  282. UNARY_TEST_CASE(FLOOR) \
  283. UNARY_TEST_CASE(CEIL) \
  284. UNARY_TEST_CASE(SIGMOID) \
  285. UNARY_TEST_CASE(EXP) \
  286. UNARY_TEST_CASE(TANH) \
  287. UNARY_TEST_CASE(FAST_TANH) \
  288. UNARY_TEST_CASE(RELU) \
  289. UNARY_TEST_CASE(ROUND)
  290. DEF_TEST(unary1) {
  291. using Mode = ElemwiseForward::Param::Mode;
  292. Checker<ElemwiseForward> checker(handle);
  293. // case int
  294. checker.set_dtype(0, dtype::Int8());
  295. BUILD_UNARY_TEST_CASE_INT
  296. checker.set_dtype(0, dtype::Int16());
  297. BUILD_UNARY_TEST_CASE_INT
  298. checker.set_dtype(0, dtype::Int32());
  299. BUILD_UNARY_TEST_CASE_INT
  300. // case float
  301. UniformFloatRNG rng(1e-2, 6e1);
  302. checker.set_rng(0, &rng);
  303. checker.set_epsilon(1e-5);
  304. checker.set_dtype(0, dtype::Float32());
  305. BUILD_UNARY_TEST_CASE_FLOAT
  306. }
  307. #undef UNARY_TEST_CASE
  308. #undef BUILD_UNARY_TEST_CASE_INT
  309. #undef BUILD_UNARY_TEST_CASE_FLOAT
  310. #define BINARY_TEST_CASE(_optr) \
  311. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  312. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  313. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  314. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  315. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  316. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  317. #define BUILD_BINARY_TEST_CASE \
  318. BINARY_TEST_CASE(MIN) \
  319. BINARY_TEST_CASE(MAX)
  320. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  321. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  322. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  323. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  324. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  325. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  326. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  327. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  328. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  329. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  330. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  331. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  332. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  333. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  334. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  335. BINARY_COMPLATE_TEST_CASE(ADD) \
  336. BINARY_COMPLATE_TEST_CASE(MUL) \
  337. BINARY_COMPLATE_TEST_CASE(MAX) \
  338. BINARY_COMPLATE_TEST_CASE(MIN) \
  339. BINARY_COMPLATE_TEST_CASE(SUB)
  340. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  341. BINARY_COMPLATE_TEST_CASE(POW) \
  342. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  343. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  344. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  345. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU) \
  346. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_H_SWISH) \
  347. BINARY_COMPLATE_TEST_CASE(FAST_TANH_GRAD) \
  348. BINARY_COMPLATE_TEST_CASE(H_SWISH_GRAD)
  349. DEF_TEST(binary1) {
  350. using Mode = ElemwiseForward::Param::Mode;
  351. Checker<ElemwiseForward> checker(handle);
  352. // case float
  353. UniformFloatRNG rng(1e-5, 7e1);
  354. checker.set_rng(0, &rng);
  355. checker.set_epsilon(1e-5);
  356. checker.set_dtype(0, dtype::Float32());
  357. checker.set_dtype(1, dtype::Float32());
  358. BUILD_BINARY_COMPLATE_TEST_CASE
  359. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  360. // case int
  361. checker.set_dtype(0, dtype::Int8());
  362. checker.set_dtype(1, dtype::Int8());
  363. BUILD_BINARY_TEST_CASE
  364. BUILD_BINARY_COMPLATE_TEST_CASE
  365. checker.set_dtype(0, dtype::Int16());
  366. checker.set_dtype(1, dtype::Int16());
  367. BUILD_BINARY_TEST_CASE
  368. BUILD_BINARY_COMPLATE_TEST_CASE
  369. checker.set_dtype(0, dtype::Int32());
  370. checker.set_dtype(1, dtype::Int32());
  371. BUILD_BINARY_TEST_CASE
  372. BUILD_BINARY_COMPLATE_TEST_CASE
  373. }
  374. #undef BINARY_TEST_CASE
  375. #undef BUILD_BINARY_TEST_CASE
  376. #undef BINARY_COMPLATE_TEST_CASE
  377. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  378. #undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  379. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  380. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  381. checker.set_param(Mode::_optr) \
  382. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  383. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  384. checker.set_param(Mode::_optr) \
  385. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  386. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  387. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  388. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  389. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  390. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  391. DEF_TEST(ternary1) {
  392. using Mode = ElemwiseForward::Param::Mode;
  393. Checker<ElemwiseForward> checker(handle);
  394. // case int
  395. checker.set_dtype(0, dtype::Int8());
  396. checker.set_dtype(1, dtype::Int8());
  397. checker.set_dtype(2, dtype::Int8());
  398. // BUILD_TERNARY_TEST_CASE
  399. BUILD_TERNARY_COMPLATE_TEST_CASE
  400. checker.set_dtype(0, dtype::Int16());
  401. checker.set_dtype(1, dtype::Int16());
  402. checker.set_dtype(2, dtype::Int16());
  403. // BUILD_TERNARY_TEST_CASE
  404. BUILD_TERNARY_COMPLATE_TEST_CASE
  405. checker.set_dtype(0, dtype::Int32());
  406. checker.set_dtype(1, dtype::Int32());
  407. checker.set_dtype(2, dtype::Int32());
  408. // BUILD_TERNARY_TEST_CASE
  409. BUILD_TERNARY_COMPLATE_TEST_CASE
  410. // case float
  411. UniformFloatRNG rng(1e-5, 7e1);
  412. checker.set_rng(0, &rng);
  413. checker.set_epsilon(1e-5);
  414. checker.set_dtype(0, dtype::Float32());
  415. checker.set_dtype(1, dtype::Float32());
  416. checker.set_dtype(2, dtype::Float32());
  417. // BUILD_TERNARY_TEST_CASE
  418. BUILD_TERNARY_COMPLATE_TEST_CASE
  419. // TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  420. }
  421. #undef TERNARY_COMPLATE_TEST_CASE
  422. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  423. /* ============= migrated from arm tests ============= */
  424. #define UNARY_TEST_CASE(_optr) \
  425. checker.set_param(Mode::_optr).execs({{1, 129}, {}}); \
  426. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  427. #define BUILD_UNARY_TEST_CASE_INT \
  428. UNARY_TEST_CASE(RELU) \
  429. UNARY_TEST_CASE(ABS) \
  430. UNARY_TEST_CASE(NEGATE)
  431. #define BUILD_UNARY_TEST_CASE_FLOAT \
  432. BUILD_UNARY_TEST_CASE_INT \
  433. UNARY_TEST_CASE(SIGMOID) \
  434. UNARY_TEST_CASE(EXP) \
  435. UNARY_TEST_CASE(TANH) \
  436. UNARY_TEST_CASE(FAST_TANH) \
  437. UNARY_TEST_CASE(H_SWISH)
  438. DEF_TEST(unary2) {
  439. using Mode = ElemwiseForward::Param::Mode;
  440. Checker<ElemwiseForward> checker(handle);
  441. // case int
  442. checker.set_dtype(0, dtype::Int8());
  443. BUILD_UNARY_TEST_CASE_INT
  444. checker.set_dtype(0, dtype::Int16());
  445. BUILD_UNARY_TEST_CASE_INT
  446. checker.set_dtype(0, dtype::Int32());
  447. BUILD_UNARY_TEST_CASE_INT
  448. // case float
  449. {
  450. UniformFloatRNG rng(1e-5, 7e1);
  451. checker.set_rng(0, &rng);
  452. checker.set_epsilon(1e-5);
  453. checker.set_dtype(0, dtype::Float32());
  454. BUILD_UNARY_TEST_CASE_FLOAT
  455. }
  456. {
  457. UniformFloatRNG rng(1e-2, 1e1);
  458. checker.set_rng(0, &rng);
  459. checker.set_epsilon(6e-3);
  460. checker.set_dtype(0, dtype::Float16());
  461. BUILD_UNARY_TEST_CASE_FLOAT
  462. }
  463. // tanh NaN bug case
  464. {
  465. UniformFloatRNG rng(100, 200);
  466. checker.set_rng(0, &rng);
  467. checker.set_epsilon(1e-5);
  468. checker.set_dtype(0, dtype::Float32());
  469. checker.set_param(Mode::TANH).execs({{1, 1025}, {}});
  470. checker.set_param(Mode::TANH).execs({{1, 7}, {}});
  471. }
  472. }
  473. #undef UNARY_TEST_CASE
  474. #undef BUILD_UNARY_TEST_CASE_INT
  475. #undef BUILD_UNARY_TEST_CASE_FLOAT
  476. #define BINARY_TEST_CASE(_optr) \
  477. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  478. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  479. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  480. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  481. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  482. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  483. #define BUILD_BINARY_TEST_CASE \
  484. BINARY_TEST_CASE(MIN) \
  485. BINARY_TEST_CASE(MAX)
  486. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  487. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  488. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  489. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  490. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  491. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  492. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  493. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  494. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  495. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  496. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  497. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  498. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  499. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  500. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  501. BINARY_COMPLATE_TEST_CASE(ADD) \
  502. BINARY_COMPLATE_TEST_CASE(MUL) \
  503. BINARY_COMPLATE_TEST_CASE(MAX) \
  504. BINARY_COMPLATE_TEST_CASE(MIN) \
  505. BINARY_COMPLATE_TEST_CASE(SUB) \
  506. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  507. DEF_TEST(binary2) {
  508. using Mode = ElemwiseForward::Param::Mode;
  509. Checker<ElemwiseForward> checker(handle);
  510. // case float
  511. UniformFloatRNG rng(1e-5, 7e1);
  512. checker.set_rng(0, &rng);
  513. checker.set_epsilon(1e-5);
  514. checker.set_dtype(0, dtype::Float32());
  515. checker.set_dtype(1, dtype::Float32());
  516. BUILD_BINARY_COMPLATE_TEST_CASE
  517. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID)
  518. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH)
  519. // case int
  520. checker.set_dtype(0, dtype::Int8());
  521. checker.set_dtype(1, dtype::Int8());
  522. // BUILD_BINARY_TEST_CASE
  523. BUILD_BINARY_COMPLATE_TEST_CASE
  524. checker.set_dtype(0, dtype::Int16());
  525. checker.set_dtype(1, dtype::Int16());
  526. // BUILD_BINARY_TEST_CASE
  527. BUILD_BINARY_COMPLATE_TEST_CASE
  528. checker.set_dtype(0, dtype::Int32());
  529. checker.set_dtype(1, dtype::Int32());
  530. BUILD_BINARY_TEST_CASE
  531. BUILD_BINARY_COMPLATE_TEST_CASE
  532. // case float
  533. checker.set_rng(0, &rng);
  534. checker.set_epsilon(1e-5);
  535. checker.set_dtype(0, dtype::Float32());
  536. checker.set_dtype(1, dtype::Float32());
  537. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  538. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  539. // commutable
  540. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  541. BUILD_BINARY_TEST_CASE
  542. BUILD_BINARY_COMPLATE_TEST_CASE
  543. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  544. {
  545. UniformFloatRNG rng(1e-3, 3e1);
  546. checker.set_rng(0, &rng);
  547. checker.set_rng(1, &rng);
  548. checker.set_epsilon(1e-3);
  549. checker.set_dtype(0, dtype::Float16());
  550. checker.set_dtype(1, dtype::Float16());
  551. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  552. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  553. BUILD_BINARY_TEST_CASE
  554. BUILD_BINARY_COMPLATE_TEST_CASE
  555. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  556. // commutable
  557. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  558. }
  559. }
  560. #undef BINARY_TEST_CASE
  561. #undef BUILD_BINARY_TEST_CASE
  562. #undef BINARY_COMPLATE_TEST_CASE
  563. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  564. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  565. checker.set_param(Mode::_optr) \
  566. .execs({{1, 123, 1}, {300, 123, 253}, {300, 123, 253}, {}}); \
  567. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  568. checker.set_param(Mode::_optr) \
  569. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  570. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  571. checker.set_param(Mode::_optr) \
  572. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  573. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  574. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  575. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  576. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
  577. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {1, 1, 1}, {3, 4, 1}, {}});
  578. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  579. DEF_TEST(ternary2) {
  580. using Mode = ElemwiseForward::Param::Mode;
  581. Checker<ElemwiseForward> checker(handle);
  582. // case int
  583. checker.set_dtype(0, dtype::Int8());
  584. checker.set_dtype(1, dtype::Int8());
  585. checker.set_dtype(2, dtype::Int8());
  586. BUILD_TERNARY_COMPLATE_TEST_CASE
  587. checker.set_dtype(0, dtype::Int16());
  588. checker.set_dtype(1, dtype::Int16());
  589. checker.set_dtype(2, dtype::Int16());
  590. BUILD_TERNARY_COMPLATE_TEST_CASE
  591. checker.set_dtype(0, dtype::Int32());
  592. checker.set_dtype(1, dtype::Int32());
  593. checker.set_dtype(2, dtype::Int32());
  594. BUILD_TERNARY_COMPLATE_TEST_CASE
  595. // case float
  596. UniformFloatRNG rng(1e-5, 7e1);
  597. checker.set_rng(0, &rng);
  598. checker.set_epsilon(1e-5);
  599. checker.set_dtype(0, dtype::Float32());
  600. checker.set_dtype(1, dtype::Float32());
  601. checker.set_dtype(2, dtype::Float32());
  602. BUILD_TERNARY_COMPLATE_TEST_CASE
  603. {
  604. UniformFloatRNG rng(1e-3, 3e1);
  605. checker.set_rng(0, &rng);
  606. checker.set_rng(1, &rng);
  607. checker.set_rng(2, &rng);
  608. checker.set_epsilon(1e-3);
  609. checker.set_dtype(0, dtype::Float16());
  610. checker.set_dtype(1, dtype::Float16());
  611. checker.set_dtype(2, dtype::Float16());
  612. BUILD_TERNARY_COMPLATE_TEST_CASE
  613. }
  614. }
  615. #undef TERNARY_COMPLATE_TEST_CASE
  616. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  617. /* ============= migrated from fallback tests ============= */
  618. DEF_TEST(unary3) {
  619. Checker<Elemwise> checker(handle);
  620. auto make_layouts =
  621. [](const TensorShape& shp,
  622. std::initializer_list<ptrdiff_t> stride) -> TensorLayoutArray {
  623. return {make_layout(shp, stride), {shp, dtype::Float32()}};
  624. };
  625. checker.set_param({Elemwise::Mode::SIN});
  626. checker.exec(make_layouts({2, 2}, {2, 1}));
  627. checker.exec(make_layouts({4}, {3}));
  628. }
  629. DEF_TEST(binary3) {
  630. Checker<Elemwise> checker(handle);
  631. checker.set_param({Elemwise::Mode::ADD});
  632. auto run = [&](const TensorShape& shp0, std::initializer_list<ptrdiff_t> stride0,
  633. const TensorShape& shp1, std::initializer_list<ptrdiff_t> stride1) {
  634. TensorShape shpo;
  635. Elemwise::deduce_shape({shp0, shp1}, shpo);
  636. auto ly0 = make_layout(shp0, stride0), ly1 = make_layout(shp1, stride1),
  637. lyo = TensorLayout{shpo, dtype::Float32()};
  638. checker.execl({ly0, ly1, lyo});
  639. checker.execl({ly1, ly0, lyo});
  640. };
  641. run({2, 2}, {2, 1}, {2, 2}, {2, 1});
  642. run({1}, {1}, {3, 3}, {1, 2});
  643. run({3, 4, 5}, {40, 10, 2}, {1, 4, 1}, {1, 1, 1});
  644. }
  645. DEF_TEST(all_modes) {
  646. // test correctness of all elemwise modes
  647. Checker<Elemwise> checker(handle);
  648. TensorShapeArray shapes;
  649. UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f},
  650. small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f},
  651. abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f},
  652. tanh_rng_f32{-5.f, 5.f};
  653. UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f},
  654. big_nonzero_rng_f32{100.f, 1000.f};
  655. UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2},
  656. shift_rng_i32_i32{0, 31}, shift_rng_i32_i8{0, 7};
  657. UniformIntNonZeroRNG nonzero_rng_i32{1, 100};
  658. using Mode = Elemwise::Mode;
  659. auto should_ignore = [handle](Mode mode) {
  660. MEGDNN_MARK_USED_VAR(mode);
  661. switch (mode) {
  662. case Mode::NEQ:
  663. case Mode::ISNAN:
  664. case Mode::ISINF:
  665. return true;
  666. default:
  667. break;
  668. }
  669. return false;
  670. };
  671. for (int mode_nr = 0; mode_nr < static_cast<int>(Elemwise::Param::MODE_NR_MEMBER);
  672. ++mode_nr) {
  673. auto mode = static_cast<Mode>(mode_nr);
  674. // ignore unsupported modes
  675. if (should_ignore(mode)) {
  676. continue;
  677. }
  678. checker.set_param({mode});
  679. auto&& trait = Elemwise::ModeTrait::from_mode(mode);
  680. shapes.resize(trait.arity + 1);
  681. for (size_t i = 0; i < shapes.size() - 1; ++i) {
  682. shapes[i] = {3, 9, 7};
  683. }
  684. //! NOTE: force set output layout to empty to trigger layout deduce
  685. shapes[shapes.size() - 1] = {};
  686. auto do_run = [&](DType dtype, float eps = 1e-3) {
  687. // limit value ranges for some modes
  688. if (mode == Mode::LOG || mode == Mode::LOG1P) {
  689. checker.set_rng(0, &pos_rng_f32);
  690. } else if (mode == Mode::POW) {
  691. checker.set_rng(0, &small_pos_rng_f32);
  692. checker.set_rng(1, &small_rng_f32);
  693. } else if (mode == Mode::EXP || mode == Mode::EXPM1) {
  694. checker.set_rng(0, &small_rng_f32);
  695. } else if (mode == Mode::FAST_TANH) {
  696. checker.set_rng(0, &tanh_rng_f32);
  697. } else if (mode == Mode::LOG_SUM_EXP) {
  698. // check numerical stability with large values
  699. checker.set_rng(0, &big_nonzero_rng_f32);
  700. checker.set_rng(1, &big_nonzero_rng_f32);
  701. } else if (
  702. mode == Mode::ASIN || mode == Mode::ACOS ||
  703. mode == Mode::SIGMOID_GRAD || mode == Mode::TANH_GRAD ||
  704. mode == Mode::ERFINV) {
  705. checker.set_rng(0, &abslt1_rng_f32);
  706. checker.set_rng(1, &default_rng_f32);
  707. } else if (mode == Mode::ERFCINV) {
  708. checker.set_rng(0, &uniform_0_2_rng);
  709. } else if (
  710. mode == Mode::MOD || mode == Mode::TRUE_DIV ||
  711. mode == Mode::FLOOR_DIV) {
  712. if (dtype.category() == DTypeCategory::INT) {
  713. checker.set_rng(0, &default_rng_i32);
  714. checker.set_rng(1, &nonzero_rng_i32);
  715. } else {
  716. checker.set_rng(0, &default_rng_f32);
  717. checker.set_rng(1, &nonzero_rng_f32);
  718. }
  719. } else if (mode == Mode::EQ) {
  720. checker.set_rng(0, &small_rng_i32);
  721. checker.set_rng(1, &small_rng_i32);
  722. } else if (mode == Mode::SHL || mode == Mode::SHR) {
  723. checker.set_rng(0, &default_rng_i32);
  724. if (dtype.size() == 4) {
  725. checker.set_rng(1, &shift_rng_i32_i32);
  726. } else {
  727. megdnn_assert(dtype.size() == 1);
  728. checker.set_rng(1, &shift_rng_i32_i8);
  729. }
  730. } else if (mode == Mode::ATAN2) {
  731. checker.set_rng(0, &nonzero_rng_f32);
  732. checker.set_rng(1, &nonzero_rng_f32);
  733. } else {
  734. RNG* rng;
  735. if (dtype.category() == DTypeCategory::INT) {
  736. rng = &default_rng_i32;
  737. } else {
  738. rng = &default_rng_f32;
  739. }
  740. for (size_t i = 0; i < shapes.size(); ++i) {
  741. checker.set_rng(i, rng);
  742. }
  743. }
  744. checker.set_epsilon(eps);
  745. for (size_t i = 0; i < shapes.size(); ++i) {
  746. checker.set_dtype(i, dtype);
  747. }
  748. EXPECT_NO_THROW(checker.execs(shapes));
  749. if (!::testing::Test::HasFailure() && shapes.size() == 3) {
  750. // channel bcast
  751. shapes[1][0] = 1;
  752. shapes[1][2] = 1;
  753. EXPECT_NO_THROW(checker.execs(shapes));
  754. if (!::testing::Test::HasFailure()) {
  755. // scalar bcast
  756. shapes[1][1] = 1;
  757. EXPECT_NO_THROW(checker.execs(shapes));
  758. }
  759. }
  760. if (::testing::Test::HasFailure()) {
  761. printf("failed on mode=%d(%s) dtype=%s\n", mode_nr, trait.name,
  762. dtype.name());
  763. for (auto&& i : shapes) {
  764. printf("ishape: %s\n", i.to_string().c_str());
  765. }
  766. return false;
  767. }
  768. return true;
  769. };
  770. #define run(args...) \
  771. do { \
  772. if (!do_run(args)) { \
  773. return; \
  774. } \
  775. } while (0)
  776. if (trait.allow_int) {
  777. run(dtype::Int32{});
  778. run(dtype::Int8{});
  779. }
  780. if (trait.allow_float) {
  781. DNN_FLOAT16_SELECT(
  782. run(dtype::Float16{}, mode == Mode::FAST_TANH_GRAD ? 0.5 : 0.05), );
  783. run(dtype::Float32{});
  784. }
  785. }
  786. #undef run
  787. }
  788. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  789. checker.set_param(Mode::_optr) \
  790. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
  791. checker.set_param(Mode::_optr) \
  792. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
  793. checker.set_param(Mode::_optr) \
  794. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});
  795. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
  796. checker.set_param(Mode::_optr) \
  797. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});
  798. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  799. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
  800. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);
  801. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
  802. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
  803. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
  804. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
  805. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
  806. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
  807. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
  808. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
  809. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
  810. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
  811. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
  812. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
  813. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)
  814. DEF_TEST(unary_negative_stride) {
  815. using Mode = ElemwiseForward::Param::Mode;
  816. Checker<ElemwiseForward> checker(handle);
  817. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  818. UniformFloatRNG rng(1e-2, 6e1);
  819. checker.set_rng(0, &rng);
  820. checker.set_epsilon(1e-5);
  821. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
  822. }
  823. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  824. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  825. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  826. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  827. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  828. checker.set_param(Mode::_optr) \
  829. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
  830. {{1, 4, 1}, dtype::Int8()}, \
  831. {}}); \
  832. checker.set_param(Mode::_optr) \
  833. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
  834. {{1, 4, 1}, dtype::Int16()}, \
  835. {}}); \
  836. checker.set_param(Mode::_optr) \
  837. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
  838. {{1, 4, 1}, dtype::Int32()}, \
  839. {}});
  840. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
  841. checker.set_param(Mode::_optr) \
  842. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
  843. {{1, 4, 1}, dtype::Float32()}, \
  844. {}});
  845. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  846. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
  847. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
  848. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
  849. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
  850. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)
  851. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
  852. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
  853. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
  854. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
  855. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
  856. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
  857. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
  858. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
  859. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)
  860. DEF_TEST(binary_negative_stride) {
  861. using Mode = ElemwiseForward::Param::Mode;
  862. Checker<ElemwiseForward> checker(handle);
  863. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  864. UniformFloatRNG rng(1e-2, 2e1);
  865. checker.set_rng(0, &rng);
  866. checker.set_epsilon(1e-5);
  867. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
  868. }
  869. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  870. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  871. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  872. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  873. DEF_TEST(ternary_negative_stride) {
  874. using Mode = ElemwiseForward::Param::Mode;
  875. Checker<ElemwiseForward> checker(handle);
  876. checker.set_param(Mode::FUSE_MUL_ADD3);
  877. checker.execl(
  878. {{{1, 7}, {-7, -1}, dtype::Int8()},
  879. {{1, 7}, {-3, -1}, dtype::Int8()},
  880. {{1, 7}, {-7, -1}, dtype::Int8()},
  881. {}});
  882. checker.execl(
  883. {{{1, 7}, {-7, -1}, dtype::Int16()},
  884. {{1, 7}, {-3, -1}, dtype::Int16()},
  885. {{1, 7}, {-7, -1}, dtype::Int16()},
  886. {}});
  887. checker.execl(
  888. {{{1, 7}, {-7, -1}, dtype::Int32()},
  889. {{1, 7}, {-3, -1}, dtype::Int32()},
  890. {{1, 7}, {-7, -1}, dtype::Int32()},
  891. {}});
  892. UniformFloatRNG rng(1e-2, 2e1);
  893. checker.set_rng(0, &rng);
  894. checker.set_epsilon(1e-5);
  895. checker.execl(
  896. {{{1, 7}, {-7, -1}, dtype::Float32()},
  897. {{1, 7}, {-3, -1}, dtype::Float32()},
  898. {{1, 7}, {-7, -1}, dtype::Float32()},
  899. {}});
  900. }
  901. TEST(TEST_ELEMWISE, MODE_TRAIT) {
  902. using M = Elemwise::Mode;
  903. using T = Elemwise::ModeTrait;
  904. ASSERT_EQ(1u, T::from_mode(M::RELU).arity);
  905. ASSERT_EQ(2u, T::from_mode(M::ADD).arity);
  906. ASSERT_EQ(3u, T::from_mode(M::FUSE_MUL_ADD3).arity);
  907. ASSERT_EQ(4u, T::from_mode(M::FUSE_MUL_ADD4).arity);
  908. ASSERT_TRUE(T::from_mode(M::ADD).commutable);
  909. ASSERT_FALSE(T::from_mode(M::TRUE_DIV).commutable);
  910. ASSERT_TRUE(T::from_mode(M::ADD).allow_int);
  911. ASSERT_FALSE(T::from_mode(M::EXP).allow_int);
  912. ASSERT_TRUE(T::from_mode(M::ADD).allow_float);
  913. ASSERT_FALSE(T::from_mode(M::SHL).allow_float);
  914. ASSERT_TRUE(T::from_mode(M::RMULH).commutable);
  915. ASSERT_FALSE(T::from_mode(M::RMULH).allow_float);
  916. ASSERT_TRUE(T::from_mode(M::XOR).allow_bool);
  917. }
  918. } // namespace elemwise
  919. } // namespace test
  920. } // namespace megdnn
  921. // vim: syntax=cpp.doxygen