You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 24 kB


  1. #include "test/common/elemwise.h"
  2. #include "test/arm_common/fixture.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/task_record_check.h"
  6. #include "megdnn/opr_param_defs.h"
  7. #include "megdnn/oprs/general.h"
  8. using namespace megdnn;
  9. using namespace test;
  10. template <typename tag>
  11. class ARM_ELEMWISE : public ARM_COMMON {};
  12. TYPED_TEST_CASE(ARM_ELEMWISE, elemwise::test_types);
  13. TYPED_TEST(ARM_ELEMWISE, run) {
  14. elemwise::run_test<TypeParam>(this->handle());
  15. }
  16. template <typename tag>
  17. class ARM_ELEMWISE_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {};
  18. TYPED_TEST_CASE(ARM_ELEMWISE_MULTI_THREADS, elemwise::test_types);
  19. TYPED_TEST(ARM_ELEMWISE_MULTI_THREADS, run) {
  20. elemwise::run_test<TypeParam>(this->handle());
  21. }
  22. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
  23. using Mode = ElemwiseForward::Param::Mode;
  24. Checker<ElemwiseForward> checker(handle());
  25. checker.set_param(Mode::FUSE_MUL_ADD3);
  26. auto run = [&] {
  27. //! nchw44
  28. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  29. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  30. checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  31. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  32. checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  33. //! nchw44
  34. checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  35. checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  36. checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  37. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  38. checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  39. //! nchw88
  40. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  41. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  42. checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  43. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  44. checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  45. //! nchw88
  46. checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  47. checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  48. checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  49. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  50. checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  51. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  52. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  53. checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
  54. checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
  55. checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
  56. checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
  57. checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
  58. checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  59. checker.execs({{3, 4, 5}, {1}, {1}, {}});
  60. checker.execs({{1}, {3, 4, 5}, {1}, {}});
  61. };
  62. // case int
  63. checker.set_dtype(0, dtype::Int8());
  64. checker.set_dtype(1, dtype::Int8());
  65. checker.set_dtype(2, dtype::Int8());
  66. run();
  67. checker.set_dtype(0, dtype::Int16());
  68. checker.set_dtype(1, dtype::Int16());
  69. checker.set_dtype(2, dtype::Int16());
  70. run();
  71. checker.set_dtype(0, dtype::Int32());
  72. checker.set_dtype(1, dtype::Int32());
  73. checker.set_dtype(2, dtype::Int32());
  74. run();
  75. // case float
  76. UniformFloatRNG rng(1e-5, 7e1);
  77. checker.set_rng(0, &rng);
  78. checker.set_epsilon(1e-5);
  79. checker.set_dtype(0, dtype::Float32());
  80. checker.set_dtype(1, dtype::Float32());
  81. checker.set_dtype(2, dtype::Float32());
  82. run();
  83. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  84. // case half
  85. UniformFloatRNG rng_float16(1, 10);
  86. checker.set_rng(0, &rng_float16);
  87. checker.set_epsilon(1e-2);
  88. checker.set_dtype(0, dtype::Float16());
  89. checker.set_dtype(1, dtype::Float16());
  90. checker.set_dtype(2, dtype::Float16());
  91. run();
  92. #endif
  93. }
  94. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
  95. using Mode = ElemwiseForward::Param::Mode;
  96. Checker<ElemwiseForward> checker(handle());
  97. auto run = [&]() {
  98. // VEC_BCAST101x not PowOp
  99. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  100. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  101. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  102. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  103. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  104. checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  105. checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  106. checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  107. checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  108. checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  109. checker.set_param(Mode::FUSE_ADD_RELU)
  110. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  111. checker.set_param(Mode::FUSE_ADD_RELU)
  112. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  113. checker.set_param(Mode::FUSE_ADD_RELU)
  114. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  115. checker.set_param(Mode::FUSE_ADD_RELU)
  116. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  117. checker.set_param(Mode::FUSE_ADD_RELU)
  118. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  119. // BCAST101x_VEC not PowOp
  120. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  121. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  122. checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  123. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  124. checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  125. checker.set_param(Mode::FUSE_ADD_RELU)
  126. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  127. checker.set_param(Mode::FUSE_ADD_RELU)
  128. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  129. checker.set_param(Mode::FUSE_ADD_RELU)
  130. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  131. checker.set_param(Mode::FUSE_ADD_RELU)
  132. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  133. checker.set_param(Mode::FUSE_ADD_RELU)
  134. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  135. };
  136. checker.set_dtype(0, dtype::Int8());
  137. checker.set_dtype(1, dtype::Int8());
  138. run();
  139. checker.set_dtype(0, dtype::Int16());
  140. checker.set_dtype(1, dtype::Int16());
  141. run();
  142. checker.set_dtype(0, dtype::Int32());
  143. checker.set_dtype(1, dtype::Int32());
  144. run();
  145. }
  146. TEST_F(ARM_COMMON, ELEMWISE_SIGMOID) {
  147. using Mode = ElemwiseForward::Param::Mode;
  148. Checker<ElemwiseForward> checker(handle());
  149. checker.set_epsilon(1e-3);
  150. checker.set_dtype(0, dtype::Float16());
  151. checker.set_param(Mode::SIGMOID);
  152. for (size_t n : {1, 2, 3}) {
  153. for (size_t ic : {8, 16, 24, 32}) {
  154. for (size_t ih : {5, 10, 15, 20, 21, 37}) {
  155. for (size_t iw : {7, 9, 11, 13, 14, 20, 35}) {
  156. checker.exec({{n, ic, ih, iw}, {}});
  157. }
  158. }
  159. }
  160. }
  161. }
  162. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
  163. using Mode = ElemwiseForward::Param::Mode;
  164. Checker<ElemwiseForward> checker(handle());
  165. UniformFloatRNG rng(1e-5, 7e1);
  166. checker.set_rng(0, &rng);
  167. checker.set_epsilon(1e-5);
  168. checker.set_dtype(0, dtype::Float32());
  169. checker.set_dtype(1, dtype::Float32());
  170. checker.set_param(Mode::FUSE_ADD_RELU)
  171. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  172. checker.set_param(Mode::FUSE_ADD_RELU)
  173. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  174. checker.set_param(Mode::FUSE_ADD_RELU)
  175. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  176. checker.set_param(Mode::FUSE_ADD_RELU)
  177. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  178. checker.set_param(Mode::FUSE_ADD_RELU)
  179. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  180. checker.set_param(Mode::FUSE_ADD_RELU)
  181. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  182. checker.set_param(Mode::FUSE_ADD_RELU)
  183. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  184. checker.set_param(Mode::FUSE_ADD_RELU)
  185. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  186. checker.set_param(Mode::FUSE_ADD_RELU)
  187. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  188. checker.set_param(Mode::FUSE_ADD_RELU)
  189. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  190. auto run = [&](Mode mode) {
  191. // VEC_BCAST101x
  192. checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  193. checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  194. checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  195. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  196. checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  197. // BCAST101x_VEC not powOp
  198. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  199. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  200. checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  201. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  202. checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  203. };
  204. run(Mode::ADD);
  205. run(Mode::FUSE_ADD_H_SWISH);
  206. run(Mode::FUSE_ADD_RELU);
  207. run(Mode::MAX);
  208. run(Mode::MIN);
  209. run(Mode::MUL);
  210. run(Mode::SUB);
  211. run(Mode::TRUE_DIV);
  212. run(Mode::POW);
  213. }
  214. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
  215. using Mode = ElemwiseForward::Param::Mode;
  216. Checker<ElemwiseForward> checker(handle());
  217. checker.set_param(Mode::FUSE_ADD_RELU)
  218. .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  219. checker.set_param(Mode::FUSE_ADD_RELU)
  220. .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  221. checker.set_param(Mode::FUSE_ADD_RELU)
  222. .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  223. checker.set_param(Mode::FUSE_ADD_RELU)
  224. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  225. checker.set_param(Mode::FUSE_ADD_RELU)
  226. .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  227. checker.set_param(Mode::FUSE_ADD_RELU)
  228. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  229. checker.set_param(Mode::FUSE_ADD_RELU)
  230. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  231. checker.set_param(Mode::FUSE_ADD_RELU)
  232. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  233. checker.set_param(Mode::FUSE_ADD_RELU)
  234. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  235. checker.set_param(Mode::FUSE_ADD_RELU)
  236. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  237. auto run = [&](Mode mode) {
  238. // VEC_BCAST101x
  239. checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  240. checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  241. checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  242. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  243. checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  244. // BCAST101x_VEC not powOp
  245. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  246. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  247. checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  248. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  249. checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  250. };
  251. auto run_all = [&]() {
  252. run(Mode::ADD);
  253. run(Mode::FUSE_ADD_H_SWISH);
  254. run(Mode::FUSE_ADD_RELU);
  255. run(Mode::MAX);
  256. run(Mode::MIN);
  257. run(Mode::MUL);
  258. run(Mode::SUB);
  259. run(Mode::TRUE_DIV);
  260. run(Mode::POW);
  261. };
  262. {
  263. UniformFloatRNG rng(1e-5, 7e1);
  264. checker.set_rng(0, &rng);
  265. checker.set_epsilon(1e-5);
  266. checker.set_dtype(0, dtype::Float32());
  267. checker.set_dtype(1, dtype::Float32());
  268. run_all();
  269. }
  270. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  271. {
  272. UniformFloatRNG rng(1, 2);
  273. checker.set_rng(0, &rng);
  274. checker.set_epsilon(3e-3);
  275. checker.set_dtype(0, dtype::Float16());
  276. checker.set_dtype(1, dtype::Float16());
  277. run_all();
  278. }
  279. #endif
  280. }
  281. TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
  282. using Mode = ElemwiseForward::Param::Mode;
  283. Checker<ElemwiseForward> checker(handle());
  284. UniformFloatRNG rng(1e-5, 7e1);
  285. checker.set_rng(0, &rng);
  286. checker.set_epsilon(1e-5);
  287. checker.set_dtype(0, dtype::Float32());
  288. checker.set_dtype(1, dtype::Float32());
  289. //! 2 dim
  290. auto run = [&](Mode mode) {
  291. // VEC_BCAST111C
  292. checker.set_param(mode).execs({{1, 2, 2, 12}, {1, 1, 1, 12}, {}});
  293. checker.set_param(mode).execs({{2, 5, 3, 28}, {1, 1, 1, 28}, {}});
  294. checker.set_param(mode).execs({{3, 5, 8, 32}, {1, 1, 1, 32}, {}});
  295. // BCAST111C_VEC
  296. checker.set_param(mode).execs({{1, 1, 1, 12}, {1, 2, 2, 12}, {}});
  297. checker.set_param(mode).execs({{1, 1, 1, 28}, {2, 5, 3, 28}, {}});
  298. checker.set_param(mode).execs({{1, 1, 1, 32}, {3, 5, 8, 32}, {}});
  299. };
  300. run(Mode::ADD);
  301. run(Mode::MUL);
  302. run(Mode::SUB);
  303. //! 3 dim contig
  304. auto run_3d_contig = [&](Mode mode) {
  305. // BCAST111C_VEC_BCAST111C
  306. checker.set_param(mode).execs(
  307. {{1, 1, 1, 12}, {1, 2, 2, 12}, {1, 1, 1, 12}, {}});
  308. checker.set_param(mode).execs(
  309. {{1, 1, 1, 28}, {2, 5, 3, 28}, {1, 1, 1, 28}, {}});
  310. checker.set_param(mode).execs(
  311. {{1, 1, 1, 32}, {3, 5, 8, 32}, {1, 1, 1, 32}, {}});
  312. // VEC_BCAST111C_VEC
  313. checker.set_param(mode).execs(
  314. {{1, 2, 2, 12}, {1, 1, 1, 12}, {1, 2, 2, 12}, {}});
  315. checker.set_param(mode).execs(
  316. {{2, 5, 3, 28}, {1, 1, 1, 28}, {2, 5, 3, 28}, {}});
  317. checker.set_param(mode).execs(
  318. {{3, 5, 8, 32}, {1, 1, 1, 32}, {3, 5, 8, 32}, {}});
  319. };
  320. run_3d_contig(Mode::FUSE_MUL_ADD3);
  321. //! 3 dim incontig
  322. auto run_3d_incontig = [&](Mode mode) {
  323. megdnn::TensorLayout src0({1, 1, 1, 12}, dtype::Float32());
  324. megdnn::TensorLayout src1({1, 2, 2, 12}, {80, 40, 20, 1}, dtype::Float32());
  325. // BCAST111C_VEC_BCAST111C
  326. checker.set_param(mode).execl({src0, src1, src0, {}});
  327. // VEC_BCAST111C_VEC
  328. checker.set_param(mode).execl({src1, src0, src1, {}});
  329. };
  330. run_3d_incontig(Mode::FUSE_MUL_ADD3);
  331. }
  332. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_N1HW_FP32_BCAST) {
  333. using Mode = ElemwiseForward::Param::Mode;
  334. Checker<ElemwiseForward> checker(handle());
  335. UniformFloatRNG rng(1e-5, 7e1);
  336. checker.set_rng(0, &rng);
  337. checker.set_epsilon(1e-5);
  338. checker.set_dtype(0, dtype::Float32());
  339. checker.set_dtype(1, dtype::Float32());
  340. //! 2 dim
  341. auto run = [&](Mode mode) {
  342. // VEC_BCASTX0X
  343. checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}});
  344. checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}});
  345. // BCASTX0X_VEC
  346. checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}});
  347. checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}});
  348. };
  349. run(Mode::ADD);
  350. run(Mode::MUL);
  351. run(Mode::SUB);
  352. }
  353. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY_RECORD) {
  354. using Mode = ElemwiseForward::Param::Mode;
  355. TaskRecordChecker<ElemwiseForward> checker(0);
  356. checker.set_param(Mode::FUSE_MUL_ADD3);
  357. auto run = [&] {
  358. //! nchw44
  359. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  360. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  361. //! nchw88
  362. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  363. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  364. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  365. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  366. };
  367. // case int
  368. checker.set_dtype(0, dtype::Int32());
  369. checker.set_dtype(1, dtype::Int32());
  370. checker.set_dtype(2, dtype::Int32());
  371. run();
  372. // case float
  373. UniformFloatRNG rng(1e-5, 7e1);
  374. checker.set_rng(0, &rng);
  375. checker.set_epsilon(1e-5);
  376. checker.set_dtype(0, dtype::Float32());
  377. checker.set_dtype(1, dtype::Float32());
  378. checker.set_dtype(2, dtype::Float32());
  379. run();
  380. }
  381. #if MEGDNN_WITH_BENCHMARK
  382. namespace {
  383. void run_elemwise_benchmark(
  384. const TensorShapeArray& shapes, param::Elemwise::Mode mode,
  385. const char* mode_str, DType type, Handle* handle_bench) {
  386. auto handle_fallback = create_cpu_handle(1);
  387. Benchmarker<Elemwise> benchmarker_bench(handle_bench);
  388. Benchmarker<Elemwise> benchmarker_fallback(handle_fallback.get());
  389. float throughput = 0;
  390. SmallVector<TensorLayout> layouts;
  391. std::string src_strs;
  392. for (size_t i = 0; i < shapes.size(); i++) {
  393. layouts.emplace_back(shapes[i], type);
  394. throughput += layouts.back().span().dist_byte();
  395. src_strs += layouts.back().to_string();
  396. if (i != shapes.size() - 1) {
  397. src_strs += ",";
  398. }
  399. }
  400. constexpr size_t RUN = 50;
  401. benchmarker_fallback.set_times(RUN).set_display(false);
  402. benchmarker_bench.set_times(RUN).set_display(false);
  403. benchmarker_fallback.set_param(mode);
  404. benchmarker_bench.set_param(mode);
  405. TensorLayout dst_layout;
  406. auto opr = handle_bench->create_operator<Elemwise>();
  407. opr->param() = mode;
  408. opr->deduce_layout(layouts, dst_layout);
  409. float computations =
  410. dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1);
  411. throughput += dst_layout.span().dist_byte();
  412. computations *= (1e3 / (1024.0 * 1024));
  413. throughput *= (1e3 / (1024.0 * 1024));
  414. layouts.emplace_back(dst_layout);
  415. auto fallback_time = benchmarker_fallback.execl(layouts) / RUN;
  416. auto bench_time = benchmarker_bench.execl(layouts) / RUN;
  417. float fallback_flops = computations / fallback_time;
  418. float bench_flops = computations / bench_time;
  419. float fallback_thr = throughput / fallback_time;
  420. float bench_thr = throughput / bench_time;
  421. printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
  422. "%fMB/s "
  423. "computations: %fx, throughput: %fx\n",
  424. src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), mode_str,
  425. fallback_flops, fallback_thr, bench_flops, bench_thr,
  426. bench_flops / fallback_flops, bench_thr / fallback_thr);
  427. }
  428. } // namespace
  429. TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NHWC) {
  430. Benchmarker<Elemwise> benchmarker(handle());
  431. constexpr size_t RUN = 50;
  432. benchmarker.set_times(RUN).set_display(false);
  433. auto run = [&](size_t N, size_t C, size_t H, size_t W, param::Elemwise::Mode mode,
  434. const char* mode_name) {
  435. megdnn::param::Elemwise param;
  436. param.mode = mode;
  437. benchmarker.set_param(param);
  438. megdnn::TensorShape nhwc_src0{N, H, W, C};
  439. megdnn::TensorShape nhwc_src1{1, 1, 1, C};
  440. megdnn::TensorShape nchw_src0{N, C, H, W};
  441. megdnn::TensorShape nchw_src1{1, C, 1, 1};
  442. float computations = N * C * H * W;
  443. auto nhwc_time = benchmarker.execs({nhwc_src1, nhwc_src0, {}}) / RUN;
  444. auto nchw_time = benchmarker.execs({nchw_src1, nchw_src0, {}}) / RUN;
  445. auto perf_nhwc = computations / nhwc_time / 1e6;
  446. auto perf_nchw = computations / nchw_time / 1e6;
  447. printf("Elemwise Mode : %s\nNHWC : %fms %fGflops\nNCHW : %fms "
  448. "%fGflops\n",
  449. mode_name, nhwc_time, perf_nhwc, nchw_time, perf_nchw);
  450. };
  451. run(1, 120, 16, 24, param::Elemwise::Mode::ADD, "ADD");
  452. run(1, 120, 16, 24, param::Elemwise::Mode::MUL, "MUL");
  453. run(1, 120, 32, 48, param::Elemwise::Mode::ADD, "ADD");
  454. run(1, 120, 32, 48, param::Elemwise::Mode::MUL, "MUL");
  455. run(1, 120, 64, 96, param::Elemwise::Mode::ADD, "ADD");
  456. run(1, 120, 64, 96, param::Elemwise::Mode::MUL, "MUL");
  457. }
  458. #define INT_RUN(shape, mode) \
  459. run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
  460. run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
  461. run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle());
  462. #define FLOAT_RUN(shape, mode) \
  463. run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \
  464. run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle());
  465. #define BENCHMARK_CASES(shape) \
  466. INT_BENCHMARK_CASES(shape) \
  467. FLOAT_BENCHMARK_CASES(shape)
  468. TEST_F(ARM_COMMON, BENCHMARK_UNARY) {
  469. #define INT_BENCHMARK_CASES(shape) \
  470. INT_RUN(shape, Mode::RELU); \
  471. INT_RUN(shape, Mode::ABS);
  472. #define FLOAT_BENCHMARK_CASES(shape) \
  473. FLOAT_RUN(shape, Mode::RELU); \
  474. FLOAT_RUN(shape, Mode::ABS); \
  475. FLOAT_RUN(shape, Mode::SIGMOID); \
  476. FLOAT_RUN(shape, Mode::EXP); \
  477. FLOAT_RUN(shape, Mode::TANH); \
  478. FLOAT_RUN(shape, Mode::FAST_TANH);
  479. using Mode = param::Elemwise::Mode;
  480. BENCHMARK_CASES({{10000}});
  481. BENCHMARK_CASES({{50000}});
  482. #undef INT_BENCHMARK_CASES
  483. #undef FLOAT_BENCHMARK_CASES
  484. }
  485. TEST_F(ARM_COMMON, BENCHMARK_BINARY) {
  486. #define INT_BENCHMARK_CASES(shape) \
  487. INT_RUN(shape, Mode::MIN); \
  488. INT_RUN(shape, Mode::MAX); \
  489. INT_RUN(shape, Mode::ADD); \
  490. INT_RUN(shape, Mode::SUB); \
  491. INT_RUN(shape, Mode::MUL); \
  492. INT_RUN(shape, Mode::RMULH); \
  493. INT_RUN(shape, Mode::FUSE_ADD_RELU);
  494. #define FLOAT_BENCHMARK_CASES(shape) \
  495. FLOAT_RUN(shape, Mode::MIN); \
  496. FLOAT_RUN(shape, Mode::MAX); \
  497. FLOAT_RUN(shape, Mode::ADD); \
  498. FLOAT_RUN(shape, Mode::SUB); \
  499. FLOAT_RUN(shape, Mode::MUL); \
  500. FLOAT_RUN(shape, Mode::POW); \
  501. FLOAT_RUN(shape, Mode::TRUE_DIV); \
  502. FLOAT_RUN(shape, Mode::FUSE_ADD_RELU);
  503. using Mode = param::Elemwise::Mode;
  504. TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}};
  505. BENCHMARK_CASES(shapes);
  506. shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}};
  507. BENCHMARK_CASES(shapes);
  508. shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}};
  509. BENCHMARK_CASES(shapes);
  510. #undef INT_BENCHMARK_CASES
  511. #undef FLOAT_BENCHMARK_CASES
  512. }
  513. TEST_F(ARM_COMMON, BENCHMARK_TERNARY_FMA3) {
  514. #define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3);
  515. #define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3);
  516. using Mode = param::Elemwise::Mode;
  517. TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}};
  518. BENCHMARK_CASES(shapes);
  519. shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}};
  520. BENCHMARK_CASES(shapes);
  521. shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}};
  522. BENCHMARK_CASES(shapes);
  523. #undef INT_BENCHMARK_CASES
  524. #undef FLOAT_BENCHMARK_CASES
  525. }
  526. #undef BENCHMARK_CASES
  527. #undef INT_RUN
  528. #undef FLOAT_RUN
  529. #endif
  530. // vim: syntax=cpp.doxygen