You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 9.1 kB


  1. /**
  2. * \file dnn/test/x86/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/elemwise.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/rng.h"
  15. #include "test/x86/fixture.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. void print4D(const TensorND& tensor) {
  19. TensorLayout layout = tensor.layout;
  20. float* result = tensor.ptr<float>();
  21. size_t N = layout.shape[0], C = layout.shape[1], H = layout.shape[2],
  22. W = layout.shape[3];
  23. size_t it = 0;
  24. rep(n, N) {
  25. rep(c, C) {
  26. rep(h, H) {
  27. rep(w, W) { printf("%.4f ", result[it++]); }
  28. printf("\n");
  29. }
  30. printf("\n");
  31. }
  32. printf("\n");
  33. }
  34. }
  35. #define UNARY_TEST_CASE(_optr) \
  36. checker.set_param(Mode::_optr).execs({{1, 1556011}, {}}); \
  37. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  38. #define BUILD_UNARY_TEST_CASE_INT \
  39. UNARY_TEST_CASE(RELU) \
  40. UNARY_TEST_CASE(ABS)
  41. #define BUILD_UNARY_TEST_CASE_FLOAT \
  42. UNARY_TEST_CASE(ABS) \
  43. UNARY_TEST_CASE(LOG) \
  44. UNARY_TEST_CASE(COS) \
  45. UNARY_TEST_CASE(SIN) \
  46. UNARY_TEST_CASE(FLOOR) \
  47. UNARY_TEST_CASE(CEIL) \
  48. UNARY_TEST_CASE(SIGMOID) \
  49. UNARY_TEST_CASE(EXP) \
  50. UNARY_TEST_CASE(TANH) \
  51. UNARY_TEST_CASE(RELU) \
  52. UNARY_TEST_CASE(ROUND)
  53. TEST_F(X86, ELEMWISE_FORWARD_UNARY) {
  54. using Mode = ElemwiseForward::Param::Mode;
  55. Checker<ElemwiseForward> checker(handle());
  56. // case int
  57. checker.set_dtype(0, dtype::Int8());
  58. BUILD_UNARY_TEST_CASE_INT
  59. checker.set_dtype(0, dtype::Int16());
  60. BUILD_UNARY_TEST_CASE_INT
  61. checker.set_dtype(0, dtype::Int32());
  62. BUILD_UNARY_TEST_CASE_INT
  63. // case float
  64. UniformFloatRNG rng(1e-2, 6e1);
  65. checker.set_rng(0, &rng);
  66. checker.set_epsilon(1e-6);
  67. checker.set_dtype(0, dtype::Float32());
  68. BUILD_UNARY_TEST_CASE_FLOAT
  69. }
  70. #define BINARY_TEST_CASE(_optr) \
  71. checker.set_param(Mode::_optr).execs({{3, 4, 17}, {3, 4, 17}, {}}); \
  72. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  73. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  74. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}});
  75. #define BUILD_BINARY_TEST_CASE \
  76. BINARY_TEST_CASE(MIN) \
  77. BINARY_TEST_CASE(MAX)
  78. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  79. printf("Check binary optr %s by all cases.\n", #_optr); \
  80. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  81. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  82. checker.set_param(Mode::_optr) \
  83. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); \
  84. checker.set_param(Mode::_optr) \
  85. .execs({{3, 4, 5, 7, 8}, {1, 4, 1, 1, 8}, {}}); \
  86. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  87. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  88. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  89. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  90. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  91. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  92. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  93. BINARY_COMPLATE_TEST_CASE(ADD) \
  94. BINARY_COMPLATE_TEST_CASE(MUL) \
  95. BINARY_COMPLATE_TEST_CASE(MAX) \
  96. BINARY_COMPLATE_TEST_CASE(MIN) \
  97. BINARY_COMPLATE_TEST_CASE(SUB)
  98. #define BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32 \
  99. BINARY_COMPLATE_TEST_CASE(TRUE_DIV) \
  100. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_SIGMOID) \
  101. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_TANH) \
  102. BINARY_COMPLATE_TEST_CASE(FUSE_ADD_RELU)
  103. TEST_F(X86, ELEMWISE_FORWARD_NCHW88) {
  104. using Mode = ElemwiseForward::Param::Mode;
  105. Checker<ElemwiseForward> checker(handle());
  106. // case float
  107. UniformFloatRNG rng(1e-5, 7e1);
  108. checker.set_rng(0, &rng);
  109. checker.set_epsilon(1e-5);
  110. checker.set_dtype(0, dtype::Float32());
  111. checker.set_dtype(1, dtype::Float32());
  112. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  113. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  114. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  115. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  116. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  117. checker.set_param(Mode::FUSE_ADD_RELU)
  118. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  119. checker.set_param(Mode::FUSE_ADD_RELU)
  120. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  121. checker.set_param(Mode::FUSE_ADD_RELU)
  122. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  123. checker.set_param(Mode::FUSE_ADD_RELU)
  124. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  125. checker.set_param(Mode::FUSE_ADD_RELU)
  126. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  127. }
  128. TEST_F(X86, ELEMWISE_FORWARD_BINARY) {
  129. using Mode = ElemwiseForward::Param::Mode;
  130. Checker<ElemwiseForward> checker(handle());
  131. // case float
  132. UniformFloatRNG rng(1e-5, 7e1);
  133. checker.set_rng(0, &rng);
  134. checker.set_epsilon(1e-5);
  135. checker.set_dtype(0, dtype::Float32());
  136. checker.set_dtype(1, dtype::Float32());
  137. BUILD_BINARY_COMPLATE_TEST_CASE
  138. BUILD_BINARY_COMPLATE_TEST_CASE_FLOAT32
  139. // case int
  140. checker.set_dtype(0, dtype::Int8());
  141. checker.set_dtype(1, dtype::Int8());
  142. BUILD_BINARY_TEST_CASE
  143. BUILD_BINARY_COMPLATE_TEST_CASE
  144. checker.set_dtype(0, dtype::Int16());
  145. checker.set_dtype(1, dtype::Int16());
  146. BUILD_BINARY_TEST_CASE
  147. BUILD_BINARY_COMPLATE_TEST_CASE
  148. checker.set_dtype(0, dtype::Int32());
  149. checker.set_dtype(1, dtype::Int32());
  150. BUILD_BINARY_TEST_CASE
  151. BUILD_BINARY_COMPLATE_TEST_CASE
  152. }
  153. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  154. printf("Check ternary optr %s by all cases.\n", #_optr); \
  155. checker.set_param(Mode::_optr) \
  156. .execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  157. checker.set_param(Mode::_optr) \
  158. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  159. checker.set_param(Mode::_optr) \
  160. .execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  161. checker.set_param(Mode::_optr) \
  162. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  163. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  164. checker.set_param(Mode::_optr) \
  165. .execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  166. checker.set_param(Mode::_optr) \
  167. .execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  168. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  169. #define BUILD_TERNARY_COMPLATE_TEST_CASE \
  170. TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  171. TEST_F(X86, ELEMWISE_FORWARD_TERNARY) {
  172. using Mode = ElemwiseForward::Param::Mode;
  173. Checker<ElemwiseForward> checker(handle());
  174. // case int
  175. checker.set_dtype(0, dtype::Int8());
  176. checker.set_dtype(1, dtype::Int8());
  177. checker.set_dtype(2, dtype::Int8());
  178. // BUILD_TERNARY_TEST_CASE
  179. BUILD_TERNARY_COMPLATE_TEST_CASE
  180. checker.set_dtype(0, dtype::Int16());
  181. checker.set_dtype(1, dtype::Int16());
  182. checker.set_dtype(2, dtype::Int16());
  183. // BUILD_TERNARY_TEST_CASE
  184. BUILD_TERNARY_COMPLATE_TEST_CASE
  185. checker.set_dtype(0, dtype::Int32());
  186. checker.set_dtype(1, dtype::Int32());
  187. checker.set_dtype(2, dtype::Int32());
  188. // BUILD_TERNARY_TEST_CASE
  189. BUILD_TERNARY_COMPLATE_TEST_CASE
  190. // case float
  191. UniformFloatRNG rng(1e-5, 7e1);
  192. checker.set_rng(0, &rng);
  193. checker.set_epsilon(1e-5);
  194. checker.set_dtype(0, dtype::Float32());
  195. checker.set_dtype(1, dtype::Float32());
  196. checker.set_dtype(2, dtype::Float32());
  197. // BUILD_TERNARY_TEST_CASE
  198. BUILD_TERNARY_COMPLATE_TEST_CASE
  199. }
  200. template <typename tag>
  201. class X86_ELEMWISE : public X86 {};
  202. TYPED_TEST_CASE(X86_ELEMWISE, elemwise::test_types);
  203. TYPED_TEST(X86_ELEMWISE, run) {
  204. elemwise::run_test<TypeParam>(this->handle());
  205. }
  206. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台