You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. /**
  2. * \file dnn/test/rocm/convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/convolution.h"
  12. #include "hcc_detail/hcc_defs_prologue.h"
  13. #include "megdnn/opr_param_defs.h"
  14. #include "megdnn/oprs.h"
  15. #include "test/common/benchmarker.h"
  16. #include "test/common/checker.h"
  17. #include "test/common/rng.h"
  18. #include "test/common/tensor.h"
  19. #include "test/common/workspace_wrapper.h"
  20. #include "test/rocm/benchmarker.h"
  21. #include "test/rocm/fixture.h"
  22. #include "src/common/utils.h"
  23. #include "src/rocm/utils.h"
  24. namespace megdnn {
  25. namespace test {
  26. namespace convolution {
  27. std::vector<TestArg> get_args_0() {
  28. std::vector<TestArg> args, tmp_args;
  29. #define ADD_ARGS(NAME) \
  30. tmp_args = get_args_##NAME(); \
  31. args.insert(args.end(), tmp_args.begin(), tmp_args.end());
  32. ADD_ARGS(common)
  33. ADD_ARGS(padding)
  34. ADD_ARGS(large_channel)
  35. ADD_ARGS(1x1)
  36. ADD_ARGS(large_filter)
  37. ADD_ARGS(exhaustive_search)
  38. ADD_ARGS(4x4)
  39. ADD_ARGS(large_channels)
  40. ADD_ARGS(x86_direct_case_2)
  41. ADD_ARGS(cudnn_5_1_failures)
  42. ADD_ARGS(x86_winograd_algorithm)
  43. ADD_ARGS(BRAIN_481)
  44. #undef ADD_ARGS
  45. return args;
  46. }
  47. std::vector<TestArg> get_args_1() {
  48. return get_args_fallback_templated_impl();
  49. }
  50. std::vector<TestArg> get_args_2() {
  51. return get_args_fallback_non_templated_impl();
  52. }
  53. std::vector<TestArg> get_group_conv_args() {
  54. std::vector<TestArg> args;
  55. for (size_t batch_size : {2}) {
  56. for (size_t ih : {23}) {
  57. for (size_t iw : {ih + 1}) {
  58. for (size_t icpg : {2, 4, 8}) {
  59. for (size_t ocpg : {4, 8}) {
  60. for (size_t fh : {3, 5, 7}) {
  61. for (size_t fw : {fh, fh + 1}) {
  62. for (size_t ph : {0_z, size_t{fw / 2}}) {
  63. for (size_t sh : {1, 2}) {
  64. for (size_t dh : {1, 2}) {
  65. param::Convolution param;
  66. size_t groups = 2;
  67. param.sparse = param::Convolution::
  68. Sparse::GROUP;
  69. param.mode = param::Convolution::
  70. Mode::CROSS_CORRELATION;
  71. param.stride_h = param.stride_w =
  72. sh;
  73. param.pad_h = param.pad_w = ph;
  74. param.dilate_h = param.dilate_w =
  75. dh;
  76. args.emplace_back(
  77. param,
  78. TensorShape{batch_size,
  79. icpg * groups,
  80. ih, iw},
  81. TensorShape{groups, ocpg,
  82. icpg, fh, fw});
  83. }
  84. }
  85. }
  86. }
  87. }
  88. }
  89. }
  90. }
  91. }
  92. }
  93. return args;
  94. }
  95. } // namespace convolution
  96. TEST_F(ROCM, CONV_GROUP) {
  97. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  98. using namespace convolution;
  99. std::vector<TestArg> args = get_group_conv_args();
  100. Checker<ConvolutionForward> checker(handle_rocm());
  101. NormalRNG default_rng;
  102. for (auto&& arg : args) {
  103. checker.set_dtype(0, dtype::Float32())
  104. .set_dtype(1, dtype::Float32())
  105. .set_rng(0, &default_rng)
  106. .set_rng(1, &default_rng)
  107. .set_epsilon(1e-3)
  108. .set_param(arg.param)
  109. .execs({arg.src, arg.filter, {}});
  110. }
  111. }
  112. TEST_F(ROCM, CONV_CHANNWISE) {
  113. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  114. using namespace convolution;
  115. std::vector<TestArg> args = get_chanwise_args();
  116. Checker<ConvolutionForward> checker(handle_rocm());
  117. NormalRNG default_rng;
  118. for (auto&& arg : args) {
  119. using Mode = param::Convolution::Mode;
  120. //! non xcorr not supported for miopen
  121. if (arg.param.mode == Mode::CONVOLUTION) {
  122. continue;
  123. }
  124. checker.set_dtype(0, dtype::Float32())
  125. .set_dtype(1, dtype::Float32())
  126. .set_rng(0, &default_rng)
  127. .set_rng(1, &default_rng)
  128. .set_epsilon(1e-3)
  129. .set_param(arg.param)
  130. .execs({arg.src, arg.filter, {}});
  131. }
  132. }
  133. TEST_F(ROCM, CONVOLUTION_FORWARD_0) {
  134. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  135. using namespace convolution;
  136. std::vector<TestArg> args = get_args_0();
  137. Checker<ConvolutionForward> checker(handle_rocm());
  138. NormalRNG default_rng;
  139. for (auto&& arg : args) {
  140. using Mode = param::Convolution::Mode;
  141. //! non xcorr not supported for miopen
  142. if (arg.param.mode == Mode::CONVOLUTION) {
  143. continue;
  144. }
  145. float scale =
  146. 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  147. UniformFloatRNG rng(scale, 2 * scale);
  148. checker.set_dtype(0, dtype::Float32())
  149. .set_dtype(1, dtype::Float32())
  150. .set_dtype(2, dtype::Float32())
  151. .set_rng(0, &default_rng)
  152. .set_rng(1, &default_rng)
  153. .set_epsilon(1e-3)
  154. .set_param(arg.param)
  155. .execs({arg.src, arg.filter, {}});
  156. #if !MEGDNN_DISABLE_FLOAT16
  157. checker.set_dtype(0, dtype::Float16())
  158. .set_dtype(1, dtype::Float16())
  159. .set_dtype(2, dtype::Float16())
  160. .set_rng(0, &rng)
  161. .set_rng(1, &rng)
  162. .set_epsilon(1e-1)
  163. .set_param(arg.param)
  164. .execs({arg.src, arg.filter, {}});
  165. #endif
  166. }
  167. }
  168. TEST_F(ROCM, CONVOLUTION_FORWARD_1) {
  169. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  170. using namespace convolution;
  171. std::vector<TestArg> args = get_args_1();
  172. Checker<ConvolutionForward> checker(handle_rocm());
  173. NormalRNG default_rng;
  174. for (auto&& arg : args) {
  175. using Mode = param::Convolution::Mode;
  176. //! non xcorr not supported for miopen
  177. if (arg.param.mode == Mode::CONVOLUTION) {
  178. continue;
  179. }
  180. float scale =
  181. 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  182. UniformFloatRNG rng(scale, 2 * scale);
  183. checker.set_dtype(0, dtype::Float32())
  184. .set_dtype(1, dtype::Float32())
  185. .set_dtype(2, dtype::Float32())
  186. .set_rng(0, &default_rng)
  187. .set_rng(1, &default_rng)
  188. .set_epsilon(1e-3)
  189. .set_param(arg.param)
  190. .execs({arg.src, arg.filter, {}});
  191. #if !MEGDNN_DISABLE_FLOAT16
  192. checker.set_dtype(0, dtype::Float16())
  193. .set_dtype(1, dtype::Float16())
  194. .set_dtype(2, dtype::Float16())
  195. .set_rng(0, &rng)
  196. .set_rng(1, &rng)
  197. .set_epsilon(1e-1)
  198. .set_param(arg.param)
  199. .execs({arg.src, arg.filter, {}});
  200. #endif
  201. }
  202. }
  203. TEST_F(ROCM, CONVOLUTION_FORWARD_2) {
  204. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  205. using namespace convolution;
  206. std::vector<TestArg> args = get_args_2();
  207. Checker<ConvolutionForward> checker(handle_rocm());
  208. NormalRNG default_rng;
  209. for (auto&& arg : args) {
  210. using Mode = param::Convolution::Mode;
  211. //! non xcorr not supported for miopen
  212. if (arg.param.mode == Mode::CONVOLUTION) {
  213. continue;
  214. }
  215. float scale =
  216. 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  217. UniformFloatRNG rng(scale, 2 * scale);
  218. checker.set_dtype(0, dtype::Float32())
  219. .set_dtype(1, dtype::Float32())
  220. .set_dtype(2, dtype::Float32())
  221. .set_rng(0, &default_rng)
  222. .set_rng(1, &default_rng)
  223. .set_epsilon(1e-3)
  224. .set_param(arg.param)
  225. .execs({arg.src, arg.filter, {}});
  226. #if !MEGDNN_DISABLE_FLOAT16
  227. checker.set_dtype(0, dtype::Float16())
  228. .set_dtype(1, dtype::Float16())
  229. .set_dtype(2, dtype::Float16())
  230. .set_rng(0, &rng)
  231. .set_rng(1, &rng)
  232. .set_epsilon(1e-1)
  233. .set_param(arg.param)
  234. .execs({arg.src, arg.filter, {}});
  235. #endif
  236. }
  237. }
  238. TEST_F(ROCM, CONVOLUTION_1X1_FORWARD) {
  239. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  240. using namespace convolution;
  241. std::vector<TestArg> args = get_1x1_args();
  242. Checker<ConvolutionForward> checker(handle_rocm());
  243. NormalRNG default_rng;
  244. for (auto&& arg : args) {
  245. float scale =
  246. 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  247. UniformFloatRNG rng(scale, 2 * scale);
  248. checker.set_dtype(0, dtype::Float32())
  249. .set_dtype(1, dtype::Float32())
  250. .set_rng(0, &default_rng)
  251. .set_rng(1, &default_rng)
  252. .set_epsilon(1e-3)
  253. .set_param(arg.param)
  254. .execs({arg.src, arg.filter, {}});
  255. }
  256. }
  257. #if MEGDNN_WITH_BENCHMARK
  258. TEST_F(ROCM, CONVOLUTION_1X1_FORWARD_ALL_ALGOS) {
  259. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  260. using namespace convolution;
  261. OprProxy<ConvolutionForward> proxy{true};
  262. proxy.warmup_times = 1;
  263. proxy.exec_times = 10;
  264. Benchmarker<ConvolutionForward> checker(handle_rocm());
  265. checker.set_times(1);
  266. auto get_computation = [&](TestArg arg) -> float {
  267. megdnn_assert(arg.param.format == param::Convolution::Format::NCHW);
  268. size_t N = arg.src[0], IC = arg.src[1], IH = arg.src[2],
  269. IW = arg.src[3], OC = arg.filter[0], FH = arg.filter[2],
  270. FW = arg.filter[3], SH = arg.param.stride_h,
  271. SW = arg.param.stride_w, PH = arg.param.pad_h,
  272. PW = arg.param.pad_w;
  273. size_t OH = infer_conv_shape(IH, FH, SH, PH);
  274. size_t OW = infer_conv_shape(IW, FW, SW, PW);
  275. float flops = 2.0 * N * OC * OH * OW * IC * FH * FW;
  276. return flops;
  277. };
  278. std::vector<TestArg> args = get_1x1_args();
  279. NormalRNG default_rng;
  280. for (auto&& arg : args) {
  281. float scale =
  282. 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  283. UniformFloatRNG rng(scale, 2 * scale);
  284. checker.set_proxy(proxy)
  285. .set_dtype(0, dtype::Float32())
  286. .set_dtype(1, dtype::Float32())
  287. .set_rng(0, &default_rng)
  288. .set_rng(1, &default_rng)
  289. .set_param(arg.param);
  290. float time_in_ms = checker.execs({arg.src, arg.filter, {}});
  291. float flops = get_computation(arg);
  292. printf("inp=%s,flt=%s,flops=%.2fGflo,time = %.2f ms, perf = %.2f "
  293. "GFLOPS\n",
  294. arg.src.to_string().c_str(), arg.filter.to_string().c_str(),
  295. flops / 1e9, time_in_ms, flops / (1e6 * time_in_ms));
  296. }
  297. }
  298. #endif
  299. TEST_F(ROCM, CONVOLUTION_BACKWARD_DATA_0) {
  300. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  301. using namespace convolution;
  302. std::vector<TestArg> args = get_args_0();
  303. Checker<ConvolutionBackwardData> checker(handle_rocm());
  304. NormalRNG default_rng;
  305. for (auto&& arg : args) {
  306. using Mode = param::Convolution::Mode;
  307. //! non xcorr not supported for miopen
  308. if (arg.param.mode == Mode::CONVOLUTION) {
  309. continue;
  310. }
  311. float scale =
  312. 1.0f / sqrt(arg.filter[0] * arg.filter[2] * arg.filter[3]);
  313. UniformFloatRNG rng(scale, 2 * scale);
  314. auto src = TensorLayout(arg.src, dtype::Float32());
  315. auto filter = TensorLayout(arg.filter, dtype::Float32());
  316. TensorLayout dst;
  317. {
  318. auto opr = handle_rocm()->create_operator<Convolution>();
  319. opr->param() = arg.param;
  320. opr->deduce_layout(src, filter, dst);
  321. }
  322. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  323. checker.set_rng(0, &default_rng)
  324. .set_rng(1, &default_rng)
  325. .set_epsilon(1e-3)
  326. .set_param(arg.param)
  327. .exec(TensorLayoutArray{filter, dst, src});
  328. #if !MEGDNN_DISABLE_FLOAT16
  329. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  330. checker.set_rng(0, &rng)
  331. .set_rng(1, &rng)
  332. .set_epsilon(1e-1)
  333. .set_param(arg.param)
  334. .exec(TensorLayoutArray{filter, dst, src});
  335. #endif
  336. }
  337. }
  338. TEST_F(ROCM, CONVOLUTION_BACKWARD_DATA_1) {
  339. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  340. using namespace convolution;
  341. std::vector<TestArg> args = get_args_1();
  342. Checker<ConvolutionBackwardData> checker(handle_rocm());
  343. NormalRNG default_rng;
  344. for (auto&& arg : args) {
  345. using Mode = param::Convolution::Mode;
  346. //! non xcorr not supported for miopen
  347. if (arg.param.mode == Mode::CONVOLUTION) {
  348. continue;
  349. }
  350. float scale =
  351. 1.0f / sqrt(arg.filter[0] * arg.filter[2] * arg.filter[3]);
  352. UniformFloatRNG rng(scale, 2 * scale);
  353. auto src = TensorLayout(arg.src, dtype::Float32());
  354. auto filter = TensorLayout(arg.filter, dtype::Float32());
  355. TensorLayout dst;
  356. {
  357. auto opr = handle_rocm()->create_operator<Convolution>();
  358. opr->param() = arg.param;
  359. opr->deduce_layout(src, filter, dst);
  360. }
  361. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  362. checker.set_rng(0, &default_rng)
  363. .set_rng(1, &default_rng)
  364. .set_epsilon(1e-3)
  365. .set_param(arg.param)
  366. .exec(TensorLayoutArray{filter, dst, src});
  367. #if !MEGDNN_DISABLE_FLOAT16
  368. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  369. checker.set_rng(0, &rng)
  370. .set_rng(1, &rng)
  371. .set_epsilon(1e-1)
  372. .set_param(arg.param)
  373. .exec(TensorLayoutArray{filter, dst, src});
  374. #endif
  375. }
  376. }
  377. TEST_F(ROCM, CONVOLUTION_BACKWARD_DATA_2) {
  378. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  379. using namespace convolution;
  380. std::vector<TestArg> args = get_args_2();
  381. Checker<ConvolutionBackwardData> checker(handle_rocm());
  382. NormalRNG default_rng;
  383. for (auto&& arg : args) {
  384. using Mode = param::Convolution::Mode;
  385. //! non xcorr not supported for miopen
  386. if (arg.param.mode == Mode::CONVOLUTION) {
  387. continue;
  388. }
  389. float scale =
  390. 1.0f / sqrt(arg.filter[0] * arg.filter[2] * arg.filter[3]);
  391. UniformFloatRNG rng(scale, 2 * scale);
  392. auto src = TensorLayout(arg.src, dtype::Float32());
  393. auto filter = TensorLayout(arg.filter, dtype::Float32());
  394. TensorLayout dst;
  395. {
  396. auto opr = handle_rocm()->create_operator<Convolution>();
  397. opr->param() = arg.param;
  398. opr->deduce_layout(src, filter, dst);
  399. }
  400. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  401. checker.set_rng(0, &default_rng)
  402. .set_rng(1, &default_rng)
  403. .set_epsilon(1e-3)
  404. .set_param(arg.param)
  405. .exec(TensorLayoutArray{filter, dst, src});
  406. #if !MEGDNN_DISABLE_FLOAT16
  407. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  408. checker.set_rng(0, &rng)
  409. .set_rng(1, &rng)
  410. .set_epsilon(1e-1)
  411. .set_param(arg.param)
  412. .exec(TensorLayoutArray{filter, dst, src});
  413. #endif
  414. }
  415. }
  416. TEST_F(ROCM, DISABLED_CONVOLUTION_BACKWARD_FILTER) {
  417. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), false);
  418. using namespace convolution;
  419. std::vector<TestArg> args = get_args();
  420. Checker<ConvolutionBackwardFilter> checker(handle_rocm());
  421. NormalRNG default_rng;
  422. bool f16_checked = false;
  423. MEGDNN_MARK_USED_VAR(f16_checked);
  424. for (auto&& arg : args) {
  425. using Mode = param::Convolution::Mode;
  426. //! non xcorr not supported for miopen
  427. if (arg.param.mode == Mode::CONVOLUTION) {
  428. continue;
  429. }
  430. auto src = TensorLayout(arg.src, dtype::Float32());
  431. auto filter = TensorLayout(arg.filter, dtype::Float32());
  432. TensorLayout dst;
  433. {
  434. auto opr = handle_rocm()->create_operator<Convolution>();
  435. opr->param() = arg.param;
  436. opr->deduce_layout(src, filter, dst);
  437. }
  438. float scale = 1.0f / sqrt(dst[2] * dst[3]);
  439. UniformFloatRNG rng(scale, 2 * scale);
  440. src.dtype = dst.dtype = filter.dtype = dtype::Float32();
  441. checker.set_rng(0, &default_rng)
  442. .set_rng(1, &default_rng)
  443. .set_epsilon(1e-3)
  444. .set_param(arg.param)
  445. .exec(TensorLayoutArray{src, dst, filter});
  446. #if !MEGDNN_DISABLE_FLOAT16
  447. //! FIXME: MIOpen convolution backward weights for FP16 with bugs
  448. #if 0
  449. // reduce on large f16 array may introduce significant error
  450. if (dst.total_nr_elems() >= 1000 && f16_checked)
  451. continue;
  452. f16_checked = true;
  453. src.dtype = dst.dtype = filter.dtype = dtype::Float16();
  454. checker.set_rng(0, &rng)
  455. .set_rng(1, &rng)
  456. .set_epsilon(1e-1)
  457. .set_param(arg.param)
  458. .exec(TensorLayoutArray{src, dst, filter});
  459. #endif
  460. #endif
  461. }
  462. }
  463. #if MEGDNN_WITH_BENCHMARK
  464. TEST_F(ROCM, CONV_FWD_BENCHMARK) {
  465. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  466. auto benchmarker = ROCMBenchmarker<ConvolutionForward>(handle_rocm(),
  467. handle_naive(false));
  468. auto run = [&](size_t N, size_t OC, size_t IC, size_t IH, size_t IW,
  469. size_t SH = 1, size_t SW = 1, size_t FH = 1, size_t FW = 1,
  470. size_t PH = 0, size_t PW = 0,
  471. DType dtype = dtype::Float32()) {
  472. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  473. benchmarker.set_display(true);
  474. ConvolutionForward::Param param;
  475. param.stride_h = SH;
  476. param.stride_w = SW;
  477. param.pad_h = PH;
  478. param.pad_w = PW;
  479. benchmarker.set_param(param);
  480. size_t OH = (IH - FH + 2 * PH) / SH + 1;
  481. size_t OW = (IW - FW + 2 * PW) / SW + 1;
  482. // warm up find best algo
  483. benchmarker.execs({{N, IC, IH, IW}, {OC, IC, FH, FW}, {N, OC, OH, OW}});
  484. // do actual benchmark
  485. auto time_ms = benchmarker.execs(
  486. {{N, IC, IH, IW}, {OC, IC, FH, FW}, {N, OC, OH, OW}});
  487. auto flo = (double)N * OC * IC * OH * OW * FH * FW * 2;
  488. auto flops = flo / (time_ms * 1e9);
  489. printf("%.3fG FLO, flops %.3fTFLOPS\n", flo / 1e9, flops);
  490. };
  491. run(32, 24, 16, 224, 224, 2, 2, 7, 7, 3, 3);
  492. run(32, 128, 32, 112, 112, 1, 1, 3, 3, 1, 1);
  493. run(32, 128, 128, 56, 56, 1, 1, 3, 3, 1, 1);
  494. run(32, 128, 256, 28, 28, 1, 1, 3, 3, 1, 1);
  495. run(32, 256, 256, 28, 28, 1, 1, 1, 1, 0, 0);
  496. run(32, 256, 256, 28, 28, 2, 2, 3, 3, 1, 1);
  497. run(32, 256, 256, 14, 14, 1, 1, 3, 3, 1, 1);
  498. run(32, 512, 512, 7, 7, 1, 1, 3, 3, 1, 1);
  499. #if !MEGDNN_DISABLE_FLOAT16
  500. run(32, 256, 256, 56, 56, 1, 1, 1, 1, 0, 0, dtype::Float16());
  501. #endif
  502. }
  503. #endif
  504. } // namespace test
  505. } // namespace megdnn
  506. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台