You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. /**
  2. * \file dnn/test/arm_common/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/common/elemwise.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "megdnn/oprs/general.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. template <typename tag>
  20. class ARM_ELEMWISE : public ARM_COMMON {};
  21. TYPED_TEST_CASE(ARM_ELEMWISE, elemwise::test_types);
  22. TYPED_TEST(ARM_ELEMWISE, run) {
  23. elemwise::run_test<TypeParam>(this->handle());
  24. }
  25. template <typename tag>
  26. class ARM_ELEMWISE_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {};
  27. TYPED_TEST_CASE(ARM_ELEMWISE_MULTI_THREADS, elemwise::test_types);
  28. TYPED_TEST(ARM_ELEMWISE_MULTI_THREADS, run) {
  29. elemwise::run_test<TypeParam>(this->handle());
  30. }
  31. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
  32. using Mode = ElemwiseForward::Param::Mode;
  33. Checker<ElemwiseForward> checker(handle());
  34. checker.set_param(Mode::FUSE_MUL_ADD3);
  35. auto run = [&] {
  36. //! nchw44
  37. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  38. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  39. checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  40. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  41. checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  42. //! nchw44
  43. checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  44. checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  45. checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  46. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  47. checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  48. //! nchw88
  49. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  50. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  51. checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  52. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  53. checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  54. //! nchw88
  55. checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  56. checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  57. checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  58. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  59. checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  60. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  61. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  62. checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
  63. checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
  64. checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
  65. checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
  66. checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
  67. checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  68. checker.execs({{3, 4, 5}, {1}, {1}, {}});
  69. checker.execs({{1}, {3, 4, 5}, {1}, {}});
  70. };
  71. // case int
  72. checker.set_dtype(0, dtype::Int8());
  73. checker.set_dtype(1, dtype::Int8());
  74. checker.set_dtype(2, dtype::Int8());
  75. run();
  76. checker.set_dtype(0, dtype::Int16());
  77. checker.set_dtype(1, dtype::Int16());
  78. checker.set_dtype(2, dtype::Int16());
  79. run();
  80. checker.set_dtype(0, dtype::Int32());
  81. checker.set_dtype(1, dtype::Int32());
  82. checker.set_dtype(2, dtype::Int32());
  83. run();
  84. // case float
  85. UniformFloatRNG rng(1e-5, 7e1);
  86. checker.set_rng(0, &rng);
  87. checker.set_epsilon(1e-5);
  88. checker.set_dtype(0, dtype::Float32());
  89. checker.set_dtype(1, dtype::Float32());
  90. checker.set_dtype(2, dtype::Float32());
  91. run();
  92. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  93. // case half
  94. UniformFloatRNG rng_float16(1, 10);
  95. checker.set_rng(0, &rng_float16);
  96. checker.set_epsilon(1e-2);
  97. checker.set_dtype(0, dtype::Float16());
  98. checker.set_dtype(1, dtype::Float16());
  99. checker.set_dtype(2, dtype::Float16());
  100. run();
  101. #endif
  102. }
  103. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
  104. using Mode = ElemwiseForward::Param::Mode;
  105. Checker<ElemwiseForward> checker(handle());
  106. auto run = [&]() {
  107. // VEC_BCAST101x not PowOp
  108. checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  109. checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  110. checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  111. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  112. checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  113. checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  114. checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  115. checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  116. checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  117. checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  118. checker.set_param(Mode::FUSE_ADD_RELU)
  119. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  120. checker.set_param(Mode::FUSE_ADD_RELU)
  121. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  122. checker.set_param(Mode::FUSE_ADD_RELU)
  123. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  124. checker.set_param(Mode::FUSE_ADD_RELU)
  125. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  126. checker.set_param(Mode::FUSE_ADD_RELU)
  127. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  128. // BCAST101x_VEC not PowOp
  129. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  130. checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  131. checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  132. checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  133. checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  134. checker.set_param(Mode::FUSE_ADD_RELU)
  135. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  136. checker.set_param(Mode::FUSE_ADD_RELU)
  137. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  138. checker.set_param(Mode::FUSE_ADD_RELU)
  139. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  140. checker.set_param(Mode::FUSE_ADD_RELU)
  141. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  142. checker.set_param(Mode::FUSE_ADD_RELU)
  143. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  144. };
  145. checker.set_dtype(0, dtype::Int8());
  146. checker.set_dtype(1, dtype::Int8());
  147. run();
  148. checker.set_dtype(0, dtype::Int16());
  149. checker.set_dtype(1, dtype::Int16());
  150. run();
  151. checker.set_dtype(0, dtype::Int32());
  152. checker.set_dtype(1, dtype::Int32());
  153. run();
  154. }
  155. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
  156. using Mode = ElemwiseForward::Param::Mode;
  157. Checker<ElemwiseForward> checker(handle());
  158. UniformFloatRNG rng(1e-5, 7e1);
  159. checker.set_rng(0, &rng);
  160. checker.set_epsilon(1e-5);
  161. checker.set_dtype(0, dtype::Float32());
  162. checker.set_dtype(1, dtype::Float32());
  163. checker.set_param(Mode::FUSE_ADD_RELU)
  164. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  165. checker.set_param(Mode::FUSE_ADD_RELU)
  166. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  167. checker.set_param(Mode::FUSE_ADD_RELU)
  168. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  169. checker.set_param(Mode::FUSE_ADD_RELU)
  170. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  171. checker.set_param(Mode::FUSE_ADD_RELU)
  172. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  173. checker.set_param(Mode::FUSE_ADD_RELU)
  174. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  175. checker.set_param(Mode::FUSE_ADD_RELU)
  176. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  177. checker.set_param(Mode::FUSE_ADD_RELU)
  178. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  179. checker.set_param(Mode::FUSE_ADD_RELU)
  180. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  181. checker.set_param(Mode::FUSE_ADD_RELU)
  182. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  183. auto run = [&](Mode mode) {
  184. // VEC_BCAST101x
  185. checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  186. checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  187. checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  188. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  189. checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  190. // BCAST101x_VEC not powOp
  191. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  192. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  193. checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  194. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  195. checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  196. };
  197. run(Mode::ADD);
  198. run(Mode::FUSE_ADD_H_SWISH);
  199. run(Mode::FUSE_ADD_RELU);
  200. run(Mode::MAX);
  201. run(Mode::MIN);
  202. run(Mode::MUL);
  203. run(Mode::SUB);
  204. run(Mode::TRUE_DIV);
  205. run(Mode::POW);
  206. }
  207. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
  208. using Mode = ElemwiseForward::Param::Mode;
  209. Checker<ElemwiseForward> checker(handle());
  210. checker.set_param(Mode::FUSE_ADD_RELU)
  211. .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  212. checker.set_param(Mode::FUSE_ADD_RELU)
  213. .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  214. checker.set_param(Mode::FUSE_ADD_RELU)
  215. .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  216. checker.set_param(Mode::FUSE_ADD_RELU)
  217. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  218. checker.set_param(Mode::FUSE_ADD_RELU)
  219. .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  220. checker.set_param(Mode::FUSE_ADD_RELU)
  221. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  222. checker.set_param(Mode::FUSE_ADD_RELU)
  223. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  224. checker.set_param(Mode::FUSE_ADD_RELU)
  225. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  226. checker.set_param(Mode::FUSE_ADD_RELU)
  227. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  228. checker.set_param(Mode::FUSE_ADD_RELU)
  229. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  230. auto run = [&](Mode mode) {
  231. // VEC_BCAST101x
  232. checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  233. checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  234. checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  235. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  236. checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  237. // BCAST101x_VEC not powOp
  238. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  239. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  240. checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  241. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  242. checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  243. };
  244. auto run_all = [&]() {
  245. run(Mode::ADD);
  246. run(Mode::FUSE_ADD_H_SWISH);
  247. run(Mode::FUSE_ADD_RELU);
  248. run(Mode::MAX);
  249. run(Mode::MIN);
  250. run(Mode::MUL);
  251. run(Mode::SUB);
  252. run(Mode::TRUE_DIV);
  253. run(Mode::POW);
  254. };
  255. {
  256. UniformFloatRNG rng(1e-5, 7e1);
  257. checker.set_rng(0, &rng);
  258. checker.set_epsilon(1e-5);
  259. checker.set_dtype(0, dtype::Float32());
  260. checker.set_dtype(1, dtype::Float32());
  261. run_all();
  262. }
  263. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  264. {
  265. UniformFloatRNG rng(1, 2);
  266. checker.set_rng(0, &rng);
  267. checker.set_epsilon(3e-3);
  268. checker.set_dtype(0, dtype::Float16());
  269. checker.set_dtype(1, dtype::Float16());
  270. run_all();
  271. }
  272. #endif
  273. }
  274. #if MEGDNN_WITH_BENCHMARK
  275. namespace {
  276. void run_elemwise_benchmark(
  277. const TensorShapeArray& shapes, param::Elemwise::Mode mode,
  278. const char* mode_str, DType type, Handle* handle_bench) {
  279. auto handle_fallback = create_cpu_handle(1);
  280. Benchmarker<Elemwise> benchmarker_bench(handle_bench);
  281. Benchmarker<Elemwise> benchmarker_fallback(handle_fallback.get());
  282. float throughput = 0;
  283. SmallVector<TensorLayout> layouts;
  284. std::string src_strs;
  285. for (size_t i = 0; i < shapes.size(); i++) {
  286. layouts.emplace_back(shapes[i], type);
  287. throughput += layouts.back().span().dist_byte();
  288. src_strs += layouts.back().to_string();
  289. if (i != shapes.size() - 1) {
  290. src_strs += ",";
  291. }
  292. }
  293. constexpr size_t RUN = 50;
  294. benchmarker_fallback.set_times(RUN).set_display(false);
  295. benchmarker_bench.set_times(RUN).set_display(false);
  296. benchmarker_fallback.set_param(mode);
  297. benchmarker_bench.set_param(mode);
  298. TensorLayout dst_layout;
  299. auto opr = handle_bench->create_operator<Elemwise>();
  300. opr->param() = mode;
  301. opr->deduce_layout(layouts, dst_layout);
  302. float computations =
  303. dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1);
  304. throughput += dst_layout.span().dist_byte();
  305. computations *= (1e3 / (1024.0 * 1024));
  306. throughput *= (1e3 / (1024.0 * 1024));
  307. layouts.emplace_back(dst_layout);
  308. auto fallback_time = benchmarker_fallback.execl(layouts) / RUN;
  309. auto bench_time = benchmarker_bench.execl(layouts) / RUN;
  310. float fallback_flops = computations / fallback_time;
  311. float bench_flops = computations / bench_time;
  312. float fallback_thr = throughput / fallback_time;
  313. float bench_thr = throughput / bench_time;
  314. printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
  315. "%fMB/s "
  316. "computations: %fx, throughput: %fx\n",
  317. src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), mode_str,
  318. fallback_flops, fallback_thr, bench_flops, bench_thr,
  319. bench_flops / fallback_flops, bench_thr / fallback_thr);
  320. }
  321. } // namespace
  322. #define INT_RUN(shape, mode) \
  323. run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
  324. run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
  325. run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle());
  326. #define FLOAT_RUN(shape, mode) \
  327. run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \
  328. run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle());
  329. #define BENCHMARK_CASES(shape) \
  330. INT_BENCHMARK_CASES(shape) \
  331. FLOAT_BENCHMARK_CASES(shape)
  332. TEST_F(ARM_COMMON, BENCHMARK_UNARY) {
  333. #define INT_BENCHMARK_CASES(shape) \
  334. INT_RUN(shape, Mode::RELU); \
  335. INT_RUN(shape, Mode::ABS);
  336. #define FLOAT_BENCHMARK_CASES(shape) \
  337. FLOAT_RUN(shape, Mode::RELU); \
  338. FLOAT_RUN(shape, Mode::ABS); \
  339. FLOAT_RUN(shape, Mode::SIGMOID); \
  340. FLOAT_RUN(shape, Mode::EXP); \
  341. FLOAT_RUN(shape, Mode::TANH); \
  342. FLOAT_RUN(shape, Mode::FAST_TANH);
  343. using Mode = param::Elemwise::Mode;
  344. BENCHMARK_CASES({{10000}});
  345. BENCHMARK_CASES({{50000}});
  346. #undef INT_BENCHMARK_CASES
  347. #undef FLOAT_BENCHMARK_CASES
  348. }
  349. TEST_F(ARM_COMMON, BENCHMARK_BINARY) {
  350. #define INT_BENCHMARK_CASES(shape) \
  351. INT_RUN(shape, Mode::MIN); \
  352. INT_RUN(shape, Mode::MAX); \
  353. INT_RUN(shape, Mode::ADD); \
  354. INT_RUN(shape, Mode::SUB); \
  355. INT_RUN(shape, Mode::MUL); \
  356. INT_RUN(shape, Mode::RMULH); \
  357. INT_RUN(shape, Mode::FUSE_ADD_RELU);
  358. #define FLOAT_BENCHMARK_CASES(shape) \
  359. FLOAT_RUN(shape, Mode::MIN); \
  360. FLOAT_RUN(shape, Mode::MAX); \
  361. FLOAT_RUN(shape, Mode::ADD); \
  362. FLOAT_RUN(shape, Mode::SUB); \
  363. FLOAT_RUN(shape, Mode::MUL); \
  364. FLOAT_RUN(shape, Mode::POW); \
  365. FLOAT_RUN(shape, Mode::TRUE_DIV); \
  366. FLOAT_RUN(shape, Mode::FUSE_ADD_RELU);
  367. using Mode = param::Elemwise::Mode;
  368. TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}};
  369. BENCHMARK_CASES(shapes);
  370. shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}};
  371. BENCHMARK_CASES(shapes);
  372. shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}};
  373. BENCHMARK_CASES(shapes);
  374. #undef INT_BENCHMARK_CASES
  375. #undef FLOAT_BENCHMARK_CASES
  376. }
  377. TEST_F(ARM_COMMON, BENCHMARK_TERNARY_FMA3) {
  378. #define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3);
  379. #define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3);
  380. using Mode = param::Elemwise::Mode;
  381. TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}};
  382. BENCHMARK_CASES(shapes);
  383. shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}};
  384. BENCHMARK_CASES(shapes);
  385. shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}};
  386. BENCHMARK_CASES(shapes);
  387. #undef INT_BENCHMARK_CASES
  388. #undef FLOAT_BENCHMARK_CASES
  389. }
  390. #undef BENCHMARK_CASES
  391. #undef INT_RUN
  392. #undef FLOAT_RUN
  393. #endif
  394. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台