You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. /**
  2. * \file dnn/test/x86/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/elemwise.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/rng.h"
  15. #include "test/common/task_record_check.h"
  16. #include "test/x86/fixture.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. void print4D(const TensorND& tensor) {
  20. TensorLayout layout = tensor.layout;
  21. float* result = tensor.ptr<float>();
  22. size_t N = layout.shape[0], C = layout.shape[1], H = layout.shape[2],
  23. W = layout.shape[3];
  24. size_t it = 0;
  25. rep(n, N) {
  26. rep(c, C) {
  27. rep(h, H) {
  28. rep(w, W) { printf("%.4f ", result[it++]); }
  29. printf("\n");
  30. }
  31. printf("\n");
  32. }
  33. printf("\n");
  34. }
  35. }
  36. #define UNARY_TEST_CASE(_optr) \
  37. checker.set_param(Mode::_optr).execs({{1, 1556011}, {}}); \
  38. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  39. #define BUILD_UNARY_TEST_CASE_INT \
  40. UNARY_TEST_CASE(RELU) \
  41. UNARY_TEST_CASE(ABS)
  42. #define BUILD_UNARY_TEST_CASE_FLOAT \
  43. UNARY_TEST_CASE(ABS) \
  44. UNARY_TEST_CASE(LOG) \
  45. UNARY_TEST_CASE(COS) \
  46. UNARY_TEST_CASE(SIN) \
  47. UNARY_TEST_CASE(FLOOR) \
  48. UNARY_TEST_CASE(CEIL) \
  49. UNARY_TEST_CASE(SIGMOID) \
  50. UNARY_TEST_CASE(EXP) \
  51. UNARY_TEST_CASE(TANH) \
  52. UNARY_TEST_CASE(RELU) \
  53. UNARY_TEST_CASE(ROUND)
  54. TEST_F(X86, ELEMWISE_FORWARD_UNARY) {
  55. using Mode = ElemwiseForward::Param::Mode;
  56. Checker<ElemwiseForward> checker(handle());
  57. // case int
  58. checker.set_dtype(0, dtype::Int8());
  59. BUILD_UNARY_TEST_CASE_INT
  60. checker.set_dtype(0, dtype::Int16());
  61. BUILD_UNARY_TEST_CASE_INT
  62. checker.set_dtype(0, dtype::Int32());
  63. BUILD_UNARY_TEST_CASE_INT
  64. // case float
  65. UniformFloatRNG rng(1e-2, 6e1);
  66. checker.set_rng(0, &rng);
  67. checker.set_epsilon(1e-6);
  68. checker.set_dtype(0, dtype::Float32());
  69. BUILD_UNARY_TEST_CASE_FLOAT
  70. }
  71. #define BINARY_TEST_CASE(_optr) \
  72. checker.set_param(Mode::_optr).execs({{3, 4, 17}, {3, 4, 17}, {}}); \
  73. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  74. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  75. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  76. #define BUILD_BINARY_TEST_CASE \
  77. BINARY_TEST_CASE(MIN) \
  78. BINARY_TEST_CASE(MAX)
  79. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  80. printf("Check binary optr %s by all cases.\n", #_optr); \
  81. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  82. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  83. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); \
  84. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7, 8}, {1, 4, 1, 1, 8}, {}}); \
  85. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  86. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  87. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  88. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  89. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  90. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  91. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  92. BINARY_COMPLATE_TEST_CASE(ADD) \
  93. BINARY_COMPLATE_TEST_CASE(MUL) \
  94. BINARY_COMPLATE_TEST_CASE(MAX) \
  95. BINARY_COMPLATE_TEST_CASE(MIN) \
  96. BINARY_COMPLATE_TEST_CASE(SUB)
  97. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  98. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  99. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  100. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  101. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  102. TEST_F(X86, ELEMWISE_FORWARD_NCHW88) {
  103. using Mode = ElemwiseForward::Param::Mode;
  104. Checker<ElemwiseForward> checker(handle());
  105. // case float
  106. UniformFloatRNG rng(1e-5, 7e1);
  107. checker.set_rng(0, &rng);
  108. checker.set_epsilon(1e-5);
  109. checker.set_dtype(0, dtype::Float32());
  110. checker.set_dtype(1, dtype::Float32());
  111. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  112. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  113. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  114. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  115. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  116. checker.set_param(Mode::FUSE_ADD_RELU)
  117. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  118. checker.set_param(Mode::FUSE_ADD_RELU)
  119. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  120. checker.set_param(Mode::FUSE_ADD_RELU)
  121. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  122. checker.set_param(Mode::FUSE_ADD_RELU)
  123. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  124. checker.set_param(Mode::FUSE_ADD_RELU)
  125. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  126. }
  127. TEST_F(X86, ELEMWISE_FORWARD_BINARY) {
  128. using Mode = ElemwiseForward::Param::Mode;
  129. Checker<ElemwiseForward> checker(handle());
  130. // case float
  131. UniformFloatRNG rng(1e-5, 7e1);
  132. checker.set_rng(0, &rng);
  133. checker.set_epsilon(1e-5);
  134. checker.set_dtype(0, dtype::Float32());
  135. checker.set_dtype(1, dtype::Float32());
  136. BUILD_BINARY_COMPLATE_TEST_CASE
  137. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  138. // case int
  139. checker.set_dtype(0, dtype::Int8());
  140. checker.set_dtype(1, dtype::Int8());
  141. BUILD_BINARY_TEST_CASE
  142. BUILD_BINARY_COMPLATE_TEST_CASE
  143. checker.set_dtype(0, dtype::Int16());
  144. checker.set_dtype(1, dtype::Int16());
  145. BUILD_BINARY_TEST_CASE
  146. BUILD_BINARY_COMPLATE_TEST_CASE
  147. checker.set_dtype(0, dtype::Int32());
  148. checker.set_dtype(1, dtype::Int32());
  149. BUILD_BINARY_TEST_CASE
  150. BUILD_BINARY_COMPLATE_TEST_CASE
  151. }
  152. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  153. printf("Check ternary optr %s by all cases.\n", #_optr); \
  154. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  155. checker.set_param(Mode::_optr) \
  156. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  157. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  158. checker.set_param(Mode::_optr) \
  159. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  160. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  161. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  162. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  163. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  164. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  165. TEST_F(X86, ELEMWISE_FORWARD_TERNARY) {
  166. using Mode = ElemwiseForward::Param::Mode;
  167. Checker<ElemwiseForward> checker(handle());
  168. // case int
  169. checker.set_dtype(0, dtype::Int8());
  170. checker.set_dtype(1, dtype::Int8());
  171. checker.set_dtype(2, dtype::Int8());
  172. // BUILD_TERNARY_TEST_CASE
  173. BUILD_TERNARY_COMPLATE_TEST_CASE
  174. checker.set_dtype(0, dtype::Int16());
  175. checker.set_dtype(1, dtype::Int16());
  176. checker.set_dtype(2, dtype::Int16());
  177. // BUILD_TERNARY_TEST_CASE
  178. BUILD_TERNARY_COMPLATE_TEST_CASE
  179. checker.set_dtype(0, dtype::Int32());
  180. checker.set_dtype(1, dtype::Int32());
  181. checker.set_dtype(2, dtype::Int32());
  182. // BUILD_TERNARY_TEST_CASE
  183. BUILD_TERNARY_COMPLATE_TEST_CASE
  184. // case float
  185. UniformFloatRNG rng(1e-5, 7e1);
  186. checker.set_rng(0, &rng);
  187. checker.set_epsilon(1e-5);
  188. checker.set_dtype(0, dtype::Float32());
  189. checker.set_dtype(1, dtype::Float32());
  190. checker.set_dtype(2, dtype::Float32());
  191. // BUILD_TERNARY_TEST_CASE
  192. BUILD_TERNARY_COMPLATE_TEST_CASE
  193. }
  194. template <typename tag>
  195. class X86_ELEMWISE : public X86 {};
  196. TYPED_TEST_CASE(X86_ELEMWISE, elemwise::test_types);
  197. TYPED_TEST(X86_ELEMWISE, run) {
  198. elemwise::run_test<TypeParam>(this->handle());
  199. }
  200. #undef UNARY_TEST_CASE
  201. #undef BUILD_UNARY_TEST_CASE_FLOAT
  202. #define UNARY_TEST_CASE(_optr) checker.set_param(Mode::_optr).execs({{1, 155}, {}});
  203. #define BUILD_UNARY_TEST_CASE_FLOAT UNARY_TEST_CASE(ABS)
  204. TEST_F(X86, ELEMWISE_UNARY_RECORD) {
  205. using Mode = ElemwiseForward::Param::Mode;
  206. TaskRecordChecker<ElemwiseForward> checker(0);
  207. // case float
  208. UniformFloatRNG rng(1e-2, 6e1);
  209. checker.set_rng(0, &rng);
  210. checker.set_epsilon(1e-6);
  211. checker.set_dtype(0, dtype::Float32());
  212. BUILD_UNARY_TEST_CASE_FLOAT
  213. }
  214. #undef BINARY_COMPLATE_TEST_CASE
  215. #undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  216. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  217. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}});
  218. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 BINARY_COMPLATE_TEST_CASE(ADD)
  219. TEST_F(X86, ELEMWISE_BINARY_RECORD) {
  220. using Mode = ElemwiseForward::Param::Mode;
  221. TaskRecordChecker<ElemwiseForward> checker(0);
  222. // case float
  223. UniformFloatRNG rng(1e-5, 7e1);
  224. checker.set_rng(0, &rng);
  225. checker.set_epsilon(1e-5);
  226. checker.set_dtype(0, dtype::Float32());
  227. checker.set_dtype(1, dtype::Float32());
  228. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  229. }
  230. #undef TERNARY_COMPLATE_TEST_CASE
  231. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  232. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  233. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  234. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  235. TEST_F(X86, ELEMWISE_TERNARY_RECORD) {
  236. using Mode = ElemwiseForward::Param::Mode;
  237. TaskRecordChecker<ElemwiseForward> checker(0);
  238. // case int
  239. checker.set_dtype(0, dtype::Int8());
  240. checker.set_dtype(1, dtype::Int8());
  241. checker.set_dtype(2, dtype::Int8());
  242. // BUILD_TERNARY_TEST_CASE
  243. BUILD_TERNARY_COMPLATE_TEST_CASE
  244. }
  245. // vim: syntax=cpp.doxygen