You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 39 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078
  1. /**
  2. * \file dnn/test/common/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/elemwise.h"
  12. #include "src/common/utils.cuh"
  13. #include "test/common/utils.h"
  14. #include "test/common/checker.h"
  15. #include "megdnn/oprs/general.h"
  16. #include "test/common/fix_gtest_on_platforms_without_exception.inl"
  17. using namespace megdnn;
  18. using namespace test;
  19. namespace {
  20. void fma3_extra_opr_impl(const TensorNDArray &data) {
  21. megdnn_assert(data.size() == 4);
  22. auto handle = create_cpu_handle(2);
  23. auto opr = handle->create_operator<Elemwise>();
  24. using Mode = Elemwise::Mode;
  25. opr->param().mode = Mode::MUL;
  26. opr->exec({data[0], data[1]}, data[3]);
  27. opr->param().mode = Mode::ADD;
  28. opr->exec({data[2], data[3]}, data[3]);
  29. }
  30. void fma4_extra_opr_impl(const TensorNDArray &data) {
  31. megdnn_assert(data.size() == 5);
  32. std::vector<uint8_t> tmp_storage(data[4].layout.span().dist_byte());
  33. TensorND tmp;
  34. tmp.raw_ptr = tmp_storage.data();
  35. tmp.layout = data[4].layout;
  36. tmp.layout.init_contiguous_stride();
  37. auto handle = create_cpu_handle(2);
  38. auto opr = handle->create_operator<Elemwise>();
  39. using Mode = Elemwise::Mode;
  40. opr->param().mode = Mode::MUL;
  41. opr->exec({data[0], data[1]}, data[4]);
  42. opr->exec({data[2], data[3]}, tmp);
  43. opr->param().mode = Mode::ADD;
  44. opr->exec({tmp, data[4]}, data[4]);
  45. }
  46. TensorLayout make_layout(const TensorShape &shp,
  47. std::initializer_list<ptrdiff_t> stride) {
  48. TensorLayout ret{shp, dtype::Float32()};
  49. megdnn_assert(stride.size() == shp.ndim);
  50. auto idx = 0;
  51. for (auto i: stride)
  52. ret.stride[idx ++] = i;
  53. return ret;
  54. }
  55. } // anonymous namespace
  56. namespace megdnn {
  57. namespace test {
  58. namespace elemwise {
  59. #define DEF_TEST(name) \
  60. template<> \
  61. void run_test<name>(Handle *handle)
  62. DEF_TEST(unary) {
  63. using Mode = ElemwiseForward::Param::Mode;
  64. Checker<ElemwiseForward> checker(handle);
  65. checker.set_param(Mode::SIN);
  66. checker.set_dtype(0, dtype::Float32()).execs({{3, 4, 1}, {}});
  67. checker.set_dtype(0, dtype::Float16()).execs({{3, 4, 1}, {}});
  68. }
  69. DEF_TEST(binary_brdcst) {
  70. auto run = [&](DType dtype) {
  71. using Mode = ElemwiseForward::Param::Mode;
  72. Checker<ElemwiseForward> checker(handle);
  73. checker.set_param(Mode::ADD);
  74. checker.set_dtype(0, dtype);
  75. checker.set_dtype(1, dtype);
  76. checker.execs({{3, 1}, {1, 3}, {3, 3}});
  77. {
  78. checker.execs({{10, 11},
  79. {10, 11},
  80. {10, 11}});
  81. //
  82. checker.execs({{2, 3, 4, 5, 6, 7},
  83. {1, 3, 1, 1, 6, 1},
  84. {2, 3, 4, 5, 6, 7}});
  85. checker.execs({{1, 3, 1, 1, 6, 1},
  86. {2, 3, 4, 5, 6, 7},
  87. {2, 3, 4, 5, 6, 7}});
  88. //
  89. checker.execs({{256, 256, 3},
  90. {1, 1, 3},
  91. {256, 256, 3}});
  92. checker.execs({{1, 1, 3},
  93. {256, 256, 3},
  94. {256, 256, 3}});
  95. //
  96. checker.execs({{8, 1, 6, 1},
  97. {1, 7, 1, 5},
  98. {8, 7, 6, 5}});
  99. checker.execs({{1, 7, 1, 5},
  100. {8, 1, 6, 1},
  101. {8, 7, 6, 5}});
  102. //
  103. checker.execs({{5, 4},
  104. {1, 1},
  105. {5, 4}});
  106. checker.execs({{1, 1},
  107. {5, 4},
  108. {5, 4}});
  109. //
  110. checker.execs({{5, 4},
  111. {1, 4},
  112. {5, 4}});
  113. checker.execs({{1, 4},
  114. {5, 4},
  115. {5, 4}});
  116. //
  117. checker.execs({{15, 3, 5},
  118. {15, 1, 5},
  119. {15, 3, 5}});
  120. checker.execs({{15, 1, 5},
  121. {15, 3, 5},
  122. {15, 3, 5}});
  123. //
  124. checker.execs({{15, 3, 5},
  125. {1, 3, 5},
  126. {15, 3, 5}});
  127. checker.execs({{1, 3, 5},
  128. {15, 3, 5},
  129. {15, 3, 5}});
  130. //
  131. checker.execs({{15, 3, 5},
  132. {1, 3, 1},
  133. {15, 3, 5}});
  134. checker.execs({{1, 3, 1},
  135. {15, 3, 5},
  136. {15, 3, 5}});
  137. //
  138. checker.execs({{3, 1},
  139. {1, 4},
  140. {3, 4}});
  141. // numpy broadcast
  142. checker.execs({
  143. {2, 3, 1, 5}, {4, 5}, {2, 3, 4, 5}});
  144. checker.execs({
  145. {3, 1, 1}, {4, 5}, {3, 4, 5}});
  146. }
  147. {
  148. // 1d
  149. {
  150. auto n = 1000u;
  151. checker.execs({{n}, {n}, {n}});
  152. checker.execs({{1}, {n}, {n}});
  153. checker.execs({{n}, {1}, {n}});
  154. }
  155. // 2d
  156. {
  157. auto m = 200u, n = 100u;
  158. auto collapse = [](size_t n, bool is_collapsed) {
  159. return is_collapsed ? 1u : n;
  160. };
  161. for (auto msk = 0u; msk < 16; ++msk) {
  162. checker.execs({
  163. {collapse(m, msk&1), collapse(n,msk&2)},
  164. {collapse(m, msk&4), collapse(n,msk&8)},
  165. {}});
  166. }
  167. }
  168. // nd
  169. {
  170. checker.execs({
  171. {2, 3, 4, 5, 6},
  172. {1, 3, 1, 5, 6},
  173. {2, 3, 4, 5, 6}});
  174. checker.execs({
  175. {2, 3, 4, 5, 6},
  176. {2, 1, 4, 1, 6},
  177. {2, 3, 4, 5, 6}});
  178. }
  179. }
  180. };
  181. run(dtype::Float32());
  182. //run(dtype::Float16());
  183. }
  184. DEF_TEST(binary_non_contig) {
  185. using Mode = ElemwiseForward::Param::Mode;
  186. Checker<ElemwiseForward> checker(handle);
  187. checker.set_param(Mode::ADD);
  188. TensorLayout ly{{2, 3}, dtype::Float32()};
  189. ly.stride[0] = 4;
  190. checker.execl({ly, ly, {{2, 3}, dtype::Float32()}});
  191. }
  192. DEF_TEST(ternary) {
  193. using Mode = ElemwiseForward::Param::Mode;
  194. Checker<ElemwiseForward> checker(handle);
  195. checker.set_param(Mode::COND_LEQ_MOV);
  196. checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  197. checker.set_dtype(0, dtype::Float32()).
  198. set_dtype(1, dtype::Float32()).
  199. set_dtype(2, dtype::Float32()).
  200. execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  201. checker.set_dtype(0, dtype::Float16()).
  202. set_dtype(1, dtype::Float16()).
  203. set_dtype(2, dtype::Float16()).
  204. set_dtype(3, dtype::Float16()).
  205. execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
  206. checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
  207. checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
  208. ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}),
  209. MegDNNError);
  210. ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}),
  211. MegDNNError);
  212. }
  213. DEF_TEST(ternary_non_contig) {
  214. using Mode = ElemwiseForward::Param::Mode;
  215. Checker<ElemwiseForward> checker(handle);
  216. checker.set_param(Mode::COND_LEQ_MOV);
  217. TensorLayout ly{{2, 3}, dtype::Float32()};
  218. ly.stride[0] = 4;
  219. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  220. }
  221. DEF_TEST(fuse_mul_add3) {
  222. using Mode = ElemwiseForward::Param::Mode;
  223. Checker<ElemwiseForward> checker(handle);
  224. checker.set_param(Mode::FUSE_MUL_ADD3)
  225. .set_extra_opr_impl(fma3_extra_opr_impl);
  226. auto make_shape = [](const TensorShape &s0, const TensorShape &s1,
  227. const TensorShape &s2) {
  228. TensorShape dest;
  229. dest.ndim = s0.ndim;
  230. for (size_t i = 0; i < dest.ndim; ++ i) {
  231. auto a = i < s0.ndim ? s0[i] : 1;
  232. auto b = i < s1.ndim ? s1[i] : 1;
  233. dest[i] = std::max(a, b);
  234. }
  235. return TensorShapeArray{s0, s1, s2, dest};
  236. };
  237. checker.exec(make_shape({2, 1}, {2, 2}, {2, 2}));
  238. checker.exec(make_shape({2, 2}, {2, 1}, {2, 2}));
  239. checker.exec(make_shape({2, 2}, {2, 2}, {1}));
  240. checker.exec(make_shape({3, 1}, {1, 3}, {3, 1}));
  241. checker.exec(make_shape(
  242. {2, 1, 2, 1, 2, 1},
  243. {1, 2, 1, 2, 1, 2},
  244. {1}));
  245. checker.exec(make_shape({1, 1, 3}, {5, 8, 1}, {5, 8, 1}));
  246. checker.exec(make_shape({1, 192, 9, 16}, {1}, {1, 192, 9, 16}));
  247. }
  248. DEF_TEST(fuse_mul_add3_non_contig) {
  249. using Mode = ElemwiseForward::Param::Mode;
  250. Checker<ElemwiseForward> checker(handle);
  251. checker.set_param(Mode::FUSE_MUL_ADD3)
  252. .set_extra_opr_impl(fma3_extra_opr_impl);
  253. TensorLayout ly{{2, 3}, dtype::Float32()};
  254. ly.stride[0] = 4;
  255. checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
  256. }
  257. DEF_TEST(fuse_mul_add4) {
  258. using Mode = ElemwiseForward::Param::Mode;
  259. Checker<ElemwiseForward> checker(handle);
  260. checker.set_param(Mode::FUSE_MUL_ADD4)
  261. .set_extra_opr_impl(fma4_extra_opr_impl);
  262. auto make_shape = [](const TensorShape &s0, const TensorShape &s1,
  263. bool swap = false) {
  264. TensorShape dest;
  265. dest.ndim = s0.ndim;
  266. for (size_t i = 0; i < dest.ndim; ++ i) {
  267. auto a = i < s0.ndim ? s0[i] : 1;
  268. auto b = i < s1.ndim ? s1[i] : 1;
  269. dest[i] = std::max(a, b);
  270. }
  271. TensorShapeArray ret{s0, s1, s0, s1, dest};
  272. if (swap)
  273. std::swap(ret[2], ret[3]);
  274. return ret;
  275. };
  276. checker.exec(make_shape({2, 2}, {2, 2}));
  277. checker.exec(make_shape({3, 1}, {1, 3}));
  278. checker.exec(make_shape(
  279. {2, 1, 2, 1, 2, 1},
  280. {1, 2, 1, 2, 1, 2}));
  281. checker.exec(make_shape({4, 2}, {1, 2}, true));
  282. }
  283. DEF_TEST(rmulh) {
  284. using Mode = ElemwiseForward::Param::Mode;
  285. Checker<ElemwiseForward> checker(handle);
  286. auto run_for_dtype = [&checker](auto dtype) {
  287. auto minv = DTypeTrait<decltype(dtype)>::min();
  288. auto maxv = DTypeTrait<decltype(dtype)>::max();
  289. UniformIntRNG rng0{minv, maxv};
  290. UniformIntRNG rngM{(maxv >> 1) + 1, maxv};
  291. checker.set_param({Mode::RMULH})
  292. .set_dtype(0, dtype)
  293. .set_dtype(1, dtype)
  294. .set_dtype(2, dtype)
  295. .set_rng(0, &rng0)
  296. .set_rng(1, &rngM);
  297. checker.execs({{7, 9, 11, 13}, {1}, {}})
  298. .execs({{16, 3, 256, 256}, {1}, {}})
  299. .execs({{2, 3, 1, 7}, {2, 3, 1, 7}, {}})
  300. .execs({{9, 5, 4}, {1, 5, 1}, {}})
  301. .execs({{233}, {1}, {}});
  302. };
  303. run_for_dtype(dtype::Int8());
  304. run_for_dtype(dtype::Int16());
  305. run_for_dtype(dtype::Int32());
  306. }
  307. /* ============= migrated from x86 tests ============= */
  308. #define UNARY_TEST_CASE(_optr) \
  309. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  310. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  311. #define BUILD_UNARY_TEST_CASE_INT \
  312. UNARY_TEST_CASE(RELU) \
  313. UNARY_TEST_CASE(ABS)
  314. #define BUILD_UNARY_TEST_CASE_FLOAT \
  315. UNARY_TEST_CASE(ABS) \
  316. UNARY_TEST_CASE(LOG) \
  317. UNARY_TEST_CASE(COS) \
  318. UNARY_TEST_CASE(SIN) \
  319. UNARY_TEST_CASE(FLOOR) \
  320. UNARY_TEST_CASE(CEIL) \
  321. UNARY_TEST_CASE(SIGMOID) \
  322. UNARY_TEST_CASE(EXP) \
  323. UNARY_TEST_CASE(TANH) \
  324. UNARY_TEST_CASE(FAST_TANH) \
  325. UNARY_TEST_CASE(RELU) \
  326. UNARY_TEST_CASE(ROUND)
  327. DEF_TEST(unary1) {
  328. using Mode = ElemwiseForward::Param::Mode;
  329. Checker<ElemwiseForward> checker(handle);
  330. // case int
  331. checker.set_dtype(0, dtype::Int8());
  332. BUILD_UNARY_TEST_CASE_INT
  333. checker.set_dtype(0, dtype::Int16());
  334. BUILD_UNARY_TEST_CASE_INT
  335. checker.set_dtype(0, dtype::Int32());
  336. BUILD_UNARY_TEST_CASE_INT
  337. // case float
  338. UniformFloatRNG rng(1e-2, 6e1);
  339. checker.set_rng(0, &rng);
  340. checker.set_epsilon(1e-5);
  341. checker.set_dtype(0, dtype::Float32());
  342. BUILD_UNARY_TEST_CASE_FLOAT
  343. }
  344. #undef UNARY_TEST_CASE
  345. #undef BUILD_UNARY_TEST_CASE_INT
  346. #undef BUILD_UNARY_TEST_CASE_FLOAT
  347. #define BINARY_TEST_CASE(_optr) \
  348. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  349. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  350. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  351. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  352. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  353. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  354. #define BUILD_BINARY_TEST_CASE \
  355. BINARY_TEST_CASE(MIN) \
  356. BINARY_TEST_CASE(MAX)
  357. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  358. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  359. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  360. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  361. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  362. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  363. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  364. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  365. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  366. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  367. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  368. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  369. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  370. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}}); \
  371. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  372. BINARY_COMPLATE_TEST_CASE(ADD) \
  373. BINARY_COMPLATE_TEST_CASE(MUL) \
  374. BINARY_COMPLATE_TEST_CASE(MAX) \
  375. BINARY_COMPLATE_TEST_CASE(MIN) \
  376. BINARY_COMPLATE_TEST_CASE(SUB)
  377. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  378. BINARY_COMPLATE_TEST_CASE(POW) \
  379. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  380. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  381. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  382. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU) \
  383. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_H_SWISH) \
  384. BINARY_COMPLATE_TEST_CASE(FAST_TANH_GRAD) \
  385. BINARY_COMPLATE_TEST_CASE(H_SWISH_GRAD)
  386. DEF_TEST(binary1) {
  387. using Mode = ElemwiseForward::Param::Mode;
  388. Checker<ElemwiseForward> checker(handle);
  389. // case float
  390. UniformFloatRNG rng(1e-5, 7e1);
  391. checker.set_rng(0, &rng);
  392. checker.set_epsilon(1e-5);
  393. checker.set_dtype(0, dtype::Float32());
  394. checker.set_dtype(1, dtype::Float32());
  395. BUILD_BINARY_COMPLATE_TEST_CASE
  396. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  397. // case int
  398. checker.set_dtype(0, dtype::Int8());
  399. checker.set_dtype(1, dtype::Int8());
  400. BUILD_BINARY_TEST_CASE
  401. BUILD_BINARY_COMPLATE_TEST_CASE
  402. checker.set_dtype(0, dtype::Int16());
  403. checker.set_dtype(1, dtype::Int16());
  404. BUILD_BINARY_TEST_CASE
  405. BUILD_BINARY_COMPLATE_TEST_CASE
  406. checker.set_dtype(0, dtype::Int32());
  407. checker.set_dtype(1, dtype::Int32());
  408. BUILD_BINARY_TEST_CASE
  409. BUILD_BINARY_COMPLATE_TEST_CASE
  410. }
  411. #undef BINARY_TEST_CASE
  412. #undef BUILD_BINARY_TEST_CASE
  413. #undef BINARY_COMPLATE_TEST_CASE
  414. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  415. #undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  416. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  417. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  418. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  419. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  420. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  421. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  422. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  423. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  424. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
  425. #define BUILD_TERNARY_COMPLATE_TEST_CASE \
  426. TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  427. DEF_TEST(ternary1) {
  428. using Mode = ElemwiseForward::Param::Mode;
  429. Checker<ElemwiseForward> checker(handle);
  430. // case int
  431. checker.set_dtype(0, dtype::Int8());
  432. checker.set_dtype(1, dtype::Int8());
  433. checker.set_dtype(2, dtype::Int8());
  434. //BUILD_TERNARY_TEST_CASE
  435. BUILD_TERNARY_COMPLATE_TEST_CASE
  436. checker.set_dtype(0, dtype::Int16());
  437. checker.set_dtype(1, dtype::Int16());
  438. checker.set_dtype(2, dtype::Int16());
  439. //BUILD_TERNARY_TEST_CASE
  440. BUILD_TERNARY_COMPLATE_TEST_CASE
  441. checker.set_dtype(0, dtype::Int32());
  442. checker.set_dtype(1, dtype::Int32());
  443. checker.set_dtype(2, dtype::Int32());
  444. //BUILD_TERNARY_TEST_CASE
  445. BUILD_TERNARY_COMPLATE_TEST_CASE
  446. // case float
  447. UniformFloatRNG rng(1e-5, 7e1);
  448. checker.set_rng(0, &rng);
  449. checker.set_epsilon(1e-5);
  450. checker.set_dtype(0, dtype::Float32());
  451. checker.set_dtype(1, dtype::Float32());
  452. checker.set_dtype(2, dtype::Float32());
  453. //BUILD_TERNARY_TEST_CASE
  454. BUILD_TERNARY_COMPLATE_TEST_CASE
  455. //TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  456. }
  457. #undef TERNARY_COMPLATE_TEST_CASE
  458. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  459. /* ============= migrated from arm tests ============= */
  460. #define UNARY_TEST_CASE(_optr) \
  461. checker.set_param(Mode::_optr).execs({{1, 129}, {}}); \
  462. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  463. #define BUILD_UNARY_TEST_CASE_INT \
  464. UNARY_TEST_CASE(RELU) \
  465. UNARY_TEST_CASE(ABS) \
  466. UNARY_TEST_CASE(NEGATE)
  467. #define BUILD_UNARY_TEST_CASE_FLOAT \
  468. BUILD_UNARY_TEST_CASE_INT \
  469. UNARY_TEST_CASE(SIGMOID) \
  470. UNARY_TEST_CASE(EXP) \
  471. UNARY_TEST_CASE(TANH) \
  472. UNARY_TEST_CASE(FAST_TANH) \
  473. UNARY_TEST_CASE(H_SWISH)
  474. DEF_TEST(unary2) {
  475. using Mode = ElemwiseForward::Param::Mode;
  476. Checker<ElemwiseForward> checker(handle);
  477. // case int
  478. checker.set_dtype(0, dtype::Int8());
  479. BUILD_UNARY_TEST_CASE_INT
  480. checker.set_dtype(0, dtype::Int16());
  481. BUILD_UNARY_TEST_CASE_INT
  482. checker.set_dtype(0, dtype::Int32());
  483. BUILD_UNARY_TEST_CASE_INT
  484. // case float
  485. {
  486. UniformFloatRNG rng(1e-5, 7e1);
  487. checker.set_rng(0, &rng);
  488. checker.set_epsilon(1e-5);
  489. checker.set_dtype(0, dtype::Float32());
  490. BUILD_UNARY_TEST_CASE_FLOAT
  491. }
  492. {
  493. UniformFloatRNG rng(1e-2, 1e1);
  494. checker.set_rng(0, &rng);
  495. checker.set_epsilon(6e-3);
  496. checker.set_dtype(0, dtype::Float16());
  497. BUILD_UNARY_TEST_CASE_FLOAT
  498. }
  499. // tanh NaN bug case
  500. {
  501. UniformFloatRNG rng(100, 200);
  502. checker.set_rng(0, &rng);
  503. checker.set_epsilon(1e-5);
  504. checker.set_dtype(0, dtype::Float32());
  505. checker.set_param(Mode::TANH).execs({{1, 1025}, {}});
  506. checker.set_param(Mode::TANH).execs({{1, 7}, {}});
  507. }
  508. }
  509. #undef UNARY_TEST_CASE
  510. #undef BUILD_UNARY_TEST_CASE_INT
  511. #undef BUILD_UNARY_TEST_CASE_FLOAT
  512. #define BINARY_TEST_CASE(_optr) \
  513. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  514. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  515. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  516. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  517. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  518. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  519. #define BUILD_BINARY_TEST_CASE \
  520. BINARY_TEST_CASE(MIN) \
  521. BINARY_TEST_CASE(MAX)
  522. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  523. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  524. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  525. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  526. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  527. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  528. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  529. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  530. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  531. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  532. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  533. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  534. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  535. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  536. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  537. BINARY_COMPLATE_TEST_CASE(ADD) \
  538. BINARY_COMPLATE_TEST_CASE(MUL) \
  539. BINARY_COMPLATE_TEST_CASE(MAX) \
  540. BINARY_COMPLATE_TEST_CASE(MIN) \
  541. BINARY_COMPLATE_TEST_CASE(SUB) \
  542. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  543. DEF_TEST(binary2) {
  544. using Mode = ElemwiseForward::Param::Mode;
  545. Checker<ElemwiseForward> checker(handle);
  546. // case float
  547. UniformFloatRNG rng(1e-5, 7e1);
  548. checker.set_rng(0, &rng);
  549. checker.set_epsilon(1e-5);
  550. checker.set_dtype(0, dtype::Float32());
  551. checker.set_dtype(1, dtype::Float32());
  552. BUILD_BINARY_COMPLATE_TEST_CASE
  553. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID)
  554. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH)
  555. // case int
  556. checker.set_dtype(0, dtype::Int8());
  557. checker.set_dtype(1, dtype::Int8());
  558. //BUILD_BINARY_TEST_CASE
  559. BUILD_BINARY_COMPLATE_TEST_CASE
  560. checker.set_dtype(0, dtype::Int16());
  561. checker.set_dtype(1, dtype::Int16());
  562. //BUILD_BINARY_TEST_CASE
  563. BUILD_BINARY_COMPLATE_TEST_CASE
  564. checker.set_dtype(0, dtype::Int32());
  565. checker.set_dtype(1, dtype::Int32());
  566. BUILD_BINARY_TEST_CASE
  567. BUILD_BINARY_COMPLATE_TEST_CASE
  568. // case float
  569. checker.set_rng(0, &rng);
  570. checker.set_epsilon(1e-5);
  571. checker.set_dtype(0, dtype::Float32());
  572. checker.set_dtype(1, dtype::Float32());
  573. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  574. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  575. // commutable
  576. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  577. BUILD_BINARY_TEST_CASE
  578. BUILD_BINARY_COMPLATE_TEST_CASE
  579. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  580. {
  581. UniformFloatRNG rng(1e-3, 3e1);
  582. checker.set_rng(0, &rng);
  583. checker.set_rng(1, &rng);
  584. checker.set_epsilon(1e-3);
  585. checker.set_dtype(0, dtype::Float16());
  586. checker.set_dtype(1, dtype::Float16());
  587. checker.set_param(Mode::FUSE_ADD_SIGMOID).execs({{3, 4, 7}, {1}, {}});
  588. checker.set_param(Mode::FUSE_ADD_TANH).execs({{3, 4, 7}, {1}, {}});
  589. BUILD_BINARY_TEST_CASE
  590. BUILD_BINARY_COMPLATE_TEST_CASE
  591. BINARY_COMPLATE_TEST_CASE(TRUE_DIV)
  592. // commutable
  593. checker.set_param(Mode::TRUE_DIV).execs({{1}, {4}, {}});
  594. }
  595. }
  596. #undef BINARY_TEST_CASE
  597. #undef BUILD_BINARY_TEST_CASE
  598. #undef BINARY_COMPLATE_TEST_CASE
  599. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  600. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  601. checker.set_param(Mode::_optr).execs({{1, 123, 1}, {300, 123, 253}, {300, 123, 253}, {}}); \
  602. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  603. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  604. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  605. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  606. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  607. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  608. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  609. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
  610. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {1, 1, 1}, {3, 4, 1}, {}}); \
  611. #define BUILD_TERNARY_COMPLATE_TEST_CASE \
  612. TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  613. DEF_TEST(ternary2) {
  614. using Mode = ElemwiseForward::Param::Mode;
  615. Checker<ElemwiseForward> checker(handle);
  616. // case int
  617. checker.set_dtype(0, dtype::Int8());
  618. checker.set_dtype(1, dtype::Int8());
  619. checker.set_dtype(2, dtype::Int8());
  620. BUILD_TERNARY_COMPLATE_TEST_CASE
  621. checker.set_dtype(0, dtype::Int16());
  622. checker.set_dtype(1, dtype::Int16());
  623. checker.set_dtype(2, dtype::Int16());
  624. BUILD_TERNARY_COMPLATE_TEST_CASE
  625. checker.set_dtype(0, dtype::Int32());
  626. checker.set_dtype(1, dtype::Int32());
  627. checker.set_dtype(2, dtype::Int32());
  628. BUILD_TERNARY_COMPLATE_TEST_CASE
  629. // case float
  630. UniformFloatRNG rng(1e-5, 7e1);
  631. checker.set_rng(0, &rng);
  632. checker.set_epsilon(1e-5);
  633. checker.set_dtype(0, dtype::Float32());
  634. checker.set_dtype(1, dtype::Float32());
  635. checker.set_dtype(2, dtype::Float32());
  636. BUILD_TERNARY_COMPLATE_TEST_CASE
  637. {
  638. UniformFloatRNG rng(1e-3, 3e1);
  639. checker.set_rng(0, &rng);
  640. checker.set_rng(1, &rng);
  641. checker.set_rng(2, &rng);
  642. checker.set_epsilon(1e-3);
  643. checker.set_dtype(0, dtype::Float16());
  644. checker.set_dtype(1, dtype::Float16());
  645. checker.set_dtype(2, dtype::Float16());
  646. BUILD_TERNARY_COMPLATE_TEST_CASE
  647. }
  648. }
  649. #undef TERNARY_COMPLATE_TEST_CASE
  650. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  651. /* ============= migrated from fallback tests ============= */
  652. DEF_TEST(unary3) {
  653. Checker<Elemwise> checker(handle);
  654. auto make_layouts = [](
  655. const TensorShape &shp, std::initializer_list<ptrdiff_t> stride) ->
  656. TensorLayoutArray{
  657. return {make_layout(shp, stride), {shp, dtype::Float32()}};
  658. };
  659. checker.set_param({Elemwise::Mode::SIN});
  660. checker.exec(make_layouts({2, 2}, {2, 1}));
  661. checker.exec(make_layouts({4}, {3}));
  662. }
  663. DEF_TEST(binary3) {
  664. Checker<Elemwise> checker(handle);
  665. checker.set_param({Elemwise::Mode::ADD});
  666. auto run = [&](
  667. const TensorShape &shp0, std::initializer_list<ptrdiff_t> stride0,
  668. const TensorShape &shp1, std::initializer_list<ptrdiff_t> stride1) {
  669. TensorShape shpo;
  670. Elemwise::deduce_shape({shp0, shp1}, shpo);
  671. auto ly0 = make_layout(shp0, stride0),
  672. ly1 = make_layout(shp1, stride1),
  673. lyo = TensorLayout{shpo, dtype::Float32()};
  674. checker.execl({ly0, ly1, lyo});
  675. checker.execl({ly1, ly0, lyo});
  676. };
  677. run({2, 2}, {2, 1}, {2, 2}, {2, 1});
  678. run({1}, {1}, {3, 3}, {1, 2});
  679. run({3, 4, 5}, {40, 10, 2}, {1, 4, 1}, {1, 1, 1});
  680. }
  681. DEF_TEST(all_modes) {
  682. // test correctness of all elemwise modes
  683. Checker<Elemwise> checker(handle);
  684. TensorShapeArray shapes;
  685. UniformFloatRNG default_rng_f32{-100.f, 100.f}, pos_rng_f32{.1f, 1000.f},
  686. small_pos_rng_f32{.1f, .10f}, small_rng_f32{-3.f, 3.f},
  687. abslt1_rng_f32{-1.f, 1.f}, uniform_0_2_rng{0.f, 2.f},
  688. tanh_rng_f32{-5.f, 5.f};
  689. UniformFloatNonZeroRNG nonzero_rng_f32{.1f, 1000.f},
  690. big_nonzero_rng_f32{100.f, 1000.f};
  691. UniformIntRNG default_rng_i32{-100, 100}, small_rng_i32{-2, 2},
  692. shift_rng_i32_i32{0, 31}, shift_rng_i32_i8{0, 7};
  693. UniformIntNonZeroRNG nonzero_rng_i32{1, 100};
  694. using Mode = Elemwise::Mode;
  695. auto should_ignore = [handle](Mode mode) {
  696. MEGDNN_MARK_USED_VAR(mode);
  697. return false;
  698. };
  699. for (int mode_nr = 0;
  700. mode_nr < static_cast<int>(Elemwise::Param::MODE_NR_MEMBER);
  701. ++mode_nr) {
  702. auto mode = static_cast<Mode>(mode_nr);
  703. // ignore unsupported modes
  704. if (should_ignore(mode)) {
  705. continue;
  706. }
  707. checker.set_param({mode});
  708. auto&& trait = Elemwise::ModeTrait::from_mode(mode);
  709. shapes.resize(trait.arity + 1);
  710. for (size_t i = 0; i < shapes.size() - 1; ++i) {
  711. shapes[i] = {3, 9, 7};
  712. }
  713. auto do_run = [&](DType dtype, float eps = 1e-3) {
  714. // limit value ranges for some modes
  715. if (mode == Mode::LOG || mode == Mode::LOG1P) {
  716. checker.set_rng(0, &pos_rng_f32);
  717. } else if (mode == Mode::POW) {
  718. checker.set_rng(0, &small_pos_rng_f32);
  719. checker.set_rng(1, &small_rng_f32);
  720. } else if (mode == Mode::EXP || mode == Mode::EXPM1) {
  721. checker.set_rng(0, &small_rng_f32);
  722. } else if (mode == Mode::FAST_TANH) {
  723. checker.set_rng(0, &tanh_rng_f32);
  724. } else if (mode == Mode::LOG_SUM_EXP) {
  725. // check numerical stability with large values
  726. checker.set_rng(0, &big_nonzero_rng_f32);
  727. checker.set_rng(1, &big_nonzero_rng_f32);
  728. } else if (mode == Mode::ASIN || mode == Mode::ACOS ||
  729. mode == Mode::SIGMOID_GRAD || mode == Mode::TANH_GRAD ||
  730. mode == Mode::ERFINV) {
  731. checker.set_rng(0, &abslt1_rng_f32);
  732. checker.set_rng(1, &default_rng_f32);
  733. } else if (mode == Mode::ERFCINV) {
  734. checker.set_rng(0, &uniform_0_2_rng);
  735. } else if (mode == Mode::MOD || mode == Mode::TRUE_DIV ||
  736. mode == Mode::FLOOR_DIV) {
  737. if (dtype.category() == DTypeCategory::INT) {
  738. checker.set_rng(0, &default_rng_i32);
  739. checker.set_rng(1, &nonzero_rng_i32);
  740. } else {
  741. checker.set_rng(0, &default_rng_f32);
  742. checker.set_rng(1, &nonzero_rng_f32);
  743. }
  744. } else if (mode == Mode::EQ) {
  745. checker.set_rng(0, &small_rng_i32);
  746. checker.set_rng(1, &small_rng_i32);
  747. } else if (mode == Mode::SHL || mode == Mode::SHR) {
  748. checker.set_rng(0, &default_rng_i32);
  749. if (dtype.size() == 4) {
  750. checker.set_rng(1, &shift_rng_i32_i32);
  751. } else {
  752. megdnn_assert(dtype.size() == 1);
  753. checker.set_rng(1, &shift_rng_i32_i8);
  754. }
  755. } else if (mode == Mode::ATAN2) {
  756. checker.set_rng(0, &nonzero_rng_f32);
  757. checker.set_rng(1, &nonzero_rng_f32);
  758. } else {
  759. RNG* rng;
  760. if (dtype.category() == DTypeCategory::INT) {
  761. rng = &default_rng_i32;
  762. } else {
  763. rng = &default_rng_f32;
  764. }
  765. for (size_t i = 0; i < shapes.size(); ++i) {
  766. checker.set_rng(i, rng);
  767. }
  768. }
  769. checker.set_epsilon(eps);
  770. for (size_t i = 0; i < shapes.size(); ++i) {
  771. checker.set_dtype(i, dtype);
  772. }
  773. EXPECT_NO_THROW(checker.execs(shapes));
  774. if (!::testing::Test::HasFailure() && shapes.size() == 3) {
  775. // channel bcast
  776. shapes[1][0] = 1;
  777. shapes[1][2] = 1;
  778. EXPECT_NO_THROW(checker.execs(shapes));
  779. if (!::testing::Test::HasFailure()) {
  780. // scalar bcast
  781. shapes[1][1] = 1;
  782. EXPECT_NO_THROW(checker.execs(shapes));
  783. }
  784. }
  785. if (::testing::Test::HasFailure()) {
  786. printf("failed on mode=%d(%s) dtype=%s\n", mode_nr, trait.name,
  787. dtype.name());
  788. for (auto&& i : shapes) {
  789. printf("ishape: %s\n", i.to_string().c_str());
  790. }
  791. return false;
  792. }
  793. return true;
  794. };
  795. #define run(args...) \
  796. do { \
  797. if (!do_run(args)) { \
  798. return; \
  799. } \
  800. } while (0)
  801. if (trait.allow_int) {
  802. run(dtype::Int8{});
  803. run(dtype::Int32{});
  804. }
  805. if (trait.allow_float) {
  806. DNN_FLOAT16_SELECT(
  807. run(dtype::Float16{},
  808. mode == Mode::FAST_TANH_GRAD ? 0.5 : 0.05), );
  809. run(dtype::Float32{});
  810. }
  811. }
  812. #undef run
  813. }
  814. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  815. checker.set_param(Mode::_optr) \
  816. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
  817. checker.set_param(Mode::_optr) \
  818. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
  819. checker.set_param(Mode::_optr) \
  820. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});
  821. #define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
  822. checker.set_param(Mode::_optr) \
  823. .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});
  824. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  825. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
  826. UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);
  827. #define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
  828. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
  829. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
  830. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
  831. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
  832. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
  833. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
  834. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
  835. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
  836. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
  837. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
  838. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
  839. UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)
  840. DEF_TEST(unary_negative_stride) {
  841. using Mode = ElemwiseForward::Param::Mode;
  842. Checker<ElemwiseForward> checker(handle);
  843. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  844. UniformFloatRNG rng(1e-2, 6e1);
  845. checker.set_rng(0, &rng);
  846. checker.set_epsilon(1e-5);
  847. BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
  848. }
  849. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  850. #undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  851. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
  852. #undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
  853. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
  854. checker.set_param(Mode::_optr) \
  855. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
  856. {{1, 4, 1}, dtype::Int8()}, \
  857. {}}); \
  858. checker.set_param(Mode::_optr) \
  859. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
  860. {{1, 4, 1}, dtype::Int16()}, \
  861. {}}); \
  862. checker.set_param(Mode::_optr) \
  863. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
  864. {{1, 4, 1}, dtype::Int32()}, \
  865. {}});
  866. #define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
  867. checker.set_param(Mode::_optr) \
  868. .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
  869. {{1, 4, 1}, dtype::Float32()}, \
  870. {}});
  871. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
  872. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
  873. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
  874. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
  875. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
  876. BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)
  877. #define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
  878. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
  879. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
  880. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
  881. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
  882. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
  883. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
  884. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
  885. BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)
  886. DEF_TEST(binary_negative_stride) {
  887. using Mode = ElemwiseForward::Param::Mode;
  888. Checker<ElemwiseForward> checker(handle);
  889. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;
  890. UniformFloatRNG rng(1e-2, 2e1);
  891. checker.set_rng(0, &rng);
  892. checker.set_epsilon(1e-5);
  893. BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
  894. }
  895. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  896. #undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  897. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
  898. #undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
  899. DEF_TEST(ternary_negative_stride) {
  900. using Mode = ElemwiseForward::Param::Mode;
  901. Checker<ElemwiseForward> checker(handle);
  902. checker.set_param(Mode::FUSE_MUL_ADD3);
  903. checker.execl({{{1, 7}, {-7, -1}, dtype::Int8()},
  904. {{1, 7}, {-3, -1}, dtype::Int8()},
  905. {{1, 7}, {-7, -1}, dtype::Int8()},
  906. {}});
  907. checker.execl({{{1, 7}, {-7, -1}, dtype::Int16()},
  908. {{1, 7}, {-3, -1}, dtype::Int16()},
  909. {{1, 7}, {-7, -1}, dtype::Int16()},
  910. {}});
  911. checker.execl({{{1, 7}, {-7, -1}, dtype::Int32()},
  912. {{1, 7}, {-3, -1}, dtype::Int32()},
  913. {{1, 7}, {-7, -1}, dtype::Int32()},
  914. {}});
  915. UniformFloatRNG rng(1e-2, 2e1);
  916. checker.set_rng(0, &rng);
  917. checker.set_epsilon(1e-5);
  918. checker.execl({{{1, 7}, {-7, -1}, dtype::Float32()},
  919. {{1, 7}, {-3, -1}, dtype::Float32()},
  920. {{1, 7}, {-7, -1}, dtype::Float32()},
  921. {}});
  922. }
  923. TEST(TEST_ELEMWISE, MODE_TRAIT) {
  924. using M = Elemwise::Mode;
  925. using T = Elemwise::ModeTrait;
  926. ASSERT_EQ(1u, T::from_mode(M::RELU).arity);
  927. ASSERT_EQ(2u, T::from_mode(M::ADD).arity);
  928. ASSERT_EQ(3u, T::from_mode(M::FUSE_MUL_ADD3).arity);
  929. ASSERT_EQ(4u, T::from_mode(M::FUSE_MUL_ADD4).arity);
  930. ASSERT_TRUE(T::from_mode(M::ADD).commutable);
  931. ASSERT_FALSE(T::from_mode(M::TRUE_DIV).commutable);
  932. ASSERT_TRUE(T::from_mode(M::ADD).allow_int);
  933. ASSERT_FALSE(T::from_mode(M::EXP).allow_int);
  934. ASSERT_TRUE(T::from_mode(M::ADD).allow_float);
  935. ASSERT_FALSE(T::from_mode(M::SHL).allow_float);
  936. ASSERT_TRUE(T::from_mode(M::RMULH).commutable);
  937. ASSERT_FALSE(T::from_mode(M::RMULH).allow_float);
  938. ASSERT_TRUE(T::from_mode(M::XOR).allow_bool);
  939. }
  940. } // namespace elemwise
  941. } // namespace test
  942. } // namespace megdnn
  943. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台