You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. #include "test/common/elemwise.h"
  2. #include "test/arm_common/fixture.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/task_record_check.h"
  6. #include "megdnn/opr_param_defs.h"
  7. #include "megdnn/oprs/general.h"
  8. using namespace megdnn;
  9. using namespace test;
  10. template <typename tag>
  11. class ARM_ELEMWISE : public ARM_COMMON {};
  12. TYPED_TEST_CASE(ARM_ELEMWISE, elemwise::test_types);
  13. TYPED_TEST(ARM_ELEMWISE, run) {
  14. elemwise::run_test<TypeParam>(this->handle());
  15. }
  16. template <typename tag>
  17. class ARM_ELEMWISE_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {};
  18. TYPED_TEST_CASE(ARM_ELEMWISE_MULTI_THREADS, elemwise::test_types);
  19. TYPED_TEST(ARM_ELEMWISE_MULTI_THREADS, run) {
  20. elemwise::run_test<TypeParam>(this->handle());
  21. }
  22. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
  23. using Mode = ElemwiseForward::Param::Mode;
  24. Checker<ElemwiseForward> checker(handle());
  25. checker.set_param(Mode::FUSE_MUL_ADD3);
  26. auto run = [&] {
  27. //! nchw44
  28. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  29. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  30. checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  31. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  32. checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  33. //! nchw44
  34. checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  35. checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  36. checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  37. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  38. checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  39. //! nchw88
  40. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  41. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  42. checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  43. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  44. checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  45. //! nchw88
  46. checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  47. checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  48. checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  49. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  50. checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  51. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  52. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  53. checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
  54. checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
  55. checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
  56. checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
  57. checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
  58. checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  59. checker.execs({{3, 4, 5}, {1}, {1}, {}});
  60. checker.execs({{1}, {3, 4, 5}, {1}, {}});
  61. };
  62. // case int
  63. checker.set_dtype(0, dtype::Int8());
  64. checker.set_dtype(1, dtype::Int8());
  65. checker.set_dtype(2, dtype::Int8());
  66. run();
  67. checker.set_dtype(0, dtype::Int16());
  68. checker.set_dtype(1, dtype::Int16());
  69. checker.set_dtype(2, dtype::Int16());
  70. run();
  71. checker.set_dtype(0, dtype::Int32());
  72. checker.set_dtype(1, dtype::Int32());
  73. checker.set_dtype(2, dtype::Int32());
  74. run();
  75. // case float
  76. UniformFloatRNG rng(1e-5, 7e1);
  77. checker.set_rng(0, &rng);
  78. checker.set_epsilon(1e-5);
  79. checker.set_dtype(0, dtype::Float32());
  80. checker.set_dtype(1, dtype::Float32());
  81. checker.set_dtype(2, dtype::Float32());
  82. run();
  83. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  84. // case half
  85. UniformFloatRNG rng_float16(1, 10);
  86. checker.set_rng(0, &rng_float16);
  87. checker.set_epsilon(1e-2);
  88. checker.set_dtype(0, dtype::Float16());
  89. checker.set_dtype(1, dtype::Float16());
  90. checker.set_dtype(2, dtype::Float16());
  91. run();
  92. #endif
  93. }
  94. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
  95. using Mode = ElemwiseForward::Param::Mode;
  96. Checker<ElemwiseForward> checker(handle());
  97. auto run = [&]() {
  98. // VEC_BCAST101x not PowOp
  99. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  100. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  101. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  102. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  103. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  104. checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  105. checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  106. checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  107. checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  108. checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  109. checker.set_param(Mode::FUSE_ADD_RELU)
  110. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  111. checker.set_param(Mode::FUSE_ADD_RELU)
  112. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  113. checker.set_param(Mode::FUSE_ADD_RELU)
  114. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  115. checker.set_param(Mode::FUSE_ADD_RELU)
  116. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  117. checker.set_param(Mode::FUSE_ADD_RELU)
  118. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  119. // BCAST101x_VEC not PowOp
  120. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  121. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  122. checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  123. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  124. checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  125. checker.set_param(Mode::FUSE_ADD_RELU)
  126. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  127. checker.set_param(Mode::FUSE_ADD_RELU)
  128. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  129. checker.set_param(Mode::FUSE_ADD_RELU)
  130. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  131. checker.set_param(Mode::FUSE_ADD_RELU)
  132. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  133. checker.set_param(Mode::FUSE_ADD_RELU)
  134. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  135. };
  136. checker.set_dtype(0, dtype::Int8());
  137. checker.set_dtype(1, dtype::Int8());
  138. run();
  139. checker.set_dtype(0, dtype::Int16());
  140. checker.set_dtype(1, dtype::Int16());
  141. run();
  142. checker.set_dtype(0, dtype::Int32());
  143. checker.set_dtype(1, dtype::Int32());
  144. run();
  145. }
  146. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
  147. using Mode = ElemwiseForward::Param::Mode;
  148. Checker<ElemwiseForward> checker(handle());
  149. UniformFloatRNG rng(1e-5, 7e1);
  150. checker.set_rng(0, &rng);
  151. checker.set_epsilon(1e-5);
  152. checker.set_dtype(0, dtype::Float32());
  153. checker.set_dtype(1, dtype::Float32());
  154. checker.set_param(Mode::FUSE_ADD_RELU)
  155. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  156. checker.set_param(Mode::FUSE_ADD_RELU)
  157. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  158. checker.set_param(Mode::FUSE_ADD_RELU)
  159. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  160. checker.set_param(Mode::FUSE_ADD_RELU)
  161. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  162. checker.set_param(Mode::FUSE_ADD_RELU)
  163. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  164. checker.set_param(Mode::FUSE_ADD_RELU)
  165. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  166. checker.set_param(Mode::FUSE_ADD_RELU)
  167. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  168. checker.set_param(Mode::FUSE_ADD_RELU)
  169. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  170. checker.set_param(Mode::FUSE_ADD_RELU)
  171. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  172. checker.set_param(Mode::FUSE_ADD_RELU)
  173. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  174. auto run = [&](Mode mode) {
  175. // VEC_BCAST101x
  176. checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  177. checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  178. checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  179. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  180. checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  181. // BCAST101x_VEC not powOp
  182. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  183. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  184. checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  185. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  186. checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  187. };
  188. run(Mode::ADD);
  189. run(Mode::FUSE_ADD_H_SWISH);
  190. run(Mode::FUSE_ADD_RELU);
  191. run(Mode::MAX);
  192. run(Mode::MIN);
  193. run(Mode::MUL);
  194. run(Mode::SUB);
  195. run(Mode::TRUE_DIV);
  196. run(Mode::POW);
  197. }
  198. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
  199. using Mode = ElemwiseForward::Param::Mode;
  200. Checker<ElemwiseForward> checker(handle());
  201. checker.set_param(Mode::FUSE_ADD_RELU)
  202. .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  203. checker.set_param(Mode::FUSE_ADD_RELU)
  204. .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  205. checker.set_param(Mode::FUSE_ADD_RELU)
  206. .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  207. checker.set_param(Mode::FUSE_ADD_RELU)
  208. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  209. checker.set_param(Mode::FUSE_ADD_RELU)
  210. .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  211. checker.set_param(Mode::FUSE_ADD_RELU)
  212. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  213. checker.set_param(Mode::FUSE_ADD_RELU)
  214. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  215. checker.set_param(Mode::FUSE_ADD_RELU)
  216. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  217. checker.set_param(Mode::FUSE_ADD_RELU)
  218. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  219. checker.set_param(Mode::FUSE_ADD_RELU)
  220. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  221. auto run = [&](Mode mode) {
  222. // VEC_BCAST101x
  223. checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  224. checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  225. checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  226. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  227. checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  228. // BCAST101x_VEC not powOp
  229. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  230. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  231. checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  232. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  233. checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  234. };
  235. auto run_all = [&]() {
  236. run(Mode::ADD);
  237. run(Mode::FUSE_ADD_H_SWISH);
  238. run(Mode::FUSE_ADD_RELU);
  239. run(Mode::MAX);
  240. run(Mode::MIN);
  241. run(Mode::MUL);
  242. run(Mode::SUB);
  243. run(Mode::TRUE_DIV);
  244. run(Mode::POW);
  245. };
  246. {
  247. UniformFloatRNG rng(1e-5, 7e1);
  248. checker.set_rng(0, &rng);
  249. checker.set_epsilon(1e-5);
  250. checker.set_dtype(0, dtype::Float32());
  251. checker.set_dtype(1, dtype::Float32());
  252. run_all();
  253. }
  254. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  255. {
  256. UniformFloatRNG rng(1, 2);
  257. checker.set_rng(0, &rng);
  258. checker.set_epsilon(3e-3);
  259. checker.set_dtype(0, dtype::Float16());
  260. checker.set_dtype(1, dtype::Float16());
  261. run_all();
  262. }
  263. #endif
  264. }
  265. TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
  266. using Mode = ElemwiseForward::Param::Mode;
  267. Checker<ElemwiseForward> checker(handle());
  268. UniformFloatRNG rng(1e-5, 7e1);
  269. checker.set_rng(0, &rng);
  270. checker.set_epsilon(1e-5);
  271. checker.set_dtype(0, dtype::Float32());
  272. checker.set_dtype(1, dtype::Float32());
  273. //! 2 dim
  274. auto run = [&](Mode mode) {
  275. // VEC_BCAST111C
  276. checker.set_param(mode).execs({{1, 2, 2, 12}, {1, 1, 1, 12}, {}});
  277. checker.set_param(mode).execs({{2, 5, 3, 28}, {1, 1, 1, 28}, {}});
  278. checker.set_param(mode).execs({{3, 5, 8, 32}, {1, 1, 1, 32}, {}});
  279. // BCAST111C_VEC
  280. checker.set_param(mode).execs({{1, 1, 1, 12}, {1, 2, 2, 12}, {}});
  281. checker.set_param(mode).execs({{1, 1, 1, 28}, {2, 5, 3, 28}, {}});
  282. checker.set_param(mode).execs({{1, 1, 1, 32}, {3, 5, 8, 32}, {}});
  283. };
  284. run(Mode::ADD);
  285. run(Mode::MUL);
  286. run(Mode::SUB);
  287. //! 3 dim contig
  288. auto run_3d_contig = [&](Mode mode) {
  289. // BCAST111C_VEC_BCAST111C
  290. checker.set_param(mode).execs(
  291. {{1, 1, 1, 12}, {1, 2, 2, 12}, {1, 1, 1, 12}, {}});
  292. checker.set_param(mode).execs(
  293. {{1, 1, 1, 28}, {2, 5, 3, 28}, {1, 1, 1, 28}, {}});
  294. checker.set_param(mode).execs(
  295. {{1, 1, 1, 32}, {3, 5, 8, 32}, {1, 1, 1, 32}, {}});
  296. // VEC_BCAST111C_VEC
  297. checker.set_param(mode).execs(
  298. {{1, 2, 2, 12}, {1, 1, 1, 12}, {1, 2, 2, 12}, {}});
  299. checker.set_param(mode).execs(
  300. {{2, 5, 3, 28}, {1, 1, 1, 28}, {2, 5, 3, 28}, {}});
  301. checker.set_param(mode).execs(
  302. {{3, 5, 8, 32}, {1, 1, 1, 32}, {3, 5, 8, 32}, {}});
  303. };
  304. run_3d_contig(Mode::FUSE_MUL_ADD3);
  305. //! 3 dim incontig
  306. auto run_3d_incontig = [&](Mode mode) {
  307. megdnn::TensorLayout src0({1, 1, 1, 12}, dtype::Float32());
  308. megdnn::TensorLayout src1({1, 2, 2, 12}, {80, 40, 20, 1}, dtype::Float32());
  309. // BCAST111C_VEC_BCAST111C
  310. checker.set_param(mode).execl({src0, src1, src0, {}});
  311. // VEC_BCAST111C_VEC
  312. checker.set_param(mode).execl({src1, src0, src1, {}});
  313. };
  314. run_3d_incontig(Mode::FUSE_MUL_ADD3);
  315. }
  316. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_N1HW_FP32_BCAST) {
  317. using Mode = ElemwiseForward::Param::Mode;
  318. Checker<ElemwiseForward> checker(handle());
  319. UniformFloatRNG rng(1e-5, 7e1);
  320. checker.set_rng(0, &rng);
  321. checker.set_epsilon(1e-5);
  322. checker.set_dtype(0, dtype::Float32());
  323. checker.set_dtype(1, dtype::Float32());
  324. //! 2 dim
  325. auto run = [&](Mode mode) {
  326. // VEC_BCASTX0X
  327. checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}});
  328. checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}});
  329. // BCASTX0X_VEC
  330. checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}});
  331. checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}});
  332. };
  333. run(Mode::ADD);
  334. run(Mode::MUL);
  335. run(Mode::SUB);
  336. }
  337. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY_RECORD) {
  338. using Mode = ElemwiseForward::Param::Mode;
  339. TaskRecordChecker<ElemwiseForward> checker(0);
  340. checker.set_param(Mode::FUSE_MUL_ADD3);
  341. auto run = [&] {
  342. //! nchw44
  343. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  344. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  345. //! nchw88
  346. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  347. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  348. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  349. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  350. };
  351. // case int
  352. checker.set_dtype(0, dtype::Int32());
  353. checker.set_dtype(1, dtype::Int32());
  354. checker.set_dtype(2, dtype::Int32());
  355. run();
  356. // case float
  357. UniformFloatRNG rng(1e-5, 7e1);
  358. checker.set_rng(0, &rng);
  359. checker.set_epsilon(1e-5);
  360. checker.set_dtype(0, dtype::Float32());
  361. checker.set_dtype(1, dtype::Float32());
  362. checker.set_dtype(2, dtype::Float32());
  363. run();
  364. }
  365. #if MEGDNN_WITH_BENCHMARK
  366. namespace {
  367. void run_elemwise_benchmark(
  368. const TensorShapeArray& shapes, param::Elemwise::Mode mode,
  369. const char* mode_str, DType type, Handle* handle_bench) {
  370. auto handle_fallback = create_cpu_handle(1);
  371. Benchmarker<Elemwise> benchmarker_bench(handle_bench);
  372. Benchmarker<Elemwise> benchmarker_fallback(handle_fallback.get());
  373. float throughput = 0;
  374. SmallVector<TensorLayout> layouts;
  375. std::string src_strs;
  376. for (size_t i = 0; i < shapes.size(); i++) {
  377. layouts.emplace_back(shapes[i], type);
  378. throughput += layouts.back().span().dist_byte();
  379. src_strs += layouts.back().to_string();
  380. if (i != shapes.size() - 1) {
  381. src_strs += ",";
  382. }
  383. }
  384. constexpr size_t RUN = 50;
  385. benchmarker_fallback.set_times(RUN).set_display(false);
  386. benchmarker_bench.set_times(RUN).set_display(false);
  387. benchmarker_fallback.set_param(mode);
  388. benchmarker_bench.set_param(mode);
  389. TensorLayout dst_layout;
  390. auto opr = handle_bench->create_operator<Elemwise>();
  391. opr->param() = mode;
  392. opr->deduce_layout(layouts, dst_layout);
  393. float computations =
  394. dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1);
  395. throughput += dst_layout.span().dist_byte();
  396. computations *= (1e3 / (1024.0 * 1024));
  397. throughput *= (1e3 / (1024.0 * 1024));
  398. layouts.emplace_back(dst_layout);
  399. auto fallback_time = benchmarker_fallback.execl(layouts) / RUN;
  400. auto bench_time = benchmarker_bench.execl(layouts) / RUN;
  401. float fallback_flops = computations / fallback_time;
  402. float bench_flops = computations / bench_time;
  403. float fallback_thr = throughput / fallback_time;
  404. float bench_thr = throughput / bench_time;
  405. printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
  406. "%fMB/s "
  407. "computations: %fx, throughput: %fx\n",
  408. src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), mode_str,
  409. fallback_flops, fallback_thr, bench_flops, bench_thr,
  410. bench_flops / fallback_flops, bench_thr / fallback_thr);
  411. }
  412. } // namespace
  413. TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NHWC) {
  414. Benchmarker<Elemwise> benchmarker(handle());
  415. constexpr size_t RUN = 50;
  416. benchmarker.set_times(RUN).set_display(false);
  417. auto run = [&](size_t N, size_t C, size_t H, size_t W, param::Elemwise::Mode mode,
  418. const char* mode_name) {
  419. megdnn::param::Elemwise param;
  420. param.mode = mode;
  421. benchmarker.set_param(param);
  422. megdnn::TensorShape nhwc_src0{N, H, W, C};
  423. megdnn::TensorShape nhwc_src1{1, 1, 1, C};
  424. megdnn::TensorShape nchw_src0{N, C, H, W};
  425. megdnn::TensorShape nchw_src1{1, C, 1, 1};
  426. float computations = N * C * H * W;
  427. auto nhwc_time = benchmarker.execs({nhwc_src1, nhwc_src0, {}}) / RUN;
  428. auto nchw_time = benchmarker.execs({nchw_src1, nchw_src0, {}}) / RUN;
  429. auto perf_nhwc = computations / nhwc_time / 1e6;
  430. auto perf_nchw = computations / nchw_time / 1e6;
  431. printf("Elemwise Mode : %s\nNHWC : %fms %fGflops\nNCHW : %fms "
  432. "%fGflops\n",
  433. mode_name, nhwc_time, perf_nhwc, nchw_time, perf_nchw);
  434. };
  435. run(1, 120, 16, 24, param::Elemwise::Mode::ADD, "ADD");
  436. run(1, 120, 16, 24, param::Elemwise::Mode::MUL, "MUL");
  437. run(1, 120, 32, 48, param::Elemwise::Mode::ADD, "ADD");
  438. run(1, 120, 32, 48, param::Elemwise::Mode::MUL, "MUL");
  439. run(1, 120, 64, 96, param::Elemwise::Mode::ADD, "ADD");
  440. run(1, 120, 64, 96, param::Elemwise::Mode::MUL, "MUL");
  441. }
  442. #define INT_RUN(shape, mode) \
  443. run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
  444. run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
  445. run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle());
  446. #define FLOAT_RUN(shape, mode) \
  447. run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \
  448. run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle());
  449. #define BENCHMARK_CASES(shape) \
  450. INT_BENCHMARK_CASES(shape) \
  451. FLOAT_BENCHMARK_CASES(shape)
  452. TEST_F(ARM_COMMON, BENCHMARK_UNARY) {
  453. #define INT_BENCHMARK_CASES(shape) \
  454. INT_RUN(shape, Mode::RELU); \
  455. INT_RUN(shape, Mode::ABS);
  456. #define FLOAT_BENCHMARK_CASES(shape) \
  457. FLOAT_RUN(shape, Mode::RELU); \
  458. FLOAT_RUN(shape, Mode::ABS); \
  459. FLOAT_RUN(shape, Mode::SIGMOID); \
  460. FLOAT_RUN(shape, Mode::EXP); \
  461. FLOAT_RUN(shape, Mode::TANH); \
  462. FLOAT_RUN(shape, Mode::FAST_TANH);
  463. using Mode = param::Elemwise::Mode;
  464. BENCHMARK_CASES({{10000}});
  465. BENCHMARK_CASES({{50000}});
  466. #undef INT_BENCHMARK_CASES
  467. #undef FLOAT_BENCHMARK_CASES
  468. }
  469. TEST_F(ARM_COMMON, BENCHMARK_BINARY) {
  470. #define INT_BENCHMARK_CASES(shape) \
  471. INT_RUN(shape, Mode::MIN); \
  472. INT_RUN(shape, Mode::MAX); \
  473. INT_RUN(shape, Mode::ADD); \
  474. INT_RUN(shape, Mode::SUB); \
  475. INT_RUN(shape, Mode::MUL); \
  476. INT_RUN(shape, Mode::RMULH); \
  477. INT_RUN(shape, Mode::FUSE_ADD_RELU);
  478. #define FLOAT_BENCHMARK_CASES(shape) \
  479. FLOAT_RUN(shape, Mode::MIN); \
  480. FLOAT_RUN(shape, Mode::MAX); \
  481. FLOAT_RUN(shape, Mode::ADD); \
  482. FLOAT_RUN(shape, Mode::SUB); \
  483. FLOAT_RUN(shape, Mode::MUL); \
  484. FLOAT_RUN(shape, Mode::POW); \
  485. FLOAT_RUN(shape, Mode::TRUE_DIV); \
  486. FLOAT_RUN(shape, Mode::FUSE_ADD_RELU);
  487. using Mode = param::Elemwise::Mode;
  488. TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}};
  489. BENCHMARK_CASES(shapes);
  490. shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}};
  491. BENCHMARK_CASES(shapes);
  492. shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}};
  493. BENCHMARK_CASES(shapes);
  494. #undef INT_BENCHMARK_CASES
  495. #undef FLOAT_BENCHMARK_CASES
  496. }
  497. TEST_F(ARM_COMMON, BENCHMARK_TERNARY_FMA3) {
  498. #define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3);
  499. #define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3);
  500. using Mode = param::Elemwise::Mode;
  501. TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}};
  502. BENCHMARK_CASES(shapes);
  503. shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}};
  504. BENCHMARK_CASES(shapes);
  505. shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}};
  506. BENCHMARK_CASES(shapes);
  507. #undef INT_BENCHMARK_CASES
  508. #undef FLOAT_BENCHMARK_CASES
  509. }
  510. #undef BENCHMARK_CASES
  511. #undef INT_RUN
  512. #undef FLOAT_RUN
  513. #endif
  514. // vim: syntax=cpp.doxygen