You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. /**
  2. * \file dnn/test/cuda/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/common/elemwise.h"
  13. #include "test/cuda/fixture.h"
  14. #include "megdnn/oprs.h"
  15. #include "test/common/tensor.h"
  16. #include "test/common/rng.h"
  17. #include "./utils.h"
  18. #include "test/common/benchmarker.h"
  19. #include "test/common/checker.h"
  20. #include <cudnn.h>
  21. #include <cuda_profiler_api.h>
  22. using namespace megdnn;
  23. using namespace test;
  24. #define cudnn_check(e) megdnn_assert((e) == CUDNN_STATUS_SUCCESS)
  25. namespace {
  26. __attribute__((unused)) cudnnTensorDescriptor_t make_cudnn_tensor_desc(
  27. const TensorLayout& ly) {
  28. megdnn_assert(ly.ndim && ly.ndim <= 4 && ly.is_contiguous());
  29. int dim[4] = {1, 1, 1, 1}, stride[4] = {1, 1, 1, 1};
  30. for (size_t i = 0; i < ly.ndim; ++i) {
  31. dim[i] = ly.shape[i];
  32. stride[i] = ly.stride[i];
  33. }
  34. cudnnTensorDescriptor_t ret;
  35. cudnn_check(cudnnCreateTensorDescriptor(&ret));
  36. // cudnn requires tensors to be at-least 4D
  37. cudnn_check(cudnnSetTensor4dDescriptorEx(ret, CUDNN_DATA_FLOAT, dim[0],
  38. dim[1], dim[2], dim[3], stride[0],
  39. stride[1], stride[2], stride[3]));
  40. return ret;
  41. }
  42. void run_tensor_add(Handle* handle_cuda, const TensorND& a, const TensorND& b,
  43. const TensorND& c) {
  44. #if 1
  45. cudnnHandle_t cudnn_handle;
  46. cudnn_check(cudnnCreate(&cudnn_handle));
  47. cuda_check(cudaDeviceSynchronize());
  48. cuda_check(cudaMemcpy(c.raw_ptr, a.raw_ptr, a.layout.span().dist_byte(),
  49. cudaMemcpyDeviceToDevice));
  50. auto bdesc = make_cudnn_tensor_desc(b.layout),
  51. cdesc = make_cudnn_tensor_desc(c.layout);
  52. float alpha = 1, beta = 1;
  53. cudaProfilerStart();
  54. cudnn_check(cudnnAddTensor(cudnn_handle, &alpha, bdesc, b.raw_ptr, &beta,
  55. cdesc, c.raw_ptr));
  56. cudaProfilerStop();
  57. cudnn_check(cudnnDestroyTensorDescriptor(cdesc));
  58. cudnn_check(cudnnDestroyTensorDescriptor(bdesc));
  59. cudnn_check(cudnnDestroy(cudnn_handle));
  60. cuda_check(cudaMemset(c.raw_ptr, 0, c.layout.span().dist_byte()));
  61. cuda_check(cudaDeviceSynchronize());
  62. #endif
  63. auto opr = handle_cuda->create_operator<ElemwiseForward>();
  64. opr->param().mode = ElemwiseForward::Mode::ADD;
  65. cudaProfilerStart();
  66. opr->exec({a, b}, c);
  67. cudaProfilerStop();
  68. }
  69. } // anonymous namespace
  70. template <typename tag>
  71. class CUDA_ELEMWISE : public CUDA {};
  72. TYPED_TEST_CASE(CUDA_ELEMWISE, elemwise::test_types);
  73. TYPED_TEST(CUDA_ELEMWISE, run) {
  74. elemwise::run_test<TypeParam>(this->handle_cuda());
  75. }
  76. TEST_F(CUDA, ELEMWISE_IBYTE) {
  77. Checker<ElemwiseForward> checker(handle_cuda());
  78. using Mode = ElemwiseForward::Param::Mode;
  79. UniformIntRNG i_rng{-128, 127};
  80. UniformIntRNG ui_rng{0, 255};
  81. checker.set_rng(0, &i_rng);
  82. auto run_unary = [&](size_t N, Mode mode, DType dtype) {
  83. checker.set_param(mode).set_dtype(0, dtype);
  84. checker.execs({{N}, {}});
  85. };
  86. #define RUN_UNARY_IBYTE(_dt) \
  87. run_unary(100, Mode::RELU, _dt); \
  88. run_unary(100, Mode::ABS, _dt);
  89. RUN_UNARY_IBYTE(dtype::Int8());
  90. checker.set_rng(0, &i_rng);
  91. RUN_UNARY_IBYTE(dtype::Uint8());
  92. #undef RUN_UNARY_IBYTE
  93. auto run_binary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  94. DType dtype) {
  95. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype);
  96. checker.execs({{5}, {5}, {}});
  97. checker.execs({{4}, {4}, {}});
  98. checker.execs({{4}, {1}, {}});
  99. checker.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}});
  100. checker.execs({{N, C / 4, H, W, 4}, {1, C / 4, 1, 1, 4}, {}});
  101. checker.execs({{N, C / 32, H, W, 32}, {N, C / 32, H, W, 32}, {}});
  102. checker.execs({{N, C / 32, H, W, 32}, {1, C / 32, 1, 1, 32}, {}});
  103. checker.execs({{3, 5, 7}, {3, 5, 7}, {}});
  104. checker.execs({{3, 5, 7}, {3, 5, 1}, {}});
  105. checker.execs({{3, 5, 1}, {3, 5, 7}, {}});
  106. checker.execs({{1}, {3, 5, 7}, {}});
  107. checker.execs({{3, 5, 7}, {1}, {}});
  108. };
  109. #define RUN_BINARY_IBYTE(_dt) \
  110. run_binary(4, 32, 10, 10, Mode::ADD, _dt); \
  111. run_binary(4, 32, 10, 10, Mode::MUL, _dt); \
  112. run_binary(4, 32, 10, 10, Mode::MAX, _dt); \
  113. run_binary(4, 32, 10, 10, Mode::MIN, _dt); \
  114. run_binary(4, 32, 10, 10, Mode::SUB, _dt);
  115. checker.set_rng(0, &i_rng).set_rng(1, &i_rng);
  116. RUN_BINARY_IBYTE(dtype::Int8());
  117. checker.set_rng(0, &ui_rng).set_rng(1, &ui_rng);
  118. RUN_BINARY_IBYTE(dtype::Uint8());
  119. #undef RUN_BINARY_IBYTE
  120. auto run_ternary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  121. DType dtype) {
  122. checker.set_param(mode)
  123. .set_dtype(0, dtype)
  124. .set_dtype(1, dtype)
  125. .set_dtype(2, dtype);
  126. checker.execs({{5}, {5}, {5}, {}});
  127. checker.execs({{4}, {4}, {1}, {}});
  128. checker.execs({{N, C / 4, H, W, 4},
  129. {N, C / 4, H, W, 4},
  130. {N, C / 4, H, W, 4},
  131. {}});
  132. checker.execs({{N, C / 4, H, W, 4},
  133. {1, C / 4, 1, 1, 4},
  134. {1, C / 4, 1, 1, 4},
  135. {}});
  136. checker.execs({{N, C / 32, H, W, 32},
  137. {N, C / 32, H, W, 32},
  138. {N, C / 32, H, W, 32},
  139. {}});
  140. checker.execs({{N, C / 32, H, W, 32},
  141. {1, C / 32, 1, 1, 32},
  142. {1, C / 32, 1, 1, 32},
  143. {}});
  144. checker.execs({{1}, {3, 5, 7}, {3, 5, 7}, {}});
  145. checker.execs({{3, 5, 7}, {3, 5, 1}, {3, 5, 1}, {}});
  146. checker.execs({{3, 5, 1}, {3, 5, 7}, {3, 5, 1}, {}});
  147. checker.execs({{1}, {3, 5, 7}, {1}, {}});
  148. checker.execs({{3, 5, 7}, {1}, {3, 5, 7}, {}});
  149. };
  150. #define RUN_TERNARY_IBYTE(_dt) \
  151. run_ternary(4, 32, 10, 10, Mode::FUSE_MUL_ADD3, _dt);
  152. checker.set_rng(0, &i_rng).set_rng(1, &i_rng);
  153. RUN_TERNARY_IBYTE(dtype::Int8());
  154. checker.set_rng(0, &ui_rng).set_rng(1, &ui_rng);
  155. RUN_TERNARY_IBYTE(dtype::Uint8());
  156. #undef RUN_TERNARY_IBYTE
  157. }
  158. // from common/elemwise.cpp
  159. TEST_F(CUDA, ELEMWISE_BFLOAT16) {
  160. using Mode = ElemwiseForward::Param::Mode;
  161. Checker<ElemwiseForward> checker(handle_cuda());
  162. // unary
  163. #define UNARY_TEST_CASE(_optr) \
  164. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  165. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  166. #define BUILD_UNARY_TEST_CASE_FLOAT \
  167. UNARY_TEST_CASE(ABS) \
  168. UNARY_TEST_CASE(LOG) \
  169. UNARY_TEST_CASE(COS) \
  170. UNARY_TEST_CASE(SIN) \
  171. UNARY_TEST_CASE(FLOOR) \
  172. UNARY_TEST_CASE(CEIL) \
  173. UNARY_TEST_CASE(SIGMOID) \
  174. UNARY_TEST_CASE(EXP) \
  175. UNARY_TEST_CASE(TANH) \
  176. UNARY_TEST_CASE(FAST_TANH) \
  177. UNARY_TEST_CASE(RELU) \
  178. UNARY_TEST_CASE(ROUND)
  179. checker.set_dtype(0, dtype::BFloat16());
  180. checker.set_dtype(1, dtype::BFloat16());
  181. UniformFloatRNG rng0(1e-2, 6e1);
  182. checker.set_rng(0, &rng0);
  183. checker.set_epsilon(1e-2);
  184. BUILD_UNARY_TEST_CASE_FLOAT
  185. #undef UNARY_TEST_CASE
  186. #undef BUILD_UNARY_TEST_CASE_FLOAT
  187. // binary
  188. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  189. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  190. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  191. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  192. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  193. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  194. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  195. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  196. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  197. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  198. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  199. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  200. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  201. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  202. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  203. BINARY_COMPLATE_TEST_CASE(ADD) \
  204. BINARY_COMPLATE_TEST_CASE(MUL) \
  205. BINARY_COMPLATE_TEST_CASE(MAX) \
  206. BINARY_COMPLATE_TEST_CASE(MIN) \
  207. BINARY_COMPLATE_TEST_CASE(SUB)
  208. UniformFloatRNG rng1(1e-5, 7e1);
  209. checker.set_rng(0, &rng1);
  210. checker.set_epsilon(1e-2);
  211. checker.set_dtype(0, dtype::BFloat16());
  212. checker.set_dtype(1, dtype::BFloat16());
  213. BUILD_BINARY_COMPLATE_TEST_CASE
  214. #undef BINARY_COMPLATE_TEST_CASE
  215. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  216. // ternary
  217. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  218. checker.set_param(Mode::_optr) \
  219. .execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  220. checker.set_param(Mode::_optr) \
  221. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  222. checker.set_param(Mode::_optr) \
  223. .execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  224. checker.set_param(Mode::_optr) \
  225. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  226. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  227. checker.set_param(Mode::_optr) \
  228. .execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  229. checker.set_param(Mode::_optr) \
  230. .execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  231. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  232. #define BUILD_TERNARY_COMPLATE_TEST_CASE \
  233. TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  234. UniformFloatRNG rng2(1e-5, 7e1);
  235. checker.set_rng(0, &rng2);
  236. checker.set_epsilon(1e-2);
  237. checker.set_dtype(0, dtype::BFloat16());
  238. checker.set_dtype(1, dtype::BFloat16());
  239. checker.set_dtype(2, dtype::BFloat16());
  240. BUILD_TERNARY_COMPLATE_TEST_CASE
  241. #undef TERNARY_COMPLATE_TEST_CASE
  242. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  243. }
  244. TEST_F(CUDA, ELEMWISE_ADD_BCAST_10_INT8_INPLACE) {
  245. constexpr size_t A = 2, B = 48, C0 = 14, C1 = 14, C = C0 * C1;
  246. SyncedTensor<dt_int8> t0(handle_cuda(),
  247. {TensorShape{A, B, C0, C1}, dtype::Int8()}),
  248. t1(handle_cuda(), {TensorShape{1, B, C0, C1}, dtype::Int8()}),
  249. t2(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Int8()});
  250. UniformIntRNG rng{-128, 127};
  251. rng.gen(t0.tensornd_host());
  252. rng.gen(t1.tensornd_host());
  253. auto p0 = t0.ptr_host(), p1 = t1.ptr_host();
  254. auto p2 = t2.ptr_mutable_host();
  255. for (size_t i = 0; i < A; ++i) {
  256. for (size_t j = 0; j < B; ++j) {
  257. for (size_t k = 0; k < C; ++k) {
  258. auto off0 = j * C + k;
  259. auto off1 = i * B * C + j * C + k;
  260. p2[off1] = p0[off1] + p1[off0];
  261. }
  262. }
  263. }
  264. auto opr = handle_cuda()->create_operator<ElemwiseForward>();
  265. opr->param().mode = ElemwiseForward::Mode::ADD;
  266. opr->exec({t0.tensornd_dev(), t1.tensornd_dev()}, t0.tensornd_dev());
  267. auto pt = t0.ptr_host();
  268. for (size_t i = 0; i < A; ++i) {
  269. for (size_t j = 0; j < B; ++j) {
  270. for (size_t k = 0; k < C; ++k) {
  271. auto off = i * B * C + j * C + k;
  272. ASSERT_EQ(pt[off], p2[off]);
  273. }
  274. }
  275. }
  276. }
  277. //! the memory of this test case is too large, sometimes will fail on tx1
  278. TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) {
  279. constexpr size_t A = 256 * 1024 * 64, S0 = 16, S1 = 256, S2 = 64, S3 = 64;
  280. static_assert(A == S0 * S1 * S2 * S3, "bad value");
  281. SyncedTensor<> t0(handle_cuda(),
  282. {TensorShape{S0, S1, S2, S3}, dtype::Float32()}),
  283. t1(handle_cuda(), {TensorShape{S0, S1, S2, S3}, dtype::Float32()});
  284. UniformFloatRNG rng{-2.f, 2.f};
  285. rng.gen(t0.tensornd_host());
  286. run_tensor_add(handle_cuda(), t0.tensornd_dev(), t0.tensornd_dev(),
  287. t1.tensornd_dev());
  288. auto p0 = t0.ptr_host(), p1 = t1.ptr_host();
  289. for (size_t i = 0; i < A; ++i) {
  290. ASSERT_EQ(p0[i] + p0[i], p1[i]) << "at index " << i << "/" << A;
  291. }
  292. }
  293. #if MEGDNN_WITH_BENCHMARK
  294. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_101) {
  295. constexpr size_t A = 511, B = 509, C0 = 23, C1 = 23, C = C0 * C1;
  296. SyncedTensor<> t0(handle_cuda(),
  297. {TensorShape{A, B, C0, C1}, dtype::Float32()}),
  298. t1(handle_cuda(), {TensorShape{1, B, 1, 1}, dtype::Float32()}),
  299. t2(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Float32()});
  300. UniformFloatRNG rng{-2.f, 2.f};
  301. rng.gen(t0.tensornd_host());
  302. rng.gen(t1.tensornd_host());
  303. run_tensor_add(handle_cuda(), t0.tensornd_dev(), t1.tensornd_dev(),
  304. t2.tensornd_dev());
  305. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  306. for (size_t i = 0; i < A; ++i) {
  307. for (size_t j = 0; j < B; ++j) {
  308. for (size_t k = 0; k < C; ++k) {
  309. auto off = i * B * C + j * C + k;
  310. ASSERT_EQ(p0[off] + p1[j], p2[off]);
  311. }
  312. }
  313. }
  314. }
  315. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_10) {
  316. constexpr size_t A = 11583, B = 11587;
  317. SyncedTensor<> t0(handle_cuda(), {TensorShape{A, B}, dtype::Float32()}),
  318. t1(handle_cuda(), {TensorShape{1, B}, dtype::Float32()}),
  319. t2(handle_cuda(), {TensorShape{A, B}, dtype::Float32()});
  320. UniformFloatRNG rng{-2.f, 2.f};
  321. rng.gen(t0.tensornd_host());
  322. rng.gen(t1.tensornd_host());
  323. run_tensor_add(handle_cuda(), t0.tensornd_dev(), t1.tensornd_dev(),
  324. t2.tensornd_dev());
  325. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  326. for (size_t i = 0; i < A; ++i) {
  327. for (size_t j = 0; j < B; ++j) {
  328. auto off = i * B + j;
  329. ASSERT_EQ(p0[off] + p1[j], p2[off]);
  330. }
  331. }
  332. }
  333. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_01) {
  334. constexpr size_t A = 11583, B = 11587;
  335. SyncedTensor<> t0(handle_cuda(), {TensorShape{1, A, B}, dtype::Float32()}),
  336. t1(handle_cuda(), {TensorShape{1, A, 1}, dtype::Float32()}),
  337. t2(handle_cuda(), {TensorShape{1, A, B}, dtype::Float32()});
  338. UniformFloatRNG rng{-2.f, 2.f};
  339. rng.gen(t0.tensornd_host());
  340. rng.gen(t1.tensornd_host());
  341. run_tensor_add(handle_cuda(), t0.tensornd_dev(), t1.tensornd_dev(),
  342. t2.tensornd_dev());
  343. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  344. for (size_t i = 0; i < A; ++i) {
  345. for (size_t j = 0; j < B; ++j) {
  346. auto off = i * B + j;
  347. ASSERT_EQ(p0[off] + p1[i], p2[off]);
  348. }
  349. }
  350. }
  351. TEST_F(CUDA, BENCHMARK_ELEMWISE_IBYTE) {
  352. Benchmarker<ElemwiseForward> bencher(handle_cuda());
  353. using Mode = ElemwiseForward::Param::Mode;
  354. auto run_bench = [&](size_t N, size_t C, size_t H, size_t W) {
  355. size_t nr_times = 100;
  356. bencher.set_times(nr_times)
  357. .set_param(Mode::FUSE_ADD_RELU)
  358. .set_dtype(0, dtype::Int8())
  359. .set_dtype(1, dtype::Int8());
  360. auto time =
  361. bencher.execs({{N * C * H * W + 1}, {N * C * H * W + 1}, {}}) /
  362. nr_times;
  363. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  364. (3.0 * (N * C * H * W + 1)) / (time * 1e6));
  365. time = bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) /
  366. nr_times;
  367. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  368. (3.0 * N * C * H * W) / (time * 1e6));
  369. time = bencher.execs({{N, C / 4, H, W, 4}, {1, C / 4, 1, 1, 4}, {}}) /
  370. nr_times;
  371. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  372. (C + 2.0 * N * C * H * W) / (time * 1e6));
  373. time = bencher.execs({{N, C / 4, H, W, 4}, {1}, {}}) / nr_times;
  374. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  375. (2.0 * N * C * H * W + 1) / (time * 1e6));
  376. time = bencher.execs(
  377. {{N, C / 32, H, W, 32}, {N, C / 32, H, W, 32}, {}}) /
  378. nr_times;
  379. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  380. (3.0 * N * C * H * W) / (time * 1e6));
  381. time = bencher.execs(
  382. {{N, C / 32, H, W, 32}, {1, C / 32, 1, 1, 32}, {}}) /
  383. nr_times;
  384. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  385. (C + 2.0 * N * C * H * W) / (time * 1e6));
  386. bencher.set_dtype(0, dtype::Float32()).set_dtype(1, dtype::Float32());
  387. time = bencher.execs({{N, C / 4, H, W}, {N, C / 4, H, W}, {}}) /
  388. nr_times;
  389. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  390. (3.0 * N * C * H * W) / (time * 1e6));
  391. time = bencher.execs({{N, C / 4, H, W}, {1, C / 4, 1, 1}, {}}) /
  392. nr_times;
  393. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  394. (C + 2.0 * N * C * H * W) / (time * 1e6));
  395. };
  396. run_bench(256, 256, 56, 56);
  397. }
  398. TEST_F(CUDA, BENCHMARK_ELEMWISE_MIN_MAX) {
  399. Benchmarker<ElemwiseForward> bencher(handle_cuda());
  400. using Mode = ElemwiseForward::Param::Mode;
  401. UniformIntRNG const_1{1, 1}, rng{-128, 127};
  402. auto run_bench = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  403. size_t nr_times = 1000;
  404. bencher.set_times(nr_times)
  405. .set_param(Mode::MIN)
  406. .set_rng(0, &rng)
  407. .set_rng(1, &rng)
  408. .set_dtype(0, dtype)
  409. .set_dtype(1, dtype);
  410. auto time =
  411. bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) /
  412. nr_times;
  413. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  414. (3.0 * N * C * H * W) / (time * 1e6));
  415. bencher.set_param(Mode::MAX).set_rng(0, &const_1).set_rng(1, &const_1);
  416. time = bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) /
  417. nr_times;
  418. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  419. (3.0 * N * C * H * W) / (time * 1e6));
  420. };
  421. run_bench(256, 256, 56, 56, dtype::Int8());
  422. }
  423. #endif
  424. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台