You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. /**
  2. * \file dnn/test/rocm/convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/convolution.h"
  12. #include "hcc_detail/hcc_defs_prologue.h"
  13. #include "megdnn/opr_param_defs.h"
  14. #include "megdnn/oprs.h"
  15. #include "test/common/benchmarker.h"
  16. #include "test/common/checker.h"
  17. #include "test/common/rng.h"
  18. #include "test/common/tensor.h"
  19. #include "test/common/workspace_wrapper.h"
  20. #include "test/rocm/benchmarker.h"
  21. #include "test/rocm/fixture.h"
  22. #include "src/common/utils.h"
  23. #include "src/rocm/utils.h"
  24. namespace megdnn {
  25. namespace test {
  26. namespace convolution {
  27. std::vector<TestArg> get_args_0() {
  28. std::vector<TestArg> args, tmp_args;
  29. #define ADD_ARGS(NAME) \
  30. tmp_args = get_args_##NAME(); \
  31. args.insert(args.end(), tmp_args.begin(), tmp_args.end());
  32. ADD_ARGS(common)
  33. ADD_ARGS(padding)
  34. ADD_ARGS(large_channel)
  35. ADD_ARGS(1x1)
  36. ADD_ARGS(large_filter)
  37. ADD_ARGS(exhaustive_search)
  38. ADD_ARGS(4x4)
  39. ADD_ARGS(large_channels)
  40. ADD_ARGS(x86_direct_case_2)
  41. ADD_ARGS(cudnn_5_1_failures)
  42. ADD_ARGS(x86_winograd_algorithm)
  43. ADD_ARGS(BRAIN_481)
  44. #undef ADD_ARGS
  45. return args;
  46. }
  47. std::vector<TestArg> get_args_1() {
  48. return get_args_fallback_templated_impl();
  49. }
  50. std::vector<TestArg> get_args_2() {
  51. return get_args_fallback_non_templated_impl();
  52. }
  53. std::vector<TestArg> get_group_conv_args() {
  54. std::vector<TestArg> args;
  55. for (size_t batch_size : {2}) {
  56. for (size_t ih : {23}) {
  57. for (size_t iw : {ih + 1}) {
  58. for (size_t icpg : {2, 4, 8}) {
  59. for (size_t ocpg : {4, 8}) {
  60. for (size_t fh : {3, 5, 7}) {
  61. for (size_t fw : {fh, fh + 1}) {
  62. for (size_t ph : {0_z, size_t{fw / 2}}) {
  63. for (size_t sh : {1, 2}) {
  64. for (size_t dh : {1, 2}) {
  65. param::Convolution param;
  66. size_t groups = 2;
  67. param.sparse =
  68. param::Convolution::Sparse::GROUP;
  69. param.mode = param::Convolution::Mode::
  70. CROSS_CORRELATION;
  71. param.stride_h = param.stride_w = sh;
  72. param.pad_h = param.pad_w = ph;
  73. param.dilate_h = param.dilate_w = dh;
  74. args.emplace_back(
  75. param,
  76. TensorShape{
  77. batch_size, icpg * groups,
  78. ih, iw},
  79. TensorShape{
  80. groups, ocpg, icpg, fh,
  81. fw});
  82. }
  83. }
  84. }
  85. }
  86. }
  87. }
  88. }
  89. }
  90. }
  91. }
  92. return args;
  93. }
  94. } // namespace convolution
  95. TEST_F(ROCM, CONV_GROUP) {
  96. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  97. using namespace convolution;
  98. std::vector<TestArg> args = get_group_conv_args();
  99. Checker<ConvolutionForward> checker(handle_rocm());
  100. NormalRNG default_rng;
  101. for (auto&& arg : args) {
  102. checker.set_dtype(0, dtype::Float32())
  103. .set_dtype(1, dtype::Float32())
  104. .set_rng(0, &default_rng)
  105. .set_rng(1, &default_rng)
  106. .set_epsilon(1e-3)
  107. .set_param(arg.param)
  108. .execs({arg.src, arg.filter, {}});
  109. }
  110. }
  111. TEST_F(ROCM, CONV_CHANNWISE) {
  112. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  113. using namespace convolution;
  114. std::vector<TestArg> args = get_chanwise_args();
  115. Checker<ConvolutionForward> checker(handle_rocm());
  116. NormalRNG default_rng;
  117. for (auto&& arg : args) {
  118. using Mode = param::Convolution::Mode;
  119. //! non xcorr not supported for miopen
  120. if (arg.param.mode == Mode::CONVOLUTION) {
  121. continue;
  122. }
  123. checker.set_dtype(0, dtype::Float32())
  124. .set_dtype(1, dtype::Float32())
  125. .set_rng(0, &default_rng)
  126. .set_rng(1, &default_rng)
  127. .set_epsilon(1e-3)
  128. .set_param(arg.param)
  129. .execs({arg.src, arg.filter, {}});
  130. }
  131. }
  132. TEST_F(ROCM, CONVOLUTION_FORWARD_0) {
  133. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  134. using namespace convolution;
  135. std::vector<TestArg> args = get_args_0();
  136. Checker<ConvolutionForward> checker(handle_rocm());
  137. NormalRNG default_rng;
  138. for (auto&& arg : args) {
  139. using Mode = param::Convolution::Mode;
  140. //! non xcorr not supported for miopen
  141. if (arg.param.mode == Mode::CONVOLUTION) {
  142. continue;
  143. }
  144. float scale = 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  145. UniformFloatRNG rng(scale, 2 * scale);
  146. checker.set_dtype(0, dtype::Float32())
  147. .set_dtype(1, dtype::Float32())
  148. .set_dtype(2, dtype::Float32())
  149. .set_rng(0, &default_rng)
  150. .set_rng(1, &default_rng)
  151. .set_epsilon(1e-3)
  152. .set_param(arg.param)
  153. .execs({arg.src, arg.filter, {}});
  154. #if !MEGDNN_DISABLE_FLOAT16
  155. checker.set_dtype(0, dtype::Float16())
  156. .set_dtype(1, dtype::Float16())
  157. .set_dtype(2, dtype::Float16())
  158. .set_rng(0, &rng)
  159. .set_rng(1, &rng)
  160. .set_epsilon(1e-1)
  161. .set_param(arg.param)
  162. .execs({arg.src, arg.filter, {}});
  163. #endif
  164. }
  165. }
  166. TEST_F(ROCM, CONVOLUTION_FORWARD_1) {
  167. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  168. using namespace convolution;
  169. std::vector<TestArg> args = get_args_1();
  170. Checker<ConvolutionForward> checker(handle_rocm());
  171. NormalRNG default_rng;
  172. for (auto&& arg : args) {
  173. using Mode = param::Convolution::Mode;
  174. //! non xcorr not supported for miopen
  175. if (arg.param.mode == Mode::CONVOLUTION) {
  176. continue;
  177. }
  178. float scale = 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  179. UniformFloatRNG rng(scale, 2 * scale);
  180. checker.set_dtype(0, dtype::Float32())
  181. .set_dtype(1, dtype::Float32())
  182. .set_dtype(2, dtype::Float32())
  183. .set_rng(0, &default_rng)
  184. .set_rng(1, &default_rng)
  185. .set_epsilon(1e-3)
  186. .set_param(arg.param)
  187. .execs({arg.src, arg.filter, {}});
  188. #if !MEGDNN_DISABLE_FLOAT16
  189. checker.set_dtype(0, dtype::Float16())
  190. .set_dtype(1, dtype::Float16())
  191. .set_dtype(2, dtype::Float16())
  192. .set_rng(0, &rng)
  193. .set_rng(1, &rng)
  194. .set_epsilon(1e-1)
  195. .set_param(arg.param)
  196. .execs({arg.src, arg.filter, {}});
  197. #endif
  198. }
  199. }
  200. TEST_F(ROCM, CONVOLUTION_FORWARD_2) {
  201. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  202. using namespace convolution;
  203. std::vector<TestArg> args = get_args_2();
  204. Checker<ConvolutionForward> checker(handle_rocm());
  205. NormalRNG default_rng;
  206. for (auto&& arg : args) {
  207. using Mode = param::Convolution::Mode;
  208. //! non xcorr not supported for miopen
  209. if (arg.param.mode == Mode::CONVOLUTION) {
  210. continue;
  211. }
  212. float scale = 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  213. UniformFloatRNG rng(scale, 2 * scale);
  214. checker.set_dtype(0, dtype::Float32())
  215. .set_dtype(1, dtype::Float32())
  216. .set_dtype(2, dtype::Float32())
  217. .set_rng(0, &default_rng)
  218. .set_rng(1, &default_rng)
  219. .set_epsilon(1e-3)
  220. .set_param(arg.param)
  221. .execs({arg.src, arg.filter, {}});
  222. #if !MEGDNN_DISABLE_FLOAT16
  223. checker.set_dtype(0, dtype::Float16())
  224. .set_dtype(1, dtype::Float16())
  225. .set_dtype(2, dtype::Float16())
  226. .set_rng(0, &rng)
  227. .set_rng(1, &rng)
  228. .set_epsilon(1e-1)
  229. .set_param(arg.param)
  230. .execs({arg.src, arg.filter, {}});
  231. #endif
  232. }
  233. }
  234. TEST_F(ROCM, CONVOLUTION_1X1_FORWARD) {
  235. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  236. using namespace convolution;
  237. std::vector<TestArg> args = get_1x1_args();
  238. Checker<ConvolutionForward> checker(handle_rocm());
  239. NormalRNG default_rng;
  240. for (auto&& arg : args) {
  241. float scale = 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  242. UniformFloatRNG rng(scale, 2 * scale);
  243. checker.set_dtype(0, dtype::Float32())
  244. .set_dtype(1, dtype::Float32())
  245. .set_rng(0, &default_rng)
  246. .set_rng(1, &default_rng)
  247. .set_epsilon(1e-3)
  248. .set_param(arg.param)
  249. .execs({arg.src, arg.filter, {}});
  250. }
  251. }
  252. #if MEGDNN_WITH_BENCHMARK
  253. TEST_F(ROCM, CONVOLUTION_1X1_FORWARD_ALL_ALGOS) {
  254. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  255. using namespace convolution;
  256. OprProxy<ConvolutionForward> proxy{true};
  257. proxy.warmup_times = 1;
  258. proxy.exec_times = 10;
  259. Benchmarker<ConvolutionForward> checker(handle_rocm());
  260. checker.set_times(1);
  261. auto get_computation = [&](TestArg arg) -> float {
  262. megdnn_assert(arg.param.format == param::Convolution::Format::NCHW);
  263. size_t N = arg.src[0], IC = arg.src[1], IH = arg.src[2], IW = arg.src[3],
  264. OC = arg.filter[0], FH = arg.filter[2], FW = arg.filter[3],
  265. SH = arg.param.stride_h, SW = arg.param.stride_w, PH = arg.param.pad_h,
  266. PW = arg.param.pad_w;
  267. size_t OH = infer_conv_shape(IH, FH, SH, PH);
  268. size_t OW = infer_conv_shape(IW, FW, SW, PW);
  269. float flops = 2.0 * N * OC * OH * OW * IC * FH * FW;
  270. return flops;
  271. };
  272. std::vector<TestArg> args = get_1x1_args();
  273. NormalRNG default_rng;
  274. for (auto&& arg : args) {
  275. float scale = 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  276. UniformFloatRNG rng(scale, 2 * scale);
  277. checker.set_proxy(proxy)
  278. .set_dtype(0, dtype::Float32())
  279. .set_dtype(1, dtype::Float32())
  280. .set_rng(0, &default_rng)
  281. .set_rng(1, &default_rng)
  282. .set_param(arg.param);
  283. float time_in_ms = checker.execs({arg.src, arg.filter, {}});
  284. float flops = get_computation(arg);
  285. printf("inp=%s,flt=%s,flops=%.2fGflo,time = %.2f ms, perf = %.2f "
  286. "GFLOPS\n",
  287. arg.src.to_string().c_str(), arg.filter.to_string().c_str(), flops / 1e9,
  288. time_in_ms, flops / (1e6 * time_in_ms));
  289. }
  290. }
  291. #endif
  292. TEST_F(ROCM, CONVOLUTION_BACKWARD_DATA_0) {
  293. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  294. using namespace convolution;
  295. std::vector<TestArg> args = get_args_0();
  296. Checker<ConvolutionBackwardData> checker(handle_rocm());
  297. NormalRNG default_rng;
  298. for (auto&& arg : args) {
  299. using Mode = param::Convolution::Mode;
  300. //! non xcorr not supported for miopen
  301. if (arg.param.mode == Mode::CONVOLUTION) {
  302. continue;
  303. }
  304. float scale = 1.0f / sqrt(arg.filter[0] * arg.filter[2] * arg.filter[3]);
  305. UniformFloatRNG rng(scale, 2 * scale);
  306. auto src = TensorLayout(arg.src, dtype::Float32());
  307. auto filter = TensorLayout(arg.filter, dtype::Float32());
  308. TensorLayout dst;
  309. {
  310. auto opr = handle_rocm()->create_operator<Convolution>();
  311. opr->param() = arg.param;
  312. opr->deduce_layout(src, filter, dst);
  313. }
  314. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  315. checker.set_rng(0, &default_rng)
  316. .set_rng(1, &default_rng)
  317. .set_epsilon(1e-3)
  318. .set_param(arg.param)
  319. .exec(TensorLayoutArray{filter, dst, src});
  320. #if !MEGDNN_DISABLE_FLOAT16
  321. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  322. checker.set_rng(0, &rng)
  323. .set_rng(1, &rng)
  324. .set_epsilon(1e-1)
  325. .set_param(arg.param)
  326. .exec(TensorLayoutArray{filter, dst, src});
  327. #endif
  328. }
  329. }
  330. TEST_F(ROCM, CONVOLUTION_BACKWARD_DATA_1) {
  331. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  332. using namespace convolution;
  333. std::vector<TestArg> args = get_args_1();
  334. Checker<ConvolutionBackwardData> checker(handle_rocm());
  335. NormalRNG default_rng;
  336. for (auto&& arg : args) {
  337. using Mode = param::Convolution::Mode;
  338. //! non xcorr not supported for miopen
  339. if (arg.param.mode == Mode::CONVOLUTION) {
  340. continue;
  341. }
  342. float scale = 1.0f / sqrt(arg.filter[0] * arg.filter[2] * arg.filter[3]);
  343. UniformFloatRNG rng(scale, 2 * scale);
  344. auto src = TensorLayout(arg.src, dtype::Float32());
  345. auto filter = TensorLayout(arg.filter, dtype::Float32());
  346. TensorLayout dst;
  347. {
  348. auto opr = handle_rocm()->create_operator<Convolution>();
  349. opr->param() = arg.param;
  350. opr->deduce_layout(src, filter, dst);
  351. }
  352. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  353. checker.set_rng(0, &default_rng)
  354. .set_rng(1, &default_rng)
  355. .set_epsilon(1e-3)
  356. .set_param(arg.param)
  357. .exec(TensorLayoutArray{filter, dst, src});
  358. #if !MEGDNN_DISABLE_FLOAT16
  359. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  360. checker.set_rng(0, &rng)
  361. .set_rng(1, &rng)
  362. .set_epsilon(1e-1)
  363. .set_param(arg.param)
  364. .exec(TensorLayoutArray{filter, dst, src});
  365. #endif
  366. }
  367. }
  368. TEST_F(ROCM, CONVOLUTION_BACKWARD_DATA_2) {
  369. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  370. using namespace convolution;
  371. std::vector<TestArg> args = get_args_2();
  372. Checker<ConvolutionBackwardData> checker(handle_rocm());
  373. NormalRNG default_rng;
  374. for (auto&& arg : args) {
  375. using Mode = param::Convolution::Mode;
  376. //! non xcorr not supported for miopen
  377. if (arg.param.mode == Mode::CONVOLUTION) {
  378. continue;
  379. }
  380. float scale = 1.0f / sqrt(arg.filter[0] * arg.filter[2] * arg.filter[3]);
  381. UniformFloatRNG rng(scale, 2 * scale);
  382. auto src = TensorLayout(arg.src, dtype::Float32());
  383. auto filter = TensorLayout(arg.filter, dtype::Float32());
  384. TensorLayout dst;
  385. {
  386. auto opr = handle_rocm()->create_operator<Convolution>();
  387. opr->param() = arg.param;
  388. opr->deduce_layout(src, filter, dst);
  389. }
  390. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  391. checker.set_rng(0, &default_rng)
  392. .set_rng(1, &default_rng)
  393. .set_epsilon(1e-3)
  394. .set_param(arg.param)
  395. .exec(TensorLayoutArray{filter, dst, src});
  396. #if !MEGDNN_DISABLE_FLOAT16
  397. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  398. checker.set_rng(0, &rng)
  399. .set_rng(1, &rng)
  400. .set_epsilon(1e-1)
  401. .set_param(arg.param)
  402. .exec(TensorLayoutArray{filter, dst, src});
  403. #endif
  404. }
  405. }
  406. TEST_F(ROCM, DISABLED_CONVOLUTION_BACKWARD_FILTER) {
  407. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  408. using namespace convolution;
  409. std::vector<TestArg> args = get_args();
  410. Checker<ConvolutionBackwardFilter> checker(handle_rocm());
  411. NormalRNG default_rng;
  412. bool f16_checked = false;
  413. MEGDNN_MARK_USED_VAR(f16_checked);
  414. for (auto&& arg : args) {
  415. using Mode = param::Convolution::Mode;
  416. //! non xcorr not supported for miopen
  417. if (arg.param.mode == Mode::CONVOLUTION) {
  418. continue;
  419. }
  420. auto src = TensorLayout(arg.src, dtype::Float32());
  421. auto filter = TensorLayout(arg.filter, dtype::Float32());
  422. TensorLayout dst;
  423. {
  424. auto opr = handle_rocm()->create_operator<Convolution>();
  425. opr->param() = arg.param;
  426. opr->deduce_layout(src, filter, dst);
  427. }
  428. float scale = 1.0f / sqrt(dst[2] * dst[3]);
  429. UniformFloatRNG rng(scale, 2 * scale);
  430. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  431. checker.set_rng(0, &default_rng)
  432. .set_rng(1, &default_rng)
  433. .set_epsilon(1e-3)
  434. .set_param(arg.param)
  435. .exec(TensorLayoutArray{src, dst, filter});
  436. #if !MEGDNN_DISABLE_FLOAT16
  437. //! FIXME: MIOpen convolution backward weights for FP16 with bugs
  438. #if 0
  439. // reduce on large f16 array may introduce significant error
  440. if (dst.total_nr_elems() >= 1000 && f16_checked)
  441. continue;
  442. f16_checked = true;
  443. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  444. checker.set_rng(0, &rng)
  445. .set_rng(1, &rng)
  446. .set_epsilon(1e-1)
  447. .set_param(arg.param)
  448. .exec(TensorLayoutArray{src, dst, filter});
  449. #endif
  450. #endif
  451. }
  452. }
  453. #if MEGDNN_WITH_BENCHMARK
  454. TEST_F(ROCM, CONV_FWD_BENCHMARK) {
  455. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  456. auto benchmarker =
  457. ROCMBenchmarker<ConvolutionForward>(handle_rocm(), handle_naive(false));
  458. auto run = [&](size_t N, size_t OC, size_t IC, size_t IH, size_t IW, size_t SH = 1,
  459. size_t SW = 1, size_t FH = 1, size_t FW = 1, size_t PH = 0,
  460. size_t PW = 0, DType dtype = dtype::Float32()) {
  461. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  462. benchmarker.set_display(true);
  463. ConvolutionForward::Param param;
  464. param.stride_h = SH;
  465. param.stride_w = SW;
  466. param.pad_h = PH;
  467. param.pad_w = PW;
  468. benchmarker.set_param(param);
  469. size_t OH = (IH - FH + 2 * PH) / SH + 1;
  470. size_t OW = (IW - FW + 2 * PW) / SW + 1;
  471. // warm up find best algo
  472. benchmarker.execs({{N, IC, IH, IW}, {OC, IC, FH, FW}, {N, OC, OH, OW}});
  473. // do actual benchmark
  474. auto time_ms =
  475. benchmarker.execs({{N, IC, IH, IW}, {OC, IC, FH, FW}, {N, OC, OH, OW}});
  476. auto flo = (double)N * OC * IC * OH * OW * FH * FW * 2;
  477. auto flops = flo / (time_ms * 1e9);
  478. printf("%.3fG FLO, flops %.3fTFLOPS\n", flo / 1e9, flops);
  479. };
  480. run(32, 24, 16, 224, 224, 2, 2, 7, 7, 3, 3);
  481. run(32, 128, 32, 112, 112, 1, 1, 3, 3, 1, 1);
  482. run(32, 128, 128, 56, 56, 1, 1, 3, 3, 1, 1);
  483. run(32, 128, 256, 28, 28, 1, 1, 3, 3, 1, 1);
  484. run(32, 256, 256, 28, 28, 1, 1, 1, 1, 0, 0);
  485. run(32, 256, 256, 28, 28, 2, 2, 3, 3, 1, 1);
  486. run(32, 256, 256, 14, 14, 1, 1, 3, 3, 1, 1);
  487. run(32, 512, 512, 7, 7, 1, 1, 3, 3, 1, 1);
  488. #if !MEGDNN_DISABLE_FLOAT16
  489. run(32, 256, 256, 56, 56, 1, 1, 1, 1, 0, 0, dtype::Float16());
  490. #endif
  491. }
  492. #endif
  493. } // namespace test
  494. } // namespace megdnn
  495. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台