You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. /**
  2. * \file dnn/test/cuda/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/elemwise.h"
  12. #include "test/cuda/fixture.h"
  13. #include "megdnn/oprs.h"
  14. #include "test/common/tensor.h"
  15. #include "test/common/rng.h"
  16. #include "./utils.h"
  17. #include "test/common/benchmarker.h"
  18. #include "test/common/checker.h"
  19. #include <cudnn.h>
  20. #include <cuda_profiler_api.h>
  21. using namespace megdnn;
  22. using namespace test;
  23. #define cudnn_check(e) megdnn_assert((e) == CUDNN_STATUS_SUCCESS)
  24. namespace {
  25. __attribute__((unused))
  26. cudnnTensorDescriptor_t make_cudnn_tensor_desc(const TensorLayout &ly) {
  27. megdnn_assert(ly.ndim && ly.ndim <= 4 && ly.is_contiguous());
  28. int dim[4] = {1, 1, 1, 1}, stride[4] = {1, 1, 1, 1};
  29. for (size_t i = 0; i < ly.ndim; ++ i) {
  30. dim[i] = ly.shape[i];
  31. stride[i] = ly.stride[i];
  32. }
  33. cudnnTensorDescriptor_t ret;
  34. cudnn_check(cudnnCreateTensorDescriptor(&ret));
  35. // cudnn requires tensors to be at-least 4D
  36. cudnn_check(cudnnSetTensor4dDescriptorEx(ret,
  37. CUDNN_DATA_FLOAT,
  38. dim[0], dim[1], dim[2], dim[3],
  39. stride[0], stride[1], stride[2], stride[3]));
  40. return ret;
  41. }
  42. void run_tensor_add(
  43. Handle *handle_cuda,
  44. const TensorND &a, const TensorND &b,
  45. const TensorND &c) {
  46. #if 1
  47. cudnnHandle_t cudnn_handle;
  48. cudnn_check(cudnnCreate(&cudnn_handle));
  49. cuda_check(cudaDeviceSynchronize());
  50. cuda_check(cudaMemcpy(c.raw_ptr, a.raw_ptr, a.layout.span().dist_byte(),
  51. cudaMemcpyDeviceToDevice));
  52. auto bdesc = make_cudnn_tensor_desc(b.layout),
  53. cdesc = make_cudnn_tensor_desc(c.layout);
  54. float alpha = 1, beta = 1;
  55. cudaProfilerStart();
  56. cudnn_check(cudnnAddTensor(cudnn_handle,
  57. &alpha, bdesc, b.raw_ptr,
  58. &beta, cdesc, c.raw_ptr));
  59. cudaProfilerStop();
  60. cudnn_check(cudnnDestroyTensorDescriptor(cdesc));
  61. cudnn_check(cudnnDestroyTensorDescriptor(bdesc));
  62. cudnn_check(cudnnDestroy(cudnn_handle));
  63. cuda_check(cudaMemset(c.raw_ptr, 0, c.layout.span().dist_byte()));
  64. cuda_check(cudaDeviceSynchronize());
  65. #endif
  66. auto opr = handle_cuda->create_operator<ElemwiseForward>();
  67. opr->param().mode = ElemwiseForward::Mode::ADD;
  68. cudaProfilerStart();
  69. opr->exec({a, b}, c);
  70. cudaProfilerStop();
  71. }
  72. } // anonymous namespace
  73. template<typename tag>
  74. class CUDA_ELEMWISE: public CUDA {
  75. };
  76. TYPED_TEST_CASE(CUDA_ELEMWISE, elemwise::test_types);
  77. TYPED_TEST(CUDA_ELEMWISE, run) {
  78. elemwise::run_test<TypeParam>(this->handle_cuda());
  79. }
  80. TEST_F(CUDA, ELEMWISE_IBYTE) {
  81. Checker<ElemwiseForward> checker(handle_cuda());
  82. using Mode = ElemwiseForward::Param::Mode;
  83. UniformIntRNG i_rng{-128, 127};
  84. UniformIntRNG ui_rng{0, 255};
  85. checker.set_rng(0, &i_rng);
  86. auto run_unary = [&](size_t N, Mode mode, DType dtype) {
  87. checker.set_param(mode).set_dtype(0, dtype);
  88. checker.execs({{N}, {}});
  89. };
  90. #define RUN_UNARY_IBYTE(_dt) \
  91. run_unary(100, Mode::RELU, _dt); \
  92. run_unary(100, Mode::ABS, _dt);
  93. RUN_UNARY_IBYTE(dtype::Int8());
  94. checker.set_rng(0, &i_rng);
  95. RUN_UNARY_IBYTE(dtype::Uint8());
  96. #undef RUN_UNARY_IBYTE
  97. auto run_binary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  98. DType dtype) {
  99. checker.set_param(mode).set_dtype(0, dtype).set_dtype(1, dtype);
  100. checker.execs({{5}, {5}, {}});
  101. checker.execs({{4}, {4}, {}});
  102. checker.execs({{4}, {1}, {}});
  103. checker.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}});
  104. checker.execs({{N, C / 4, H, W, 4}, {1, C / 4, 1, 1, 4}, {}});
  105. checker.execs({{N, C / 32, H, W, 32}, {N, C / 32, H, W, 32}, {}});
  106. checker.execs({{N, C / 32, H, W, 32}, {1, C / 32, 1, 1, 32}, {}});
  107. checker.execs({{3, 5, 7}, {3, 5, 7}, {}});
  108. checker.execs({{3, 5, 7}, {3, 5, 1}, {}});
  109. checker.execs({{3, 5, 1}, {3, 5, 7}, {}});
  110. checker.execs({{1}, {3, 5, 7}, {}});
  111. checker.execs({{3, 5, 7}, {1}, {}});
  112. };
  113. #define RUN_BINARY_IBYTE(_dt) \
  114. run_binary(4, 32, 10, 10, Mode::ADD, _dt); \
  115. run_binary(4, 32, 10, 10, Mode::MUL, _dt); \
  116. run_binary(4, 32, 10, 10, Mode::MAX, _dt); \
  117. run_binary(4, 32, 10, 10, Mode::MIN, _dt); \
  118. run_binary(4, 32, 10, 10, Mode::SUB, _dt);
  119. checker.set_rng(0, &i_rng).set_rng(1, &i_rng);
  120. RUN_BINARY_IBYTE(dtype::Int8());
  121. checker.set_rng(0, &ui_rng).set_rng(1, &ui_rng);
  122. RUN_BINARY_IBYTE(dtype::Uint8());
  123. #undef RUN_BINARY_IBYTE
  124. auto run_ternary = [&](size_t N, size_t C, size_t H, size_t W, Mode mode,
  125. DType dtype) {
  126. checker.set_param(mode)
  127. .set_dtype(0, dtype)
  128. .set_dtype(1, dtype)
  129. .set_dtype(2, dtype);
  130. checker.execs({{5}, {5}, {5}, {}});
  131. checker.execs({{4}, {4}, {1}, {}});
  132. checker.execs({{N, C / 4, H, W, 4},
  133. {N, C / 4, H, W, 4},
  134. {N, C / 4, H, W, 4},
  135. {}});
  136. checker.execs({{N, C / 4, H, W, 4},
  137. {1, C / 4, 1, 1, 4},
  138. {1, C / 4, 1, 1, 4},
  139. {}});
  140. checker.execs({{N, C / 32, H, W, 32},
  141. {N, C / 32, H, W, 32},
  142. {N, C / 32, H, W, 32},
  143. {}});
  144. checker.execs({{N, C / 32, H, W, 32},
  145. {1, C / 32, 1, 1, 32},
  146. {1, C / 32, 1, 1, 32},
  147. {}});
  148. checker.execs({{1}, {3, 5, 7}, {3, 5, 7}, {}});
  149. checker.execs({{3, 5, 7}, {3, 5, 1}, {3, 5, 1}, {}});
  150. checker.execs({{3, 5, 1}, {3, 5, 7}, {3, 5, 1}, {}});
  151. checker.execs({{1}, {3, 5, 7}, {1}, {}});
  152. checker.execs({{3, 5, 7}, {1}, {3, 5, 7}, {}});
  153. };
  154. #define RUN_TERNARY_IBYTE(_dt) \
  155. run_ternary(4, 32, 10, 10, Mode::FUSE_MUL_ADD3, _dt);
  156. checker.set_rng(0, &i_rng).set_rng(1, &i_rng);
  157. RUN_TERNARY_IBYTE(dtype::Int8());
  158. checker.set_rng(0, &ui_rng).set_rng(1, &ui_rng);
  159. RUN_TERNARY_IBYTE(dtype::Uint8());
  160. #undef RUN_TERNARY_IBYTE
  161. }
  162. // from common/elemwise.cpp
  163. TEST_F(CUDA, ELEMWISE_BFLOAT16) {
  164. using Mode = ElemwiseForward::Param::Mode;
  165. Checker<ElemwiseForward> checker(handle_cuda());
  166. // unary
  167. #define UNARY_TEST_CASE(_optr) \
  168. checker.set_param(Mode::_optr).execs({{1, 127}, {}}); \
  169. checker.set_param(Mode::_optr).execs({{1, 7}, {}});
  170. #define BUILD_UNARY_TEST_CASE_FLOAT \
  171. UNARY_TEST_CASE(ABS) \
  172. UNARY_TEST_CASE(LOG) \
  173. UNARY_TEST_CASE(COS) \
  174. UNARY_TEST_CASE(SIN) \
  175. UNARY_TEST_CASE(FLOOR) \
  176. UNARY_TEST_CASE(CEIL) \
  177. UNARY_TEST_CASE(SIGMOID) \
  178. UNARY_TEST_CASE(EXP) \
  179. UNARY_TEST_CASE(TANH) \
  180. UNARY_TEST_CASE(FAST_TANH) \
  181. UNARY_TEST_CASE(RELU) \
  182. UNARY_TEST_CASE(ROUND)
  183. checker.set_dtype(0, dtype::BFloat16());
  184. checker.set_dtype(1, dtype::BFloat16());
  185. UniformFloatRNG rng0(1e-2, 6e1);
  186. checker.set_rng(0, &rng0);
  187. checker.set_epsilon(1e-2);
  188. BUILD_UNARY_TEST_CASE_FLOAT
  189. #undef UNARY_TEST_CASE
  190. #undef BUILD_UNARY_TEST_CASE_FLOAT
  191. // binary
  192. #define BINARY_COMPLATE_TEST_CASE(_optr) \
  193. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {3, 4, 7}, {}}); \
  194. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  195. checker.set_param(Mode::_optr).execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {}}); \
  196. checker.set_param(Mode::_optr).execs({{3, 4, 7}, {1, 4, 1}, {}}); \
  197. checker.set_param(Mode::_optr).execs({{1, 4, 1}, {3, 4, 7}, {}}); \
  198. checker.set_param(Mode::_optr).execs({{3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  199. checker.set_param(Mode::_optr).execs({{1, 1, 1, 1}, {3, 4, 5, 7}, {}}); \
  200. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {}}); \
  201. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 2, 1}, {}}); \
  202. checker.set_param(Mode::_optr).execs({{1, 2, 1}, {1, 2, 2}, {}}); \
  203. checker.set_param(Mode::_optr).execs({{1, 2, 2}, {1, 1, 1}, {}}); \
  204. checker.set_param(Mode::_optr).execs({{1, 1, 1}, {1, 2, 2}, {}}); \
  205. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {}});
  206. #define BUILD_BINARY_COMPLATE_TEST_CASE \
  207. BINARY_COMPLATE_TEST_CASE(ADD) \
  208. BINARY_COMPLATE_TEST_CASE(MUL) \
  209. BINARY_COMPLATE_TEST_CASE(MAX) \
  210. BINARY_COMPLATE_TEST_CASE(MIN) \
  211. BINARY_COMPLATE_TEST_CASE(SUB)
  212. UniformFloatRNG rng1(1e-5, 7e1);
  213. checker.set_rng(0, &rng1);
  214. checker.set_epsilon(1e-2);
  215. checker.set_dtype(0, dtype::BFloat16());
  216. checker.set_dtype(1, dtype::BFloat16());
  217. BUILD_BINARY_COMPLATE_TEST_CASE
  218. #undef BINARY_COMPLATE_TEST_CASE
  219. #undef BUILD_BINARY_COMPLATE_TEST_CASE
  220. // ternary
  221. #define TERNARY_COMPLATE_TEST_CASE(_optr) \
  222. checker.set_param(Mode::_optr) \
  223. .execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
  224. checker.set_param(Mode::_optr) \
  225. .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
  226. checker.set_param(Mode::_optr) \
  227. .execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
  228. checker.set_param(Mode::_optr) \
  229. .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
  230. checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
  231. checker.set_param(Mode::_optr) \
  232. .execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
  233. checker.set_param(Mode::_optr) \
  234. .execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
  235. checker.set_param(Mode::_optr).execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
  236. #define BUILD_TERNARY_COMPLATE_TEST_CASE \
  237. TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
  238. UniformFloatRNG rng2(1e-5, 7e1);
  239. checker.set_rng(0, &rng2);
  240. checker.set_epsilon(1e-2);
  241. checker.set_dtype(0, dtype::BFloat16());
  242. checker.set_dtype(1, dtype::BFloat16());
  243. checker.set_dtype(2, dtype::BFloat16());
  244. BUILD_TERNARY_COMPLATE_TEST_CASE
  245. #undef TERNARY_COMPLATE_TEST_CASE
  246. #undef BUILD_TERNARY_COMPLATE_TEST_CASE
  247. }
  248. //! the memory of this test case is too large, sometimes will fail on tx1
  249. TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) {
  250. constexpr size_t A = 256 * 1024 * 64,
  251. S0 = 16, S1 = 256, S2 = 64, S3 = 64;
  252. static_assert(A == S0 * S1 * S2 * S3, "bad value");
  253. SyncedTensor<>
  254. t0(handle_cuda(), {TensorShape{S0, S1, S2, S3}, dtype::Float32()}),
  255. t1(handle_cuda(), {TensorShape{S0, S1, S2, S3}, dtype::Float32()});
  256. UniformFloatRNG rng{-2.f, 2.f};
  257. rng.gen(t0.tensornd_host());
  258. run_tensor_add(handle_cuda(),
  259. t0.tensornd_dev(), t0.tensornd_dev(), t1.tensornd_dev());
  260. auto p0 = t0.ptr_host(), p1 = t1.ptr_host();
  261. for (size_t i = 0; i < A; ++ i) {
  262. ASSERT_EQ(p0[i] + p0[i], p1[i]) << "at index " << i << "/" << A;
  263. }
  264. }
  265. #if MEGDNN_WITH_BENCHMARK
  266. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_101) {
  267. constexpr size_t A = 511, B = 509, C0 = 23, C1 = 23, C = C0 * C1;
  268. SyncedTensor<>
  269. t0(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Float32()}),
  270. t1(handle_cuda(), {TensorShape{1, B, 1, 1}, dtype::Float32()}),
  271. t2(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Float32()});
  272. UniformFloatRNG rng{-2.f, 2.f};
  273. rng.gen(t0.tensornd_host());
  274. rng.gen(t1.tensornd_host());
  275. run_tensor_add(handle_cuda(),
  276. t0.tensornd_dev(), t1.tensornd_dev(), t2.tensornd_dev());
  277. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  278. for (size_t i = 0; i < A; ++ i) {
  279. for (size_t j = 0; j < B; ++ j) {
  280. for (size_t k = 0; k < C; ++ k) {
  281. auto off = i * B * C + j * C + k;
  282. ASSERT_EQ(p0[off] + p1[j], p2[off]);
  283. }
  284. }
  285. }
  286. }
  287. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_10) {
  288. constexpr size_t A = 11583, B = 11587;
  289. SyncedTensor<> t0(handle_cuda(), {TensorShape{A, B}, dtype::Float32()}),
  290. t1(handle_cuda(), {TensorShape{1, B}, dtype::Float32()}),
  291. t2(handle_cuda(), {TensorShape{A, B}, dtype::Float32()});
  292. UniformFloatRNG rng{-2.f, 2.f};
  293. rng.gen(t0.tensornd_host());
  294. rng.gen(t1.tensornd_host());
  295. run_tensor_add(handle_cuda(),
  296. t0.tensornd_dev(), t1.tensornd_dev(), t2.tensornd_dev());
  297. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  298. for (size_t i = 0; i < A; ++ i) {
  299. for (size_t j = 0; j < B; ++ j) {
  300. auto off = i * B + j;
  301. ASSERT_EQ(p0[off] + p1[j], p2[off]);
  302. }
  303. }
  304. }
  305. TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_01) {
  306. constexpr size_t A = 11583, B = 11587;
  307. SyncedTensor<> t0(handle_cuda(), {TensorShape{1, A, B}, dtype::Float32()}),
  308. t1(handle_cuda(), {TensorShape{1, A, 1}, dtype::Float32()}),
  309. t2(handle_cuda(), {TensorShape{1, A, B}, dtype::Float32()});
  310. UniformFloatRNG rng{-2.f, 2.f};
  311. rng.gen(t0.tensornd_host());
  312. rng.gen(t1.tensornd_host());
  313. run_tensor_add(handle_cuda(),
  314. t0.tensornd_dev(), t1.tensornd_dev(), t2.tensornd_dev());
  315. auto p0 = t0.ptr_host(), p1 = t1.ptr_host(), p2 = t2.ptr_host();
  316. for (size_t i = 0; i < A; ++ i) {
  317. for (size_t j = 0; j < B; ++ j) {
  318. auto off = i * B + j;
  319. ASSERT_EQ(p0[off] + p1[i], p2[off]);
  320. }
  321. }
  322. }
  323. TEST_F(CUDA, BENCHMARK_ELEMWISE_IBYTE) {
  324. Benchmarker<ElemwiseForward> bencher(handle_cuda());
  325. using Mode = ElemwiseForward::Param::Mode;
  326. auto run_bench = [&](size_t N, size_t C, size_t H, size_t W) {
  327. size_t nr_times = 100;
  328. bencher.set_times(nr_times)
  329. .set_param(Mode::FUSE_ADD_RELU)
  330. .set_dtype(0, dtype::Int8())
  331. .set_dtype(1, dtype::Int8());
  332. auto time = bencher.execs({{N * C * H * W + 1}, {N * C * H * W + 1}, {}}) /
  333. nr_times;
  334. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  335. (3.0 * (N * C * H * W + 1)) / (time * 1e6));
  336. time = bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) /
  337. nr_times;
  338. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  339. (3.0 * N * C * H * W) / (time * 1e6));
  340. time = bencher.execs({{N, C / 4, H, W, 4}, {1, C / 4, 1, 1, 4}, {}}) /
  341. nr_times;
  342. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  343. (C + 2.0 * N * C * H * W) / (time * 1e6));
  344. time = bencher.execs({{N, C / 4, H, W, 4}, {1}, {}}) / nr_times;
  345. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  346. (2.0 * N * C * H * W + 1) / (time * 1e6));
  347. time = bencher.execs(
  348. {{N, C / 32, H, W, 32}, {N, C / 32, H, W, 32}, {}}) /
  349. nr_times;
  350. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  351. (3.0 * N * C * H * W) / (time * 1e6));
  352. time = bencher.execs(
  353. {{N, C / 32, H, W, 32}, {1, C / 32, 1, 1, 32}, {}}) /
  354. nr_times;
  355. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  356. (C + 2.0 * N * C * H * W) / (time * 1e6));
  357. bencher.set_dtype(0, dtype::Float32()).set_dtype(1, dtype::Float32());
  358. time = bencher.execs({{N, C / 4, H, W}, {N, C / 4, H, W}, {}}) /
  359. nr_times;
  360. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  361. (3.0 * N * C * H * W) / (time * 1e6));
  362. time = bencher.execs({{N, C / 4, H, W}, {1, C / 4, 1, 1}, {}}) /
  363. nr_times;
  364. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  365. (C + 2.0 * N * C * H * W) / (time * 1e6));
  366. };
  367. run_bench(256, 256, 56, 56);
  368. }
  369. TEST_F(CUDA, BENCHMARK_ELEMWISE_MIN_MAX) {
  370. Benchmarker<ElemwiseForward> bencher(handle_cuda());
  371. using Mode = ElemwiseForward::Param::Mode;
  372. UniformIntRNG const_1{1, 1}, rng{-128, 127};
  373. auto run_bench = [&](size_t N, size_t C, size_t H, size_t W, DType dtype) {
  374. size_t nr_times = 1000;
  375. bencher.set_times(nr_times)
  376. .set_param(Mode::MIN)
  377. .set_rng(0, &rng)
  378. .set_rng(1, &rng)
  379. .set_dtype(0, dtype)
  380. .set_dtype(1, dtype);
  381. auto time =
  382. bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) /
  383. nr_times;
  384. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  385. (3.0 * N * C * H * W) / (time * 1e6));
  386. bencher.set_param(Mode::MAX).set_rng(0, &const_1).set_rng(1, &const_1);
  387. time = bencher.execs({{N, C / 4, H, W, 4}, {N, C / 4, H, W, 4}, {}}) /
  388. nr_times;
  389. printf("time = %.2fms, bandwidth = %.2fGB/s\n", time,
  390. (3.0 * N * C * H * W) / (time * 1e6));
  391. };
  392. run_bench(256, 256, 56, 56, dtype::Int8());
  393. }
  394. #endif
  395. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台