You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. /**
  2. * \file dnn/test/arm_common/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/common/elemwise.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "megdnn/oprs/general.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. template <typename tag>
  20. class ARM_ELEMWISE : public ARM_COMMON {};
  21. TYPED_TEST_CASE(ARM_ELEMWISE, elemwise::test_types);
  22. TYPED_TEST(ARM_ELEMWISE, run) {
  23. elemwise::run_test<TypeParam>(this->handle());
  24. }
  25. template <typename tag>
  26. class ARM_ELEMWISE_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {};
  27. TYPED_TEST_CASE(ARM_ELEMWISE_MULTI_THREADS, elemwise::test_types);
  28. TYPED_TEST(ARM_ELEMWISE_MULTI_THREADS, run) {
  29. elemwise::run_test<TypeParam>(this->handle());
  30. }
  31. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
  32. using Mode = ElemwiseForward::Param::Mode;
  33. Checker<ElemwiseForward> checker(handle());
  34. checker.set_param(Mode::FUSE_MUL_ADD3);
  35. auto run = [&] {
  36. //! nchw44
  37. checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  38. checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  39. checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  40. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  41. checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  42. //! nchw44
  43. checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  44. checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  45. checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  46. checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  47. checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  48. //! nchw88
  49. checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  50. checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  51. checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  52. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  53. checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  54. //! nchw88
  55. checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  56. checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  57. checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  58. checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  59. checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  60. checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
  61. checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
  62. checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
  63. checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
  64. checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
  65. checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
  66. checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
  67. checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  68. checker.execs({{3, 4, 5}, {1}, {1}, {}});
  69. checker.execs({{1}, {3, 4, 5}, {1}, {}});
  70. };
  71. // case int
  72. checker.set_dtype(0, dtype::Int8());
  73. checker.set_dtype(1, dtype::Int8());
  74. checker.set_dtype(2, dtype::Int8());
  75. run();
  76. checker.set_dtype(0, dtype::Int16());
  77. checker.set_dtype(1, dtype::Int16());
  78. checker.set_dtype(2, dtype::Int16());
  79. run();
  80. checker.set_dtype(0, dtype::Int32());
  81. checker.set_dtype(1, dtype::Int32());
  82. checker.set_dtype(2, dtype::Int32());
  83. run();
  84. // case float
  85. UniformFloatRNG rng(1e-5, 7e1);
  86. checker.set_rng(0, &rng);
  87. checker.set_epsilon(1e-5);
  88. checker.set_dtype(0, dtype::Float32());
  89. checker.set_dtype(1, dtype::Float32());
  90. checker.set_dtype(2, dtype::Float32());
  91. run();
  92. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  93. // case half
  94. UniformFloatRNG rng_float16(1, 10);
  95. checker.set_rng(0, &rng_float16);
  96. checker.set_epsilon(1e-2);
  97. checker.set_dtype(0, dtype::Float16());
  98. checker.set_dtype(1, dtype::Float16());
  99. checker.set_dtype(2, dtype::Float16());
  100. run();
  101. #endif
  102. }
  103. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
  104. using Mode = ElemwiseForward::Param::Mode;
  105. Checker<ElemwiseForward> checker(handle());
  106. auto run = [&]() {
  107. // VEC_BCAST101x not PowOp
  108. checker.set_param(Mode::ADD).execs(
  109. {{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  110. checker.set_param(Mode::ADD).execs(
  111. {{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  112. checker.set_param(Mode::ADD).execs(
  113. {{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  114. checker.set_param(Mode::ADD).execs(
  115. {{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  116. checker.set_param(Mode::ADD).execs(
  117. {{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  118. checker.set_param(Mode::RMULH)
  119. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  120. checker.set_param(Mode::RMULH)
  121. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  122. checker.set_param(Mode::RMULH)
  123. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  124. checker.set_param(Mode::RMULH)
  125. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  126. checker.set_param(Mode::RMULH)
  127. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  128. checker.set_param(Mode::FUSE_ADD_RELU)
  129. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  130. checker.set_param(Mode::FUSE_ADD_RELU)
  131. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  132. checker.set_param(Mode::FUSE_ADD_RELU)
  133. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  134. checker.set_param(Mode::FUSE_ADD_RELU)
  135. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  136. checker.set_param(Mode::FUSE_ADD_RELU)
  137. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  138. // BCAST101x_VEC not PowOp
  139. checker.set_param(Mode::ADD).execs(
  140. {{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  141. checker.set_param(Mode::ADD).execs(
  142. {{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  143. checker.set_param(Mode::ADD).execs(
  144. {{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  145. checker.set_param(Mode::ADD).execs(
  146. {{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  147. checker.set_param(Mode::ADD).execs(
  148. {{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  149. checker.set_param(Mode::FUSE_ADD_RELU)
  150. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  151. checker.set_param(Mode::FUSE_ADD_RELU)
  152. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  153. checker.set_param(Mode::FUSE_ADD_RELU)
  154. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  155. checker.set_param(Mode::FUSE_ADD_RELU)
  156. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  157. checker.set_param(Mode::FUSE_ADD_RELU)
  158. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  159. };
  160. checker.set_dtype(0, dtype::Int8());
  161. checker.set_dtype(1, dtype::Int8());
  162. run();
  163. checker.set_dtype(0, dtype::Int16());
  164. checker.set_dtype(1, dtype::Int16());
  165. run();
  166. checker.set_dtype(0, dtype::Int32());
  167. checker.set_dtype(1, dtype::Int32());
  168. run();
  169. }
  170. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
  171. using Mode = ElemwiseForward::Param::Mode;
  172. Checker<ElemwiseForward> checker(handle());
  173. UniformFloatRNG rng(1e-5, 7e1);
  174. checker.set_rng(0, &rng);
  175. checker.set_epsilon(1e-5);
  176. checker.set_dtype(0, dtype::Float32());
  177. checker.set_dtype(1, dtype::Float32());
  178. checker.set_param(Mode::FUSE_ADD_RELU)
  179. .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  180. checker.set_param(Mode::FUSE_ADD_RELU)
  181. .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  182. checker.set_param(Mode::FUSE_ADD_RELU)
  183. .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  184. checker.set_param(Mode::FUSE_ADD_RELU)
  185. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  186. checker.set_param(Mode::FUSE_ADD_RELU)
  187. .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  188. checker.set_param(Mode::FUSE_ADD_RELU)
  189. .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  190. checker.set_param(Mode::FUSE_ADD_RELU)
  191. .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  192. checker.set_param(Mode::FUSE_ADD_RELU)
  193. .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  194. checker.set_param(Mode::FUSE_ADD_RELU)
  195. .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  196. checker.set_param(Mode::FUSE_ADD_RELU)
  197. .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  198. auto run = [&](Mode mode) {
  199. // VEC_BCAST101x
  200. checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  201. checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
  202. checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
  203. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  204. checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
  205. // BCAST101x_VEC not powOp
  206. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
  207. checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
  208. checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
  209. checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
  210. checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
  211. };
  212. run(Mode::ADD);
  213. run(Mode::FUSE_ADD_H_SWISH);
  214. run(Mode::FUSE_ADD_RELU);
  215. run(Mode::MAX);
  216. run(Mode::MIN);
  217. run(Mode::MUL);
  218. run(Mode::SUB);
  219. run(Mode::TRUE_DIV);
  220. run(Mode::POW);
  221. }
  222. TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
  223. using Mode = ElemwiseForward::Param::Mode;
  224. Checker<ElemwiseForward> checker(handle());
  225. checker.set_param(Mode::FUSE_ADD_RELU)
  226. .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  227. checker.set_param(Mode::FUSE_ADD_RELU)
  228. .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  229. checker.set_param(Mode::FUSE_ADD_RELU)
  230. .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  231. checker.set_param(Mode::FUSE_ADD_RELU)
  232. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  233. checker.set_param(Mode::FUSE_ADD_RELU)
  234. .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  235. checker.set_param(Mode::FUSE_ADD_RELU)
  236. .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  237. checker.set_param(Mode::FUSE_ADD_RELU)
  238. .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  239. checker.set_param(Mode::FUSE_ADD_RELU)
  240. .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  241. checker.set_param(Mode::FUSE_ADD_RELU)
  242. .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  243. checker.set_param(Mode::FUSE_ADD_RELU)
  244. .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  245. auto run = [&](Mode mode) {
  246. // VEC_BCAST101x
  247. checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  248. checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
  249. checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
  250. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  251. checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
  252. // BCAST101x_VEC not powOp
  253. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
  254. checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
  255. checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
  256. checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
  257. checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
  258. };
  259. auto run_all = [&]() {
  260. run(Mode::ADD);
  261. run(Mode::FUSE_ADD_H_SWISH);
  262. run(Mode::FUSE_ADD_RELU);
  263. run(Mode::MAX);
  264. run(Mode::MIN);
  265. run(Mode::MUL);
  266. run(Mode::SUB);
  267. run(Mode::TRUE_DIV);
  268. run(Mode::POW);
  269. };
  270. {
  271. UniformFloatRNG rng(1e-5, 7e1);
  272. checker.set_rng(0, &rng);
  273. checker.set_epsilon(1e-5);
  274. checker.set_dtype(0, dtype::Float32());
  275. checker.set_dtype(1, dtype::Float32());
  276. run_all();
  277. }
  278. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  279. {
  280. UniformFloatRNG rng(1, 2);
  281. checker.set_rng(0, &rng);
  282. checker.set_epsilon(3e-3);
  283. checker.set_dtype(0, dtype::Float16());
  284. checker.set_dtype(1, dtype::Float16());
  285. run_all();
  286. }
  287. #endif
  288. }
  289. #if MEGDNN_WITH_BENCHMARK
  290. namespace {
  291. void run_elemwise_benchmark(const TensorShapeArray& shapes,
  292. param::Elemwise::Mode mode, const char* mode_str,
  293. DType type, Handle* handle_bench) {
  294. auto handle_fallback = create_cpu_handle(1);
  295. Benchmarker<Elemwise> benchmarker_bench(handle_bench);
  296. Benchmarker<Elemwise> benchmarker_fallback(handle_fallback.get());
  297. float throughput = 0;
  298. SmallVector<TensorLayout> layouts;
  299. std::string src_strs;
  300. for (size_t i = 0; i < shapes.size(); i++) {
  301. layouts.emplace_back(shapes[i], type);
  302. throughput += layouts.back().span().dist_byte();
  303. src_strs += layouts.back().to_string();
  304. if (i != shapes.size() - 1) {
  305. src_strs += ",";
  306. }
  307. }
  308. constexpr size_t RUN = 50;
  309. benchmarker_fallback.set_times(RUN).set_display(false);
  310. benchmarker_bench.set_times(RUN).set_display(false);
  311. benchmarker_fallback.set_param(mode);
  312. benchmarker_bench.set_param(mode);
  313. TensorLayout dst_layout;
  314. auto opr = handle_bench->create_operator<Elemwise>();
  315. opr->param() = mode;
  316. opr->deduce_layout(layouts, dst_layout);
  317. float computations = dst_layout.total_nr_elems() *
  318. (std::max<size_t>(shapes.size(), 2) - 1);
  319. throughput += dst_layout.span().dist_byte();
  320. computations *= (1e3 / (1024.0 * 1024));
  321. throughput *= (1e3 / (1024.0 * 1024));
  322. layouts.emplace_back(dst_layout);
  323. auto fallback_time = benchmarker_fallback.execl(layouts) / RUN;
  324. auto bench_time = benchmarker_bench.execl(layouts) / RUN;
  325. float fallback_flops = computations / fallback_time;
  326. float bench_flops = computations / bench_time;
  327. float fallback_thr = throughput / fallback_time;
  328. float bench_thr = throughput / bench_time;
  329. printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
  330. "%fMB/s "
  331. "computations: %fx, throughput: %fx\n",
  332. src_strs.c_str(), dst_layout.to_string().c_str(), type.name(),
  333. mode_str, fallback_flops, fallback_thr, bench_flops, bench_thr,
  334. bench_flops / fallback_flops, bench_thr / fallback_thr);
  335. }
  336. } // namespace
  337. #define INT_RUN(shape, mode) \
  338. run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
  339. run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
  340. run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle());
  341. #define FLOAT_RUN(shape, mode) \
  342. run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \
  343. run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle());
  344. #define BENCHMARK_CASES(shape) \
  345. INT_BENCHMARK_CASES(shape) \
  346. FLOAT_BENCHMARK_CASES(shape)
  347. TEST_F(ARM_COMMON, BENCHMARK_UNARY) {
  348. #define INT_BENCHMARK_CASES(shape) \
  349. INT_RUN(shape, Mode::RELU); \
  350. INT_RUN(shape, Mode::ABS);
  351. #define FLOAT_BENCHMARK_CASES(shape) \
  352. FLOAT_RUN(shape, Mode::RELU); \
  353. FLOAT_RUN(shape, Mode::ABS); \
  354. FLOAT_RUN(shape, Mode::SIGMOID); \
  355. FLOAT_RUN(shape, Mode::EXP); \
  356. FLOAT_RUN(shape, Mode::TANH); \
  357. FLOAT_RUN(shape, Mode::FAST_TANH);
  358. using Mode = param::Elemwise::Mode;
  359. BENCHMARK_CASES({{10000}});
  360. BENCHMARK_CASES({{50000}});
  361. #undef INT_BENCHMARK_CASES
  362. #undef FLOAT_BENCHMARK_CASES
  363. }
  364. TEST_F(ARM_COMMON, BENCHMARK_BINARY) {
  365. #define INT_BENCHMARK_CASES(shape) \
  366. INT_RUN(shape, Mode::MIN); \
  367. INT_RUN(shape, Mode::MAX); \
  368. INT_RUN(shape, Mode::ADD); \
  369. INT_RUN(shape, Mode::SUB); \
  370. INT_RUN(shape, Mode::MUL); \
  371. INT_RUN(shape, Mode::RMULH); \
  372. INT_RUN(shape, Mode::FUSE_ADD_RELU);
  373. #define FLOAT_BENCHMARK_CASES(shape) \
  374. FLOAT_RUN(shape, Mode::MIN); \
  375. FLOAT_RUN(shape, Mode::MAX); \
  376. FLOAT_RUN(shape, Mode::ADD); \
  377. FLOAT_RUN(shape, Mode::SUB); \
  378. FLOAT_RUN(shape, Mode::MUL); \
  379. FLOAT_RUN(shape, Mode::POW); \
  380. FLOAT_RUN(shape, Mode::TRUE_DIV); \
  381. FLOAT_RUN(shape, Mode::FUSE_ADD_RELU);
  382. using Mode = param::Elemwise::Mode;
  383. TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}};
  384. BENCHMARK_CASES(shapes);
  385. shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}};
  386. BENCHMARK_CASES(shapes);
  387. shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}};
  388. BENCHMARK_CASES(shapes);
  389. #undef INT_BENCHMARK_CASES
  390. #undef FLOAT_BENCHMARK_CASES
  391. }
  392. TEST_F(ARM_COMMON, BENCHMARK_TERNARY_FMA3) {
  393. #define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3);
  394. #define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3);
  395. using Mode = param::Elemwise::Mode;
  396. TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}};
  397. BENCHMARK_CASES(shapes);
  398. shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}};
  399. BENCHMARK_CASES(shapes);
  400. shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}};
  401. BENCHMARK_CASES(shapes);
  402. #undef INT_BENCHMARK_CASES
  403. #undef FLOAT_BENCHMARK_CASES
  404. }
  405. #undef BENCHMARK_CASES
  406. #undef INT_RUN
  407. #undef FLOAT_RUN
  408. #endif
  409. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台