You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. /**
  2. * \file dnn/test/cuda/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/common/elemwise.h"
  13. #include "./utils.h"
  14. #include "megdnn/oprs.h"
  15. #include "test/common/benchmarker.h"
  16. #include "test/common/checker.h"
  17. #include "test/common/rng.h"
  18. #include "test/common/tensor.h"
  19. #include "test/cuda/fixture.h"
  20. #include <cuda_profiler_api.h>
  21. #include <cudnn.h>
  22. using namespace megdnn;
  23. using namespace test;
  24. #define cudnn_check(e) megdnn_assert((e) == CUDNN_STATUS_SUCCESS)
  25. namespace {
  26. __attribute__((unused)) cudnnTensorDescriptor_t make_cudnn_tensor_desc(
  27. const TensorLayout& ly) {
  28. megdnn_assert(ly.ndim && ly.ndim <= 4 && ly.is_contiguous());
  29. int dim[4] = {1, 1, 1, 1}, stride[4] = {1, 1, 1, 1};
  30. for (size_t i = 0; i < ly.ndim; ++i) {
  31. dim[i] = ly.shape[i];
  32. stride[i] = ly.stride[i];
  33. }
  34. cudnnTensorDescriptor_t ret;
  35. cudnn_check(cudnnCreateTensorDescriptor(&ret));
  36. // cudnn requires tensors to be at-least 4D
  37. cudnn_check(cudnnSetTensor4dDescriptorEx(
  38. ret, CUDNN_DATA_FLOAT, dim[0], dim[1], dim[2], dim[3], stride[0], stride[1],
  39. stride[2], stride[3]));
  40. return ret;
  41. }
  42. void run_tensor_add(
  43. Handle* handle_cuda, const TensorND& a, const TensorND& b, const TensorND& c) {
  44. #if 1
  45. cudnnHandle_t cudnn_handle;
  46. cudnn_check(cudnnCreate(&cudnn_handle));
  47. cuda_check(cudaDeviceSynchronize());
  48. cuda_check(cudaMemcpy(
  49. c.raw_ptr, a.raw_ptr, a.layout.span().dist_byte(),
  50. cudaMemcpyDeviceToDevice));
  51. auto bdesc = make_cudnn_tensor_desc(b.layout),
  52. cdesc = make_cudnn_tensor_desc(c.layout);
  53. float alpha = 1, beta = 1;
  54. cudaProfilerStart();
  55. cudnn_check(cudnnAddTensor(
  56. cudnn_handle, &alpha, bdesc, b.raw_ptr, &beta, cdesc, c.raw_ptr));
  57. cudaProfilerStop();
  58. cudnn_check(cudnnDestroyTensorDescriptor(cdesc));
  59. cudnn_check(cudnnDestroyTensorDescriptor(bdesc));
  60. cudnn_check(cudnnDestroy(cudnn_handle));
  61. cuda_check(cudaMemset(c.raw_ptr, 0, c.layout.span().dist_byte()));
  62. cuda_check(cudaDeviceSynchronize());
  63. #endif
  64. auto opr = handle_cuda->create_operator<ElemwiseForward>();
  65. opr->param().mode = ElemwiseForward::Mode::ADD;
  66. cudaProfilerStart();
  67. opr->exec({a, b}, c);
  68. cudaProfilerStop();
  69. }
  70. } // anonymous namespace
  71. template <typename tag>
  72. class CUDA_ELEMWISE : public CUDA {};
  73. TYPED_TEST_CASE(CUDA_ELEMWISE, elemwise::test_types);
  74. TYPED_TEST(CUDA_ELEMWISE, run) {
  75. elemwise::run_test<TypeParam>(this->handle_cuda());
  76. }
  77. TEST_F(CUDA, ELEMWISE_IBYTE) {
  78. Checker<ElemwiseForward> checker(handle_cuda());
  79. using Mode = ElemwiseForward::Param::Mode;
  80. UniformIntRNG i_rng{-128, 127};
  81. UniformIntRNG ui_rng{0, 255};
  82. checker.set_rng(0, &i_rng);
  83. auto run_unary = [&](size_t N, Mode mode, DType dtype) {
  84. checker.set_param(mode).set_dtype(0, dtype);
  85. checker.execs({{N}, {}});
  86. };
  87. #define RUN_UNARY_IBYTE(_dt) \
  88. run_unary(100, Mode::RELU, _dt); \
  89. run_unary(100, Mode::ABS, _dt);
  90. RUN_UNARY_IBYTE(dtype::Int8());
  91. checker.set_rng(0, &i_rng);
  92. RUN_UNARY_IBYTE(dtype::Uint8());
  93. #undef RUN_UNARY_IBYTE
  94. auto run_binary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  95. DType dtype) {
  96. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype);
  97. checker.execs({{5}, {5}, {}});
  98. checker.execs({{4}, {4}, {}});
  99. checker.execs({{4}, {1}, {}});
  100. checker.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}});
  101. checker.execs({{N, C / 4, H, W, 4}, {1, C / 4, 1, 1, 4}, {}});
  102. checker.execs({{N, C / 32, H, W, 32}, {N, C / 32, H, W, 32}, {}});
  103. checker.execs({{N, C / 32, H, W, 32}, {1, C / 32, 1, 1, 32}, {}});
  104. checker.execs({{3, 5, 7}, {3, 5, 7}, {}});
  105. checker.execs({{3, 5, 7}, {3, 5, 1}, {}});
  106. checker.execs({{3, 5, 1}, {3, 5, 7}, {}});
  107. checker.execs({{1}, {3, 5, 7}, {}});
  108. checker.execs({{3, 5, 7}, {1}, {}});
  109. };
  110. #define RUN_BINARY_IBYTE(_dt) \
  111. run_binary(4, 32, 10, 10, Mode::ADD, _dt); \
  112. run_binary(4, 32, 10, 10, Mode::MUL, _dt); \
  113. run_binary(4, 32, 10, 10, Mode::MAX, _dt); \
  114. run_binary(4, 32, 10, 10, Mode::MIN, _dt); \
  115. run_binary(4, 32, 10, 10, Mode::SUB, _dt);
  116. checker.set_rng(0, &i_rng).set_rng(1, &i_rng);
  117. RUN_BINARY_IBYTE(dtype::Int8());
  118. checker.set_rng(0, &ui_rng).set_rng(1, &ui_rng);
  119. RUN_BINARY_IBYTE(dtype::Uint8());
  120. #undef RUN_BINARY_IBYTE
  121. auto run_ternary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  122. DType dtype) {
  123. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(
  124. 2, dtype);
  125. checker.execs({{5}, {5}, {5}, {}});
  126. checker.execs({{4}, {4}, {1}, {}});
  127. checker.execs(
  128. {{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}});
  129. checker.execs(
  130. {{N, C / 4, H, W, 4}, {1, C / 4, 1, 1, 4}, {1, C / 4, 1, 1, 4}, {}});
  131. checker.execs(
  132. {{N, C / 32, H, W, 32},
  133. {N, C / 32, H, W, 32},
  134. {N, C / 32, H, W, 32},
  135. {}});
  136. checker.execs(
  137. {{N, C / 32, H, W, 32},
  138. {1, C / 32, 1, 1, 32},
  139. {1, C / 32, 1, 1, 32},
  140. {}});
  141. checker.execs({{1}, {3, 5, 7}, {3, 5, 7}, {}});
  142. checker.execs({{3, 5, 7}, {3, 5, 1}, {3, 5, 1}, {}});
  143. checker.execs({{3, 5, 1}, {3, 5, 7}, {3, 5, 1}, {}});
  144. checker.execs({{1}, {3, 5, 7}, {1}, {}});
  145. checker.execs({{3, 5, 7}, {1}, {3, 5, 7}, {}});
  146. };
  147. #define RUN_TERNARY_IBYTE(_dt) run_ternary(4, 32, 10, 10, Mode::FUSE_MUL_ADD3, _dt);
  148. checker.set_rng(0, &i_rng).set_rng(1, &i_rng);
  149. RUN_TERNARY_IBYTE(dtype::Int8());
  150. checker.set_rng(0, &ui_rng).set_rng(1, &ui_rng);
  151. RUN_TERNARY_IBYTE(dtype::Uint8());
  152. #undef RUN_TERNARY_IBYTE
  153. }
  154. // from common/elemwise.cpp
  155. TEST_F(CUDA, ELEMWISE_BFLOAT16) {
  156. using Mode = ElemwiseForward::Param::Mode;
  157. Checker<ElemwiseForward> checker(handle_cuda());
  158. // unary
  159. #define UNARY_TEST_CASE(_optr) \
  160. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  161. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  162. #define BUILD_UNARY_TEST_CASE_FLOAT \
  163. UNARY_TEST_CASE(ABS) \
  164. UNARY_TEST_CASE(LOG) \
  165. UNARY_TEST_CASE(COS) \
  166. UNARY_TEST_CASE(SIN) \
  167. UNARY_TEST_CASE(FLOOR) \
  168. UNARY_TEST_CASE(CEIL) \
  169. UNARY_TEST_CASE(SIGMOID) \
  170. UNARY_TEST_CASE(EXP) \
  171. UNARY_TEST_CASE(TANH) \
  172. UNARY_TEST_CASE(FAST_TANH) \
  173. UNARY_TEST_CASE(RELU) \
  174. UNARY_TEST_CASE(ROUND)
  175. checker.set_dtype(0, dtype::BFloat16());
  176. checker.set_dtype(1, dtype::BFloat16());
  177. UniformFloatRNG rng0(1e-2, 6e1);
  178. checker.set_rng(0, &rng0);
  179. checker.set_epsilon(1e-2);
  180. BUILD_UNARY_TEST_CASE_FLOAT
  181. #undef UNARY_TEST_CASE
  182. #undef BUILD_UNARY_TEST_CASE_FLOAT
  183. // binary
  184. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  185. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  186. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  187. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  188. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  189. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  190. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  191. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  192. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  193. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  194. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  195. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  196. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  197. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  198. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  199. BINARY_COMPLATE_TEST_CASE(ADD) \
  200. BINARY_COMPLATE_TEST_CASE(MUL) \
  201. BINARY_COMPLATE_TEST_CASE(MAX) \
  202. BINARY_COMPLATE_TEST_CASE(MIN) \
  203. BINARY_COMPLATE_TEST_CASE(SUB)
  204. UniformFloatRNG rng1(1e-5, 7e1);
  205. checker.set_rng(0, &rng1);
  206. checker.set_epsilon(1e-2);
  207. checker.set_dtype(0, dtype::BFloat16());
  208. checker.set_dtype(1, dtype::BFloat16());
  209. BUILD_BINARY_COMPLATE_TEST_CASE
  210. #undef BINARY_COMPLATE_TEST_CASE
  211. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  212. // ternary
  213. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  214. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  215. checker.set_param(Mode::_optr) \
  216. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  217. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  218. checker.set_param(Mode::_optr) \
  219. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  220. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  221. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  222. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  223. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  224. #define BUILD_TERNARY_COMPLATE_TEST_CASE TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  225. UniformFloatRNG rng2(1e-5, 7e1);
  226. checker.set_rng(0, &rng2);
  227. checker.set_epsilon(1e-2);
  228. checker.set_dtype(0, dtype::BFloat16());
  229. checker.set_dtype(1, dtype::BFloat16());
  230. checker.set_dtype(2, dtype::BFloat16());
  231. BUILD_TERNARY_COMPLATE_TEST_CASE
  232. #undef TERNARY_COMPLATE_TEST_CASE
  233. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  234. }
  235. TEST_F(CUDA, ELEMWISE_ADD_BCAST_10_INT8_INPLACE) {
  236. constexpr size_t A = 2, B = 48, C0 = 14, C1 = 14, C = C0 * C1;
  237. SyncedTensor<dt_int8> t0(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Int8()}),
  238. t1(handle_cuda(), {TensorShape{1, B, C0, C1}, dtype::Int8()}),
  239. t2(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Int8()});
  240. UniformIntRNG rng{-128, 127};
  241. rng.gen(t0.tensornd_host());
  242. rng.gen(t1.tensornd_host());
  243. auto p0 = t0.ptr_host(), p1 = t1.ptr_host();
  244. auto p2 = t2.ptr_mutable_host();
  245. for (size_t i = 0; i < A; ++i) {
  246. for (size_t j = 0; j < B; ++j) {
  247. for (size_t k = 0; k < C; ++k) {
  248. auto off0 = j * C + k;
  249. auto off1 = i * B * C + j * C + k;
  250. p2[off1] = p0[off1] + p1[off0];
  251. }
  252. }
  253. }
  254. auto opr = handle_cuda()->create_operator<ElemwiseForward>();
  255. opr->param().mode = ElemwiseForward::Mode::ADD;
  256. opr->exec({t0.tensornd_dev(), t1.tensornd_dev()}, t0.tensornd_dev());
  257. auto pt = t0.ptr_host();
  258. for (size_t i = 0; i < A; ++i) {
  259. for (size_t j = 0; j < B; ++j) {
  260. for (size_t k = 0; k < C; ++k) {
  261. auto off = i * B * C + j * C + k;
  262. ASSERT_EQ(pt[off], p2[off]);
  263. }
  264. }
  265. }
  266. }
  267. //! the memory of this test case is too large, sometimes will fail on tx1
  268. TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) {
  269. constexpr size_t A = 256 * 1024 * 64, S0 = 16, S1 = 256, S2 = 64, S3 = 64;
  270. static_assert(A == S0 * S1 * S2 * S3, "bad value");
  271. SyncedTensor<> t0(handle_cuda(), {TensorShape{S0, S1, S2, S3}, dtype::Float32()}),
  272. t1(handle_cuda(), {TensorShape{S0, S1, S2, S3}, dtype::Float32()});
  273. UniformFloatRNG rng{-2.f, 2.f};
  274. rng.gen(t0.tensornd_host());
  275. run_tensor_add(
  276. handle_cuda(), t0.tensornd_dev(), t0.tensornd_dev(), t1.tensornd_dev());
  277. auto p0 = t0.ptr_host(), p1 = t1.ptr_host();
  278. for (size_t i = 0; i < A; ++i) {
  279. ASSERT_EQ(p0[i] + p0[i], p1[i]) << "at index " << i << "/" << A;
  280. }
  281. }
  282. #if MEGDNN_WITH_BENCHMARK
  283. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_101) {
  284. constexpr size_t A = 511, B = 509, C0 = 23, C1 = 23, C = C0 * C1;
  285. SyncedTensor<> t0(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Float32()}),
  286. t1(handle_cuda(), {TensorShape{1, B, 1, 1}, dtype::Float32()}),
  287. t2(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Float32()});
  288. UniformFloatRNG rng{-2.f, 2.f};
  289. rng.gen(t0.tensornd_host());
  290. rng.gen(t1.tensornd_host());
  291. run_tensor_add(
  292. handle_cuda(), t0.tensornd_dev(), t1.tensornd_dev(), t2.tensornd_dev());
  293. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  294. for (size_t i = 0; i < A; ++i) {
  295. for (size_t j = 0; j < B; ++j) {
  296. for (size_t k = 0; k < C; ++k) {
  297. auto off = i * B * C + j * C + k;
  298. ASSERT_EQ(p0[off] + p1[j], p2[off]);
  299. }
  300. }
  301. }
  302. }
  303. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_10) {
  304. constexpr size_t A = 11583, B = 11587;
  305. SyncedTensor<> t0(handle_cuda(), {TensorShape{A, B}, dtype::Float32()}),
  306. t1(handle_cuda(), {TensorShape{1, B}, dtype::Float32()}),
  307. t2(handle_cuda(), {TensorShape{A, B}, dtype::Float32()});
  308. UniformFloatRNG rng{-2.f, 2.f};
  309. rng.gen(t0.tensornd_host());
  310. rng.gen(t1.tensornd_host());
  311. run_tensor_add(
  312. handle_cuda(), t0.tensornd_dev(), t1.tensornd_dev(), t2.tensornd_dev());
  313. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  314. for (size_t i = 0; i < A; ++i) {
  315. for (size_t j = 0; j < B; ++j) {
  316. auto off = i * B + j;
  317. ASSERT_EQ(p0[off] + p1[j], p2[off]);
  318. }
  319. }
  320. }
  321. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_01) {
  322. constexpr size_t A = 11583, B = 11587;
  323. SyncedTensor<> t0(handle_cuda(), {TensorShape{1, A, B}, dtype::Float32()}),
  324. t1(handle_cuda(), {TensorShape{1, A, 1}, dtype::Float32()}),
  325. t2(handle_cuda(), {TensorShape{1, A, B}, dtype::Float32()});
  326. UniformFloatRNG rng{-2.f, 2.f};
  327. rng.gen(t0.tensornd_host());
  328. rng.gen(t1.tensornd_host());
  329. run_tensor_add(
  330. handle_cuda(), t0.tensornd_dev(), t1.tensornd_dev(), t2.tensornd_dev());
  331. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  332. for (size_t i = 0; i < A; ++i) {
  333. for (size_t j = 0; j < B; ++j) {
  334. auto off = i * B + j;
  335. ASSERT_EQ(p0[off] + p1[i], p2[off]);
  336. }
  337. }
  338. }
  339. TEST_F(CUDA, BENCHMARK_ELEMWISE_IBYTE) {
  340. Benchmarker<ElemwiseForward> bencher(handle_cuda());
  341. using Mode = ElemwiseForward::Param::Mode;
  342. auto run_bench = [&](size_t N, size_t C, size_t H, size_t W) {
  343. size_t nr_times = 100;
  344. bencher.set_times(nr_times)
  345. .set_param(Mode::FUSE_ADD_RELU)
  346. .set_dtype(0, dtype::Int8())
  347. .set_dtype(1, dtype::Int8());
  348. auto time = bencher.execs({{N * C * H * W + 1}, {N * C * H * W + 1}, {}}) /
  349. nr_times;
  350. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  351. (3.0 * (N * C * H * W + 1)) / (time * 1e6));
  352. time = bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) / nr_times;
  353. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  354. (3.0 * N * C * H * W) / (time * 1e6));
  355. time = bencher.execs({{N, C / 4, H, W, 4}, {1, C / 4, 1, 1, 4}, {}}) / nr_times;
  356. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  357. (C + 2.0 * N * C * H * W) / (time * 1e6));
  358. time = bencher.execs({{N, C / 4, H, W, 4}, {1}, {}}) / nr_times;
  359. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  360. (2.0 * N * C * H * W + 1) / (time * 1e6));
  361. time = bencher.execs({{N, C / 32, H, W, 32}, {N, C / 32, H, W, 32}, {}}) /
  362. nr_times;
  363. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  364. (3.0 * N * C * H * W) / (time * 1e6));
  365. time = bencher.execs({{N, C / 32, H, W, 32}, {1, C / 32, 1, 1, 32}, {}}) /
  366. nr_times;
  367. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  368. (C + 2.0 * N * C * H * W) / (time * 1e6));
  369. bencher.set_dtype(0, dtype::Float32()).set_dtype(1, dtype::Float32());
  370. time = bencher.execs({{N, C / 4, H, W}, {N, C / 4, H, W}, {}}) / nr_times;
  371. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  372. (3.0 * N * C * H * W) / (time * 1e6));
  373. time = bencher.execs({{N, C / 4, H, W}, {1, C / 4, 1, 1}, {}}) / nr_times;
  374. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  375. (C + 2.0 * N * C * H * W) / (time * 1e6));
  376. };
  377. run_bench(256, 256, 56, 56);
  378. }
  379. TEST_F(CUDA, BENCHMARK_ELEMWISE_MIN_MAX) {
  380. Benchmarker<ElemwiseForward> bencher(handle_cuda());
  381. using Mode = ElemwiseForward::Param::Mode;
  382. UniformIntRNG const_1{1, 1}, rng{-128, 127};
  383. auto run_bench = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  384. size_t nr_times = 1000;
  385. bencher.set_times(nr_times)
  386. .set_param(Mode::MIN)
  387. .set_rng(0, &rng)
  388. .set_rng(1, &rng)
  389. .set_dtype(0, dtype)
  390. .set_dtype(1, dtype);
  391. auto time = bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) /
  392. nr_times;
  393. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  394. (3.0 * N * C * H * W) / (time * 1e6));
  395. bencher.set_param(Mode::MAX).set_rng(0, &const_1).set_rng(1, &const_1);
  396. time = bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) / nr_times;
  397. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  398. (3.0 * N * C * H * W) / (time * 1e6));
  399. };
  400. run_bench(256, 256, 56, 56, dtype::Int8());
  401. }
  402. #endif
  403. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台