You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. #include "test/common/elemwise.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/rng.h"
  5. #include "test/common/task_record_check.h"
  6. #include "test/x86/fixture.h"
  7. using namespace megdnn;
  8. using namespace test;
  9. void print4D(const TensorND& tensor) {
  10. TensorLayout layout = tensor.layout;
  11. float* result = tensor.ptr<float>();
  12. size_t N = layout.shape[0], C = layout.shape[1], H = layout.shape[2],
  13. W = layout.shape[3];
  14. size_t it = 0;
  15. rep(n, N) {
  16. rep(c, C) {
  17. rep(h, H) {
  18. rep(w, W) { printf("%.4f ", result[it++]); }
  19. printf("\n");
  20. }
  21. printf("\n");
  22. }
  23. printf("\n");
  24. }
  25. }
  26. #define UNARY_TEST_CASE(_optr) \
  27. checker.set_param(Mode::_optr).execs({{1, 1556011}, {}}); \
  28. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  29. #define BUILD_UNARY_TEST_CASE_INT \
  30. UNARY_TEST_CASE(RELU) \
  31. UNARY_TEST_CASE(ABS)
  32. #define BUILD_UNARY_TEST_CASE_FLOAT \
  33. UNARY_TEST_CASE(ABS) \
  34. UNARY_TEST_CASE(LOG) \
  35. UNARY_TEST_CASE(COS) \
  36. UNARY_TEST_CASE(SIN) \
  37. UNARY_TEST_CASE(FLOOR) \
  38. UNARY_TEST_CASE(CEIL) \
  39. UNARY_TEST_CASE(SIGMOID) \
  40. UNARY_TEST_CASE(EXP) \
  41. UNARY_TEST_CASE(TANH) \
  42. UNARY_TEST_CASE(RELU) \
  43. UNARY_TEST_CASE(ROUND)
  44. TEST_F(X86, ELEMWISE_FORWARD_UNARY) {
  45. using Mode = ElemwiseForward::Param::Mode;
  46. Checker<ElemwiseForward> checker(handle());
  47. // case int
  48. checker.set_dtype(0, dtype::Int8());
  49. BUILD_UNARY_TEST_CASE_INT
  50. checker.set_dtype(0, dtype::Int16());
  51. BUILD_UNARY_TEST_CASE_INT
  52. checker.set_dtype(0, dtype::Int32());
  53. BUILD_UNARY_TEST_CASE_INT
  54. // case float
  55. UniformFloatRNG rng(1e-2, 6e1);
  56. checker.set_rng(0, &rng);
  57. checker.set_epsilon(1e-6);
  58. checker.set_dtype(0, dtype::Float32());
  59. BUILD_UNARY_TEST_CASE_FLOAT
  60. }
  61. #define BINARY_TEST_CASE(_optr) \
  62. checker.set_param(Mode::_optr).execs({{3, 4, 17}, {3, 4, 17}, {}}); \
  63. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  64. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  65. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  66. #define BUILD_BINARY_TEST_CASE \
  67. BINARY_TEST_CASE(MIN) \
  68. BINARY_TEST_CASE(MAX)
  69. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  70. printf("Check binary optr %s by all cases.\n", #_optr); \
  71. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  72. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  73. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); \
  74. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7, 8}, {1, 4, 1, 1, 8}, {}}); \
  75. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  76. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  77. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  78. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  79. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  80. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  81. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  82. BINARY_COMPLATE_TEST_CASE(ADD) \
  83. BINARY_COMPLATE_TEST_CASE(MUL) \
  84. BINARY_COMPLATE_TEST_CASE(MAX) \
  85. BINARY_COMPLATE_TEST_CASE(MIN) \
  86. BINARY_COMPLATE_TEST_CASE(SUB)
  87. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  88. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  89. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  90. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  91. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  92. TEST_F(X86, ELEMWISE_FORWARD_NCHW88) {
  93. using Mode = ElemwiseForward::Param::Mode;
  94. Checker<ElemwiseForward> checker(handle());
  95. // case float
  96. UniformFloatRNG rng(1e-5, 7e1);
  97. checker.set_rng(0, &rng);
  98. checker.set_epsilon(1e-5);
  99. checker.set_dtype(0, dtype::Float32());
  100. checker.set_dtype(1, dtype::Float32());
  101. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  102. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  103. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  104. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  105. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  106. checker.set_param(Mode::FUSE_ADD_RELU)
  107. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  108. checker.set_param(Mode::FUSE_ADD_RELU)
  109. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  110. checker.set_param(Mode::FUSE_ADD_RELU)
  111. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  112. checker.set_param(Mode::FUSE_ADD_RELU)
  113. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  114. checker.set_param(Mode::FUSE_ADD_RELU)
  115. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  116. }
  117. TEST_F(X86, ELEMWISE_FORWARD_BINARY) {
  118. using Mode = ElemwiseForward::Param::Mode;
  119. Checker<ElemwiseForward> checker(handle());
  120. // case float
  121. UniformFloatRNG rng(1e-5, 7e1);
  122. checker.set_rng(0, &rng);
  123. checker.set_epsilon(1e-5);
  124. checker.set_dtype(0, dtype::Float32());
  125. checker.set_dtype(1, dtype::Float32());
  126. BUILD_BINARY_COMPLATE_TEST_CASE
  127. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  128. // case int
  129. checker.set_dtype(0, dtype::Int8());
  130. checker.set_dtype(1, dtype::Int8());
  131. BUILD_BINARY_TEST_CASE
  132. BUILD_BINARY_COMPLATE_TEST_CASE
  133. checker.set_dtype(0, dtype::Int16());
  134. checker.set_dtype(1, dtype::Int16());
  135. BUILD_BINARY_TEST_CASE
  136. BUILD_BINARY_COMPLATE_TEST_CASE
  137. checker.set_dtype(0, dtype::Int32());
  138. checker.set_dtype(1, dtype::Int32());
  139. BUILD_BINARY_TEST_CASE
  140. BUILD_BINARY_COMPLATE_TEST_CASE
  141. }
  142. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  143. printf("Check ternary optr %s by all cases.\n", #_optr); \
  144. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  145. checker.set_param(Mode::_optr) \
  146. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  147. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  148. checker.set_param(Mode::_optr) \
  149. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  150. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  151. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  152. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  153. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  154. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  155. TEST_F(X86, ELEMWISE_FORWARD_TERNARY) {
  156. using Mode = ElemwiseForward::Param::Mode;
  157. Checker<ElemwiseForward> checker(handle());
  158. // case int
  159. checker.set_dtype(0, dtype::Int8());
  160. checker.set_dtype(1, dtype::Int8());
  161. checker.set_dtype(2, dtype::Int8());
  162. // BUILD_TERNARY_TEST_CASE
  163. BUILD_TERNARY_COMPLATE_TEST_CASE
  164. checker.set_dtype(0, dtype::Int16());
  165. checker.set_dtype(1, dtype::Int16());
  166. checker.set_dtype(2, dtype::Int16());
  167. // BUILD_TERNARY_TEST_CASE
  168. BUILD_TERNARY_COMPLATE_TEST_CASE
  169. checker.set_dtype(0, dtype::Int32());
  170. checker.set_dtype(1, dtype::Int32());
  171. checker.set_dtype(2, dtype::Int32());
  172. // BUILD_TERNARY_TEST_CASE
  173. BUILD_TERNARY_COMPLATE_TEST_CASE
  174. // case float
  175. UniformFloatRNG rng(1e-5, 7e1);
  176. checker.set_rng(0, &rng);
  177. checker.set_epsilon(1e-5);
  178. checker.set_dtype(0, dtype::Float32());
  179. checker.set_dtype(1, dtype::Float32());
  180. checker.set_dtype(2, dtype::Float32());
  181. // BUILD_TERNARY_TEST_CASE
  182. BUILD_TERNARY_COMPLATE_TEST_CASE
  183. }
  184. template <typename tag>
  185. class X86_ELEMWISE : public X86 {};
  186. TYPED_TEST_CASE(X86_ELEMWISE, elemwise::test_types);
  187. TYPED_TEST(X86_ELEMWISE, run) {
  188. elemwise::run_test<TypeParam>(this->handle());
  189. }
  190. #undef UNARY_TEST_CASE
  191. #undef BUILD_UNARY_TEST_CASE_FLOAT
  192. #define UNARY_TEST_CASE(_optr) checker.set_param(Mode::_optr).execs({{1, 155}, {}});
  193. #define BUILD_UNARY_TEST_CASE_FLOAT UNARY_TEST_CASE(ABS)
  194. TEST_F(X86, ELEMWISE_UNARY_RECORD) {
  195. using Mode = ElemwiseForward::Param::Mode;
  196. TaskRecordChecker<ElemwiseForward> checker(0);
  197. // case float
  198. UniformFloatRNG rng(1e-2, 6e1);
  199. checker.set_rng(0, &rng);
  200. checker.set_epsilon(1e-6);
  201. checker.set_dtype(0, dtype::Float32());
  202. BUILD_UNARY_TEST_CASE_FLOAT
  203. }
  204. #undef BINARY_COMPLATE_TEST_CASE
  205. #undef BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  206. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  207. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}});
  208. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 BINARY_COMPLATE_TEST_CASE(ADD)
  209. TEST_F(X86, ELEMWISE_BINARY_RECORD) {
  210. using Mode = ElemwiseForward::Param::Mode;
  211. TaskRecordChecker<ElemwiseForward> checker(0);
  212. // case float
  213. UniformFloatRNG rng(1e-5, 7e1);
  214. checker.set_rng(0, &rng);
  215. checker.set_epsilon(1e-5);
  216. checker.set_dtype(0, dtype::Float32());
  217. checker.set_dtype(1, dtype::Float32());
  218. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  219. }
  220. #undef TERNARY_COMPLATE_TEST_CASE
  221. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  222. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  223. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  224. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  225. TEST_F(X86, ELEMWISE_TERNARY_RECORD) {
  226. using Mode = ElemwiseForward::Param::Mode;
  227. TaskRecordChecker<ElemwiseForward> checker(0);
  228. // case int
  229. checker.set_dtype(0, dtype::Int8());
  230. checker.set_dtype(1, dtype::Int8());
  231. checker.set_dtype(2, dtype::Int8());
  232. // BUILD_TERNARY_TEST_CASE
  233. BUILD_TERNARY_COMPLATE_TEST_CASE
  234. }
  235. // vim: syntax=cpp.doxygen