You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. /**
  2. * \file dnn/test/arm_common/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/common/elemwise.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/task_record_check.h"
  17. #include "megdnn/opr_param_defs.h"
  18. #include "megdnn/oprs/general.h"
  19. using namespace megdnn;
  20. using namespace test;
  21. template <typename tag>
  22. class ARM_ELEMWISE : public ARM_COMMON {};
  23. TYPED_TEST_CASE(ARM_ELEMWISE, elemwise::test_types);
  24. TYPED_TEST(ARM_ELEMWISE, run) {
  25. elemwise::run_test<TypeParam>(this->handle());
  26. }
  27. template <typename tag>
  28. class ARM_ELEMWISE_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {};
  29. TYPED_TEST_CASE(ARM_ELEMWISE_MULTI_THREADS, elemwise::test_types);
  30. TYPED_TEST(ARM_ELEMWISE_MULTI_THREADS, run) {
  31. elemwise::run_test<TypeParam>(this->handle());
  32. }
  33. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
  34. using Mode = ElemwiseForward::Param::Mode;
  35. Checker<ElemwiseForward> checker(handle());
  36. checker.set_param(Mode::FUSE_MUL_ADD3);
  37. auto run = [&] {
  38. //! nchw44
  39. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  40. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  41. checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  42. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  43. checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  44. //! nchw44
  45. checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  46. checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  47. checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  48. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  49. checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  50. //! nchw88
  51. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  52. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  53. checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  54. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  55. checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  56. //! nchw88
  57. checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  58. checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  59. checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  60. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  61. checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  62. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  63. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  64. checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
  65. checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
  66. checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
  67. checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
  68. checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
  69. checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  70. checker.execs({{3, 4, 5}, {1}, {1}, {}});
  71. checker.execs({{1}, {3, 4, 5}, {1}, {}});
  72. };
  73. // case int
  74. checker.set_dtype(0, dtype::Int8());
  75. checker.set_dtype(1, dtype::Int8());
  76. checker.set_dtype(2, dtype::Int8());
  77. run();
  78. checker.set_dtype(0, dtype::Int16());
  79. checker.set_dtype(1, dtype::Int16());
  80. checker.set_dtype(2, dtype::Int16());
  81. run();
  82. checker.set_dtype(0, dtype::Int32());
  83. checker.set_dtype(1, dtype::Int32());
  84. checker.set_dtype(2, dtype::Int32());
  85. run();
  86. // case float
  87. UniformFloatRNG rng(1e-5, 7e1);
  88. checker.set_rng(0, &rng);
  89. checker.set_epsilon(1e-5);
  90. checker.set_dtype(0, dtype::Float32());
  91. checker.set_dtype(1, dtype::Float32());
  92. checker.set_dtype(2, dtype::Float32());
  93. run();
  94. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  95. // case half
  96. UniformFloatRNG rng_float16(1, 10);
  97. checker.set_rng(0, &rng_float16);
  98. checker.set_epsilon(1e-2);
  99. checker.set_dtype(0, dtype::Float16());
  100. checker.set_dtype(1, dtype::Float16());
  101. checker.set_dtype(2, dtype::Float16());
  102. run();
  103. #endif
  104. }
  105. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
  106. using Mode = ElemwiseForward::Param::Mode;
  107. Checker<ElemwiseForward> checker(handle());
  108. auto run = [&]() {
  109. // VEC_BCAST101x not PowOp
  110. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  111. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  112. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  113. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  114. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  115. checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  116. checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  117. checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  118. checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  119. checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  120. checker.set_param(Mode::FUSE_ADD_RELU)
  121. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  122. checker.set_param(Mode::FUSE_ADD_RELU)
  123. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  124. checker.set_param(Mode::FUSE_ADD_RELU)
  125. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  126. checker.set_param(Mode::FUSE_ADD_RELU)
  127. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  128. checker.set_param(Mode::FUSE_ADD_RELU)
  129. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  130. // BCAST101x_VEC not PowOp
  131. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  132. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  133. checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  134. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  135. checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  136. checker.set_param(Mode::FUSE_ADD_RELU)
  137. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  138. checker.set_param(Mode::FUSE_ADD_RELU)
  139. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  140. checker.set_param(Mode::FUSE_ADD_RELU)
  141. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  142. checker.set_param(Mode::FUSE_ADD_RELU)
  143. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  144. checker.set_param(Mode::FUSE_ADD_RELU)
  145. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  146. };
  147. checker.set_dtype(0, dtype::Int8());
  148. checker.set_dtype(1, dtype::Int8());
  149. run();
  150. checker.set_dtype(0, dtype::Int16());
  151. checker.set_dtype(1, dtype::Int16());
  152. run();
  153. checker.set_dtype(0, dtype::Int32());
  154. checker.set_dtype(1, dtype::Int32());
  155. run();
  156. }
  157. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
  158. using Mode = ElemwiseForward::Param::Mode;
  159. Checker<ElemwiseForward> checker(handle());
  160. UniformFloatRNG rng(1e-5, 7e1);
  161. checker.set_rng(0, &rng);
  162. checker.set_epsilon(1e-5);
  163. checker.set_dtype(0, dtype::Float32());
  164. checker.set_dtype(1, dtype::Float32());
  165. checker.set_param(Mode::FUSE_ADD_RELU)
  166. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  167. checker.set_param(Mode::FUSE_ADD_RELU)
  168. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  169. checker.set_param(Mode::FUSE_ADD_RELU)
  170. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  171. checker.set_param(Mode::FUSE_ADD_RELU)
  172. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  173. checker.set_param(Mode::FUSE_ADD_RELU)
  174. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  175. checker.set_param(Mode::FUSE_ADD_RELU)
  176. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  177. checker.set_param(Mode::FUSE_ADD_RELU)
  178. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  179. checker.set_param(Mode::FUSE_ADD_RELU)
  180. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  181. checker.set_param(Mode::FUSE_ADD_RELU)
  182. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  183. checker.set_param(Mode::FUSE_ADD_RELU)
  184. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  185. auto run = [&](Mode mode) {
  186. // VEC_BCAST101x
  187. checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  188. checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  189. checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  190. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  191. checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  192. // BCAST101x_VEC not powOp
  193. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  194. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  195. checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  196. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  197. checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  198. };
  199. run(Mode::ADD);
  200. run(Mode::FUSE_ADD_H_SWISH);
  201. run(Mode::FUSE_ADD_RELU);
  202. run(Mode::MAX);
  203. run(Mode::MIN);
  204. run(Mode::MUL);
  205. run(Mode::SUB);
  206. run(Mode::TRUE_DIV);
  207. run(Mode::POW);
  208. }
  209. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
  210. using Mode = ElemwiseForward::Param::Mode;
  211. Checker<ElemwiseForward> checker(handle());
  212. checker.set_param(Mode::FUSE_ADD_RELU)
  213. .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  214. checker.set_param(Mode::FUSE_ADD_RELU)
  215. .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  216. checker.set_param(Mode::FUSE_ADD_RELU)
  217. .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  218. checker.set_param(Mode::FUSE_ADD_RELU)
  219. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  220. checker.set_param(Mode::FUSE_ADD_RELU)
  221. .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  222. checker.set_param(Mode::FUSE_ADD_RELU)
  223. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  224. checker.set_param(Mode::FUSE_ADD_RELU)
  225. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  226. checker.set_param(Mode::FUSE_ADD_RELU)
  227. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  228. checker.set_param(Mode::FUSE_ADD_RELU)
  229. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  230. checker.set_param(Mode::FUSE_ADD_RELU)
  231. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  232. auto run = [&](Mode mode) {
  233. // VEC_BCAST101x
  234. checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  235. checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  236. checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  237. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  238. checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  239. // BCAST101x_VEC not powOp
  240. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  241. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  242. checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  243. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  244. checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  245. };
  246. auto run_all = [&]() {
  247. run(Mode::ADD);
  248. run(Mode::FUSE_ADD_H_SWISH);
  249. run(Mode::FUSE_ADD_RELU);
  250. run(Mode::MAX);
  251. run(Mode::MIN);
  252. run(Mode::MUL);
  253. run(Mode::SUB);
  254. run(Mode::TRUE_DIV);
  255. run(Mode::POW);
  256. };
  257. {
  258. UniformFloatRNG rng(1e-5, 7e1);
  259. checker.set_rng(0, &rng);
  260. checker.set_epsilon(1e-5);
  261. checker.set_dtype(0, dtype::Float32());
  262. checker.set_dtype(1, dtype::Float32());
  263. run_all();
  264. }
  265. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  266. {
  267. UniformFloatRNG rng(1, 2);
  268. checker.set_rng(0, &rng);
  269. checker.set_epsilon(3e-3);
  270. checker.set_dtype(0, dtype::Float16());
  271. checker.set_dtype(1, dtype::Float16());
  272. run_all();
  273. }
  274. #endif
  275. }
  276. TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
  277. using Mode = ElemwiseForward::Param::Mode;
  278. Checker<ElemwiseForward> checker(handle());
  279. UniformFloatRNG rng(1e-5, 7e1);
  280. checker.set_rng(0, &rng);
  281. checker.set_epsilon(1e-5);
  282. checker.set_dtype(0, dtype::Float32());
  283. checker.set_dtype(1, dtype::Float32());
  284. //! 2 dim
  285. auto run = [&](Mode mode) {
  286. // VEC_BCAST111C
  287. checker.set_param(mode).execs({{1, 2, 2, 12}, {1, 1, 1, 12}, {}});
  288. checker.set_param(mode).execs({{2, 5, 3, 28}, {1, 1, 1, 28}, {}});
  289. checker.set_param(mode).execs({{3, 5, 8, 32}, {1, 1, 1, 32}, {}});
  290. // BCAST111C_VEC
  291. checker.set_param(mode).execs({{1, 1, 1, 12}, {1, 2, 2, 12}, {}});
  292. checker.set_param(mode).execs({{1, 1, 1, 28}, {2, 5, 3, 28}, {}});
  293. checker.set_param(mode).execs({{1, 1, 1, 32}, {3, 5, 8, 32}, {}});
  294. };
  295. run(Mode::ADD);
  296. run(Mode::MUL);
  297. run(Mode::SUB);
  298. //! 3 dim contig
  299. auto run_3d_contig = [&](Mode mode) {
  300. // BCAST111C_VEC_BCAST111C
  301. checker.set_param(mode).execs(
  302. {{1, 1, 1, 12}, {1, 2, 2, 12}, {1, 1, 1, 12}, {}});
  303. checker.set_param(mode).execs(
  304. {{1, 1, 1, 28}, {2, 5, 3, 28}, {1, 1, 1, 28}, {}});
  305. checker.set_param(mode).execs(
  306. {{1, 1, 1, 32}, {3, 5, 8, 32}, {1, 1, 1, 32}, {}});
  307. // VEC_BCAST111C_VEC
  308. checker.set_param(mode).execs(
  309. {{1, 2, 2, 12}, {1, 1, 1, 12}, {1, 2, 2, 12}, {}});
  310. checker.set_param(mode).execs(
  311. {{2, 5, 3, 28}, {1, 1, 1, 28}, {2, 5, 3, 28}, {}});
  312. checker.set_param(mode).execs(
  313. {{3, 5, 8, 32}, {1, 1, 1, 32}, {3, 5, 8, 32}, {}});
  314. };
  315. run_3d_contig(Mode::FUSE_MUL_ADD3);
  316. //! 3 dim incontig
  317. auto run_3d_incontig = [&](Mode mode) {
  318. megdnn::TensorLayout src0({1, 1, 1, 12}, dtype::Float32());
  319. megdnn::TensorLayout src1({1, 2, 2, 12}, {80, 40, 20, 1}, dtype::Float32());
  320. // BCAST111C_VEC_BCAST111C
  321. checker.set_param(mode).execl({src0, src1, src0, {}});
  322. // VEC_BCAST111C_VEC
  323. checker.set_param(mode).execl({src1, src0, src1, {}});
  324. };
  325. run_3d_incontig(Mode::FUSE_MUL_ADD3);
  326. }
  327. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_N1HW_FP32_BCAST) {
  328. using Mode = ElemwiseForward::Param::Mode;
  329. Checker<ElemwiseForward> checker(handle());
  330. UniformFloatRNG rng(1e-5, 7e1);
  331. checker.set_rng(0, &rng);
  332. checker.set_epsilon(1e-5);
  333. checker.set_dtype(0, dtype::Float32());
  334. checker.set_dtype(1, dtype::Float32());
  335. //! 2 dim
  336. auto run = [&](Mode mode) {
  337. // VEC_BCASTX0X
  338. checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}});
  339. checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}});
  340. // BCASTX0X_VEC
  341. checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}});
  342. checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}});
  343. };
  344. run(Mode::ADD);
  345. run(Mode::MUL);
  346. run(Mode::SUB);
  347. }
  348. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY_RECORD) {
  349. using Mode = ElemwiseForward::Param::Mode;
  350. TaskRecordChecker<ElemwiseForward> checker(0);
  351. checker.set_param(Mode::FUSE_MUL_ADD3);
  352. auto run = [&] {
  353. //! nchw44
  354. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  355. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  356. //! nchw88
  357. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  358. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  359. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  360. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  361. };
  362. // case int
  363. checker.set_dtype(0, dtype::Int32());
  364. checker.set_dtype(1, dtype::Int32());
  365. checker.set_dtype(2, dtype::Int32());
  366. run();
  367. // case float
  368. UniformFloatRNG rng(1e-5, 7e1);
  369. checker.set_rng(0, &rng);
  370. checker.set_epsilon(1e-5);
  371. checker.set_dtype(0, dtype::Float32());
  372. checker.set_dtype(1, dtype::Float32());
  373. checker.set_dtype(2, dtype::Float32());
  374. run();
  375. }
  376. #if MEGDNN_WITH_BENCHMARK
  377. namespace {
  378. void run_elemwise_benchmark(
  379. const TensorShapeArray& shapes, param::Elemwise::Mode mode,
  380. const char* mode_str, DType type, Handle* handle_bench) {
  381. auto handle_fallback = create_cpu_handle(1);
  382. Benchmarker<Elemwise> benchmarker_bench(handle_bench);
  383. Benchmarker<Elemwise> benchmarker_fallback(handle_fallback.get());
  384. float throughput = 0;
  385. SmallVector<TensorLayout> layouts;
  386. std::string src_strs;
  387. for (size_t i = 0; i < shapes.size(); i++) {
  388. layouts.emplace_back(shapes[i], type);
  389. throughput += layouts.back().span().dist_byte();
  390. src_strs += layouts.back().to_string();
  391. if (i != shapes.size() - 1) {
  392. src_strs += ",";
  393. }
  394. }
  395. constexpr size_t RUN = 50;
  396. benchmarker_fallback.set_times(RUN).set_display(false);
  397. benchmarker_bench.set_times(RUN).set_display(false);
  398. benchmarker_fallback.set_param(mode);
  399. benchmarker_bench.set_param(mode);
  400. TensorLayout dst_layout;
  401. auto opr = handle_bench->create_operator<Elemwise>();
  402. opr->param() = mode;
  403. opr->deduce_layout(layouts, dst_layout);
  404. float computations =
  405. dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1);
  406. throughput += dst_layout.span().dist_byte();
  407. computations *= (1e3 / (1024.0 * 1024));
  408. throughput *= (1e3 / (1024.0 * 1024));
  409. layouts.emplace_back(dst_layout);
  410. auto fallback_time = benchmarker_fallback.execl(layouts) / RUN;
  411. auto bench_time = benchmarker_bench.execl(layouts) / RUN;
  412. float fallback_flops = computations / fallback_time;
  413. float bench_flops = computations / bench_time;
  414. float fallback_thr = throughput / fallback_time;
  415. float bench_thr = throughput / bench_time;
  416. printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
  417. "%fMB/s "
  418. "computations: %fx, throughput: %fx\n",
  419. src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), mode_str,
  420. fallback_flops, fallback_thr, bench_flops, bench_thr,
  421. bench_flops / fallback_flops, bench_thr / fallback_thr);
  422. }
  423. } // namespace
  424. TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NHWC) {
  425. Benchmarker<Elemwise> benchmarker(handle());
  426. constexpr size_t RUN = 50;
  427. benchmarker.set_times(RUN).set_display(false);
  428. auto run = [&](size_t N, size_t C, size_t H, size_t W, param::Elemwise::Mode mode,
  429. const char* mode_name) {
  430. megdnn::param::Elemwise param;
  431. param.mode = mode;
  432. benchmarker.set_param(param);
  433. megdnn::TensorShape nhwc_src0{N, H, W, C};
  434. megdnn::TensorShape nhwc_src1{1, 1, 1, C};
  435. megdnn::TensorShape nchw_src0{N, C, H, W};
  436. megdnn::TensorShape nchw_src1{1, C, 1, 1};
  437. float computations = N * C * H * W;
  438. auto nhwc_time = benchmarker.execs({nhwc_src1, nhwc_src0, {}}) / RUN;
  439. auto nchw_time = benchmarker.execs({nchw_src1, nchw_src0, {}}) / RUN;
  440. auto perf_nhwc = computations / nhwc_time / 1e6;
  441. auto perf_nchw = computations / nchw_time / 1e6;
  442. printf("Elemwise Mode : %s\nNHWC : %fms %fGflops\nNCHW : %fms "
  443. "%fGflops\n",
  444. mode_name, nhwc_time, perf_nhwc, nchw_time, perf_nchw);
  445. };
  446. run(1, 120, 16, 24, param::Elemwise::Mode::ADD, "ADD");
  447. run(1, 120, 16, 24, param::Elemwise::Mode::MUL, "MUL");
  448. run(1, 120, 32, 48, param::Elemwise::Mode::ADD, "ADD");
  449. run(1, 120, 32, 48, param::Elemwise::Mode::MUL, "MUL");
  450. run(1, 120, 64, 96, param::Elemwise::Mode::ADD, "ADD");
  451. run(1, 120, 64, 96, param::Elemwise::Mode::MUL, "MUL");
  452. }
  453. #define INT_RUN(shape, mode) \
  454. run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
  455. run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
  456. run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle());
  457. #define FLOAT_RUN(shape, mode) \
  458. run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \
  459. run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle());
  460. #define BENCHMARK_CASES(shape) \
  461. INT_BENCHMARK_CASES(shape) \
  462. FLOAT_BENCHMARK_CASES(shape)
  463. TEST_F(ARM_COMMON, BENCHMARK_UNARY) {
  464. #define INT_BENCHMARK_CASES(shape) \
  465. INT_RUN(shape, Mode::RELU); \
  466. INT_RUN(shape, Mode::ABS);
  467. #define FLOAT_BENCHMARK_CASES(shape) \
  468. FLOAT_RUN(shape, Mode::RELU); \
  469. FLOAT_RUN(shape, Mode::ABS); \
  470. FLOAT_RUN(shape, Mode::SIGMOID); \
  471. FLOAT_RUN(shape, Mode::EXP); \
  472. FLOAT_RUN(shape, Mode::TANH); \
  473. FLOAT_RUN(shape, Mode::FAST_TANH);
  474. using Mode = param::Elemwise::Mode;
  475. BENCHMARK_CASES({{10000}});
  476. BENCHMARK_CASES({{50000}});
  477. #undef INT_BENCHMARK_CASES
  478. #undef FLOAT_BENCHMARK_CASES
  479. }
  480. TEST_F(ARM_COMMON, BENCHMARK_BINARY) {
  481. #define INT_BENCHMARK_CASES(shape) \
  482. INT_RUN(shape, Mode::MIN); \
  483. INT_RUN(shape, Mode::MAX); \
  484. INT_RUN(shape, Mode::ADD); \
  485. INT_RUN(shape, Mode::SUB); \
  486. INT_RUN(shape, Mode::MUL); \
  487. INT_RUN(shape, Mode::RMULH); \
  488. INT_RUN(shape, Mode::FUSE_ADD_RELU);
  489. #define FLOAT_BENCHMARK_CASES(shape) \
  490. FLOAT_RUN(shape, Mode::MIN); \
  491. FLOAT_RUN(shape, Mode::MAX); \
  492. FLOAT_RUN(shape, Mode::ADD); \
  493. FLOAT_RUN(shape, Mode::SUB); \
  494. FLOAT_RUN(shape, Mode::MUL); \
  495. FLOAT_RUN(shape, Mode::POW); \
  496. FLOAT_RUN(shape, Mode::TRUE_DIV); \
  497. FLOAT_RUN(shape, Mode::FUSE_ADD_RELU);
  498. using Mode = param::Elemwise::Mode;
  499. TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}};
  500. BENCHMARK_CASES(shapes);
  501. shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}};
  502. BENCHMARK_CASES(shapes);
  503. shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}};
  504. BENCHMARK_CASES(shapes);
  505. #undef INT_BENCHMARK_CASES
  506. #undef FLOAT_BENCHMARK_CASES
  507. }
  508. TEST_F(ARM_COMMON, BENCHMARK_TERNARY_FMA3) {
  509. #define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3);
  510. #define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3);
  511. using Mode = param::Elemwise::Mode;
  512. TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}};
  513. BENCHMARK_CASES(shapes);
  514. shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}};
  515. BENCHMARK_CASES(shapes);
  516. shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}};
  517. BENCHMARK_CASES(shapes);
  518. #undef INT_BENCHMARK_CASES
  519. #undef FLOAT_BENCHMARK_CASES
  520. }
  521. #undef BENCHMARK_CASES
  522. #undef INT_RUN
  523. #undef FLOAT_RUN
  524. #endif
  525. // vim: syntax=cpp.doxygen