You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 16 kB


  1. #include "test/fallback/fixture.h"
  2. #include <ctime>
  3. #include "test/common/checker.h"
  4. #include "test/common/elemwise.h"
  5. #include "test/common/task_record_check.h"
  6. #include "test/common/tensor.h"
  7. using namespace megdnn;
  8. using namespace test;
  9. template <typename tag>
  10. class FALLBACK_ELEMWISE : public FALLBACK {};
  11. TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types);
  12. TYPED_TEST(FALLBACK_ELEMWISE, run) {
  13. elemwise::run_test<TypeParam>(this->handle());
  14. }
  15. TEST_F(FALLBACK, ELEMWISE_RECORD) {
  16. TaskRecordChecker<Elemwise> checker{1};
  17. checker.set_param({Elemwise::Mode::ADD});
  18. checker.set_dtype(0, dtype::Float32());
  19. checker.set_dtype(1, dtype::Float32());
  20. checker.set_dtype(2, dtype::Float32());
  21. UniformIntRNG rng{-100, 100};
  22. checker.set_rng(0, &rng);
  23. checker.set_rng(1, &rng);
  24. checker.set_rng(2, &rng);
  25. checker.execs({{10, 10, 32}, {10, 10, 32}, {}});
  26. }
  27. TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) {
  28. using Mode = ElemwiseForward::Param::Mode;
  29. Checker<ElemwiseForward> checker(handle());
  30. checker.set_param(Mode::FUSE_MUL_ADD3);
  31. auto run = [&] {
  32. //! nchw44
  33. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  34. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  35. checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  36. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  37. checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  38. //! nchw44
  39. checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  40. checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  41. checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  42. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  43. checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  44. //! nchw88
  45. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  46. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  47. checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  48. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  49. checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  50. //! nchw88
  51. checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  52. checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  53. checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  54. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  55. checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  56. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  57. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  58. checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
  59. checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
  60. checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
  61. checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
  62. checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
  63. checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  64. checker.execs({{3, 4, 5}, {1}, {1}, {}});
  65. checker.execs({{1}, {3, 4, 5}, {1}, {}});
  66. };
  67. // case int
  68. checker.set_dtype(0, dtype::Int8());
  69. checker.set_dtype(1, dtype::Int8());
  70. checker.set_dtype(2, dtype::Int8());
  71. run();
  72. checker.set_dtype(0, dtype::Int16());
  73. checker.set_dtype(1, dtype::Int16());
  74. checker.set_dtype(2, dtype::Int16());
  75. run();
  76. checker.set_dtype(0, dtype::Int32());
  77. checker.set_dtype(1, dtype::Int32());
  78. checker.set_dtype(2, dtype::Int32());
  79. run();
  80. // case float
  81. UniformFloatRNG rng(1e-5, 7e1);
  82. checker.set_rng(0, &rng);
  83. checker.set_epsilon(1e-5);
  84. checker.set_dtype(0, dtype::Float32());
  85. checker.set_dtype(1, dtype::Float32());
  86. checker.set_dtype(2, dtype::Float32());
  87. run();
  88. }
  89. TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
  90. using Mode = ElemwiseForward::Param::Mode;
  91. Checker<ElemwiseForward> checker(handle());
  92. auto run = [&]() {
  93. // VEC_BCAST101x not PowOp
  94. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  95. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  96. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  97. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  98. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  99. checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  100. checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  101. checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  102. checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  103. checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  104. checker.set_param(Mode::FUSE_ADD_RELU)
  105. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  106. checker.set_param(Mode::FUSE_ADD_RELU)
  107. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  108. checker.set_param(Mode::FUSE_ADD_RELU)
  109. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  110. checker.set_param(Mode::FUSE_ADD_RELU)
  111. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  112. checker.set_param(Mode::FUSE_ADD_RELU)
  113. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  114. // BCAST101x_VEC not PowOp
  115. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  116. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  117. checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  118. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  119. checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  120. checker.set_param(Mode::FUSE_ADD_RELU)
  121. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  122. checker.set_param(Mode::FUSE_ADD_RELU)
  123. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  124. checker.set_param(Mode::FUSE_ADD_RELU)
  125. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  126. checker.set_param(Mode::FUSE_ADD_RELU)
  127. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  128. checker.set_param(Mode::FUSE_ADD_RELU)
  129. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  130. };
  131. checker.set_dtype(0, dtype::Int8());
  132. checker.set_dtype(1, dtype::Int8());
  133. run();
  134. checker.set_dtype(0, dtype::Int16());
  135. checker.set_dtype(1, dtype::Int16());
  136. run();
  137. checker.set_dtype(0, dtype::Int32());
  138. checker.set_dtype(1, dtype::Int32());
  139. run();
  140. }
  141. TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_FP32) {
  142. using Mode = ElemwiseForward::Param::Mode;
  143. Checker<ElemwiseForward> checker(handle());
  144. UniformFloatRNG rng(1e-5, 7e1);
  145. checker.set_rng(0, &rng);
  146. checker.set_epsilon(1e-5);
  147. checker.set_dtype(0, dtype::Float32());
  148. checker.set_dtype(1, dtype::Float32());
  149. checker.set_param(Mode::FUSE_ADD_RELU)
  150. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  151. checker.set_param(Mode::FUSE_ADD_RELU)
  152. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  153. checker.set_param(Mode::FUSE_ADD_RELU)
  154. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  155. checker.set_param(Mode::FUSE_ADD_RELU)
  156. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  157. checker.set_param(Mode::FUSE_ADD_RELU)
  158. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  159. checker.set_param(Mode::FUSE_ADD_RELU)
  160. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  161. checker.set_param(Mode::FUSE_ADD_RELU)
  162. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  163. checker.set_param(Mode::FUSE_ADD_RELU)
  164. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  165. checker.set_param(Mode::FUSE_ADD_RELU)
  166. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  167. checker.set_param(Mode::FUSE_ADD_RELU)
  168. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  169. auto run = [&](Mode mode) {
  170. // VEC_BCAST101x
  171. checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  172. checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  173. checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  174. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  175. checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  176. // BCAST101x_VEC not powOp
  177. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  178. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  179. checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  180. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  181. checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  182. };
  183. run(Mode::ADD);
  184. run(Mode::FUSE_ADD_H_SWISH);
  185. run(Mode::FUSE_ADD_RELU);
  186. run(Mode::MAX);
  187. run(Mode::MIN);
  188. run(Mode::MUL);
  189. run(Mode::SUB);
  190. run(Mode::TRUE_DIV);
  191. run(Mode::POW);
  192. }
  193. TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW88_FP) {
  194. using Mode = ElemwiseForward::Param::Mode;
  195. Checker<ElemwiseForward> checker(handle());
  196. checker.set_param(Mode::FUSE_ADD_RELU)
  197. .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  198. checker.set_param(Mode::FUSE_ADD_RELU)
  199. .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  200. checker.set_param(Mode::FUSE_ADD_RELU)
  201. .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  202. checker.set_param(Mode::FUSE_ADD_RELU)
  203. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  204. checker.set_param(Mode::FUSE_ADD_RELU)
  205. .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  206. checker.set_param(Mode::FUSE_ADD_RELU)
  207. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  208. checker.set_param(Mode::FUSE_ADD_RELU)
  209. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  210. checker.set_param(Mode::FUSE_ADD_RELU)
  211. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  212. checker.set_param(Mode::FUSE_ADD_RELU)
  213. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  214. checker.set_param(Mode::FUSE_ADD_RELU)
  215. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  216. auto run = [&](Mode mode) {
  217. // VEC_BCAST101x
  218. checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  219. checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  220. checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  221. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  222. checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  223. // BCAST101x_VEC not powOp
  224. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  225. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  226. checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  227. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  228. checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  229. };
  230. auto run_all = [&]() {
  231. run(Mode::ADD);
  232. run(Mode::FUSE_ADD_H_SWISH);
  233. run(Mode::FUSE_ADD_RELU);
  234. run(Mode::MAX);
  235. run(Mode::MIN);
  236. run(Mode::MUL);
  237. run(Mode::SUB);
  238. run(Mode::TRUE_DIV);
  239. run(Mode::POW);
  240. };
  241. {
  242. UniformFloatRNG rng(1e-5, 7e1);
  243. checker.set_rng(0, &rng);
  244. checker.set_epsilon(1e-5);
  245. checker.set_dtype(0, dtype::Float32());
  246. checker.set_dtype(1, dtype::Float32());
  247. run_all();
  248. }
  249. }
  250. TEST_F(FALLBACK, ELEMWISE_FORWARD_N1HW_FP32_BCAST) {
  251. using Mode = ElemwiseForward::Param::Mode;
  252. Checker<ElemwiseForward> checker(handle());
  253. UniformFloatRNG rng(1e-5, 7e1);
  254. checker.set_rng(0, &rng);
  255. checker.set_epsilon(1e-5);
  256. checker.set_dtype(0, dtype::Float32());
  257. checker.set_dtype(1, dtype::Float32());
  258. //! 2 dim
  259. auto run = [&](Mode mode) {
  260. // VEC_BCASTX0X
  261. checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}});
  262. checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}});
  263. // BCASTX0X_VEC
  264. checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}});
  265. checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}});
  266. };
  267. run(Mode::ADD);
  268. run(Mode::MUL);
  269. run(Mode::SUB);
  270. }
  271. TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY_RECORD) {
  272. using Mode = ElemwiseForward::Param::Mode;
  273. TaskRecordChecker<ElemwiseForward> checker(0);
  274. checker.set_param(Mode::FUSE_MUL_ADD3);
  275. auto run = [&] {
  276. //! nchw44
  277. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  278. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  279. //! nchw88
  280. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  281. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  282. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  283. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  284. };
  285. // case int
  286. checker.set_dtype(0, dtype::Int32());
  287. checker.set_dtype(1, dtype::Int32());
  288. checker.set_dtype(2, dtype::Int32());
  289. run();
  290. // case float
  291. UniformFloatRNG rng(1e-5, 7e1);
  292. checker.set_rng(0, &rng);
  293. checker.set_epsilon(1e-5);
  294. checker.set_dtype(0, dtype::Float32());
  295. checker.set_dtype(1, dtype::Float32());
  296. checker.set_dtype(2, dtype::Float32());
  297. run();
  298. }
  299. #if MEGDNN_WITH_BENCHMARK
  300. TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
  301. auto naive_handle = create_cpu_handle(2);
  302. auto run = [&](const TensorShape& shp0, const TensorShape& shp1) {
  303. TensorShape shpo;
  304. Elemwise::deduce_shape({shp0, shp1}, shpo);
  305. Tensor<> op0(handle(), {shp0, dtype::Float32()}),
  306. op1(handle(), {shp1, dtype::Float32()}),
  307. out(handle(), {shpo, dtype::Float32()});
  308. auto opr_cur = handle()->create_operator<Elemwise>();
  309. auto opr_naive = naive_handle->create_operator<Elemwise>();
  310. opr_cur->param() = {Elemwise::Mode::ADD};
  311. opr_naive->param() = {Elemwise::Mode::ADD};
  312. auto timeit = [&](Elemwise* opr) {
  313. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  314. auto start = clock();
  315. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  316. auto stop = clock();
  317. return (stop - start) * 1e3 / CLOCKS_PER_SEC;
  318. };
  319. auto t0 = timeit(opr_cur.get()), t1 = timeit(opr_naive.get());
  320. double tot_size_gb_ms =
  321. (op0.layout().span().dist_byte() + op1.layout().span().dist_byte() +
  322. out.layout().span().dist_byte()) /
  323. 1024.0 / 1024.0 / 1024.0 * 1e3;
  324. printf("%15s+%-15s: fallback=%7.3fms,%5.2fGiB/s "
  325. "naive=%7.3fms,%5.2fGiB/s\n",
  326. shp0.to_string().c_str(), shp1.to_string().c_str(), t0,
  327. tot_size_gb_ms / t0, t1, tot_size_gb_ms / t1);
  328. };
  329. // contig
  330. run({1024, 1024, 32}, {1024, 1024, 32});
  331. // bcast 101
  332. run({1024, 1024, 32}, {1, 1024, 1});
  333. // bcast 01
  334. run({4096 * 4, 1024}, {4096 * 4, 1});
  335. // bcast 10
  336. run({4096 * 4, 1024}, {1, 1024});
  337. // non-contig, fallback to naive
  338. run({1024, 1024, 32}, {1024, 1, 32});
  339. }
  340. #endif
  341. // vim: syntax=cpp.doxygen