You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

codegen.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. /**
  2. * \file src/jit/test/codegen.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <memory>
  12. #include "./helper.h"
  13. #include "megbrain/jit/executor_opr.h"
  14. #include "megbrain/opr/basic_arith.h"
  15. #include "megbrain/opr/basic_arith_wrapper.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/test/helper.h"
  18. #include "megdnn/dtype.h"
  19. #if MGB_JIT
  20. using namespace mgb;
  21. using namespace jit;
  22. #define FOREACH_CASE(cb) cb(simple) cb(grad)
  23. namespace {
  24. #define def_tag(x) \
  25. struct x {};
  26. FOREACH_CASE(def_tag)
  27. #undef def_tag
  28. #define t(n) n,
  29. using test_types = ::testing::Types<FOREACH_CASE(t) void>;
  30. #undef t
  31. template <typename tag>
  32. void run(Backend backend, CompNode cn);
  33. template <>
  34. void run<simple>(Backend backend, CompNode cn) {
  35. set_backend(backend);
  36. auto graph = ComputingGraph::make();
  37. HostTensorGenerator<> gen;
  38. auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn),
  39. host_x2 = gen({1, 42}, cn);
  40. auto a = opr::Host2DeviceCopy::make(*graph, host_x0),
  41. b = opr::Host2DeviceCopy::make(*graph, host_x1),
  42. c = opr::Host2DeviceCopy::make(*graph, host_x2);
  43. a = opr::TypeCvt::make(a, dtype::Float16{});
  44. auto y = a + b * c;
  45. y = opr::TypeCvt::make(y, dtype::Float16{});
  46. y = opr::TypeCvt::make((y + y.make_scalar_dt(1.f)), dtype::Float32{});
  47. VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()};
  48. auto ig_gen =
  49. std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
  50. for (auto i : get_rev_topo_order(y)) {
  51. if (!i->same_type<opr::Host2DeviceCopy>()) {
  52. ig_gen->add_opr(i);
  53. }
  54. }
  55. auto igraph = ig_gen->generate();
  56. auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
  57. HostTensorND host_y, host_y_jit;
  58. auto func = graph->compile({make_callback_copy(y, host_y),
  59. make_callback_copy(y_jit, host_y_jit)});
  60. func->execute();
  61. MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, 5e-3);
  62. };
  63. template <>
  64. void run<grad>(Backend backend, CompNode cn) {
  65. set_backend(backend);
  66. auto graph = ComputingGraph::make();
  67. HostTensorGenerator<> gen;
  68. auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn),
  69. host_x2 = gen({1, 42}, cn);
  70. auto a = opr::Host2DeviceCopy::make(*graph, host_x0),
  71. b = opr::Host2DeviceCopy::make(*graph, host_x1),
  72. c = opr::Host2DeviceCopy::make(*graph, host_x2);
  73. a = opr::TypeCvt::make(a, dtype::Float16{});
  74. auto y = opr::floor_div(a, opr::abs(b) + 0.1f) * opr::sin(c);
  75. VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()};
  76. auto ig_gen =
  77. std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
  78. for (auto i : get_rev_topo_order(y)) {
  79. if (!i->same_type<opr::Host2DeviceCopy>()) {
  80. ig_gen->add_opr(i);
  81. }
  82. }
  83. auto igraph = ig_gen->generate();
  84. auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
  85. HostTensorND host_y, host_y_jit;
  86. auto func = graph->compile({make_callback_copy(y, host_y),
  87. make_callback_copy(y_jit, host_y_jit)});
  88. func->execute();
  89. MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
  90. auto grad = [loss = opr::reduce_sum(y_jit, y_jit.make_scalar(1))](
  91. SymbolVar x) {
  92. return cg::grad(loss, x, false, false).node();
  93. };
  94. ASSERT_EQ(nullptr, grad(a));
  95. ASSERT_EQ(nullptr, grad(b));
  96. ASSERT_NE(nullptr, grad(c));
  97. };
  98. template <>
  99. void run<void>(Backend, CompNode) {}
  100. #if MGB_JIT_MLIR
  101. void run_mlir(CompNode cn) {
  102. set_backend(Backend::MLIR);
  103. auto graph = ComputingGraph::make();
  104. HostTensorGenerator<dtype::Float32> gen;
  105. auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn),
  106. host_x2 = gen({23, 42}, cn);
  107. auto a = opr::Host2DeviceCopy::make(*graph, host_x0),
  108. b = opr::Host2DeviceCopy::make(*graph, host_x1),
  109. c = opr::Host2DeviceCopy::make(*graph, host_x2);
  110. auto y = a + b * c + 0.3f;
  111. auto ig_gen =
  112. std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
  113. for (auto i : get_rev_topo_order(y)) {
  114. if (!i->same_type<opr::Host2DeviceCopy>()) {
  115. ig_gen->add_opr(i);
  116. }
  117. }
  118. auto igraph = ig_gen->generate();
  119. auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
  120. HostTensorND host_y, host_y_jit;
  121. auto func = graph->compile({make_callback_copy(y, host_y),
  122. make_callback_copy(y_jit, host_y_jit)});
  123. func->execute();
  124. MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
  125. }
  126. void run_mlir_broadcast(CompNode cn) {
  127. set_backend(Backend::MLIR);
  128. auto graph = ComputingGraph::make();
  129. HostTensorGenerator<dtype::Float32> gen;
  130. auto host_x0 = gen({10, 20, 5, 6}, cn), host_x1 = gen({1, 20, 1, 1}, cn),
  131. host_x2 = gen({10, 1, 5, 1}, cn), host_x3 = gen({10, 1, 1, 1}, cn);
  132. auto a = opr::Host2DeviceCopy::make(*graph, host_x0),
  133. b = opr::Host2DeviceCopy::make(*graph, host_x1),
  134. c = opr::Host2DeviceCopy::make(*graph, host_x2),
  135. d = opr::Host2DeviceCopy::make(*graph, host_x3);
  136. auto y =
  137. opr::Elemwise::make({a, b, c}, opr::Elemwise::Mode::FUSE_MUL_ADD3) +
  138. opr::Elemwise::make({d}, opr::Elemwise::Mode::ABS) - 0.3f;
  139. auto ig_gen =
  140. std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
  141. for (auto i : get_rev_topo_order(y)) {
  142. if (!i->same_type<opr::Host2DeviceCopy>()) {
  143. ig_gen->add_opr(i);
  144. }
  145. }
  146. auto igraph = ig_gen->generate();
  147. auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
  148. HostTensorND host_y, host_y_jit;
  149. auto func = graph->compile({make_callback_copy(y, host_y),
  150. make_callback_copy(y_jit, host_y_jit)});
  151. func->execute();
  152. MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
  153. }
  154. struct MlirTestOpt {
  155. float low;
  156. float high;
  157. float maxerr;
  158. };
  159. struct MlirTestOpt get_mode_opt(opr::Elemwise::Mode mode) {
  160. struct MlirTestOpt opt = {0, 1, 1e-6};
  161. if (mode == opr::Elemwise::Mode::ABS) {
  162. opt.low = -10;
  163. opt.high = 10;
  164. } else if (mode == opr::Elemwise::Mode::LOG) {
  165. opt.low = 0.1;
  166. opt.high = 4;
  167. } else if (mode == opr::Elemwise::Mode::ERF or
  168. mode == opr::Elemwise::Mode::ERFC) {
  169. opt.low = -5;
  170. opt.high = 5;
  171. } else if (mode == opr::Elemwise::Mode::ERFINV) {
  172. opt.low = -0.999;
  173. opt.high = 0.999;
  174. opt.maxerr = 1e-4;
  175. } else if (mode == opr::Elemwise::Mode::ERFCINV) {
  176. opt.low = 0.001;
  177. opt.high = 1.999;
  178. opt.maxerr = 1e-4;
  179. }
  180. return opt;
  181. }
  182. template <typename tag, int arity>
  183. void run_mlir_mode(CompNode cn) {
  184. set_backend(Backend::MLIR);
  185. auto graph = ComputingGraph::make();
  186. auto opt = get_mode_opt(tag::mode);
  187. HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen(opt.low,
  188. opt.high);
  189. SmallVector<std::shared_ptr<HostTensorND>> hosts;
  190. VarNodeArray input_vars;
  191. for (int i = 0; i < arity; i++) {
  192. hosts.push_back(gen({2323, 4242}, cn));
  193. input_vars.push_back(
  194. opr::Host2DeviceCopy::make(*graph, hosts[i]).node());
  195. }
  196. auto y = opr::Elemwise::make(input_vars, tag::mode);
  197. auto ig_gen =
  198. std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
  199. for (auto i : get_rev_topo_order(y)) {
  200. if (!i->template same_type<opr::Host2DeviceCopy>()) {
  201. ig_gen->add_opr(i);
  202. }
  203. }
  204. auto igraph = ig_gen->generate();
  205. auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
  206. HostTensorND host_y, host_y_jit;
  207. auto func = graph->compile({make_callback_copy(y, host_y),
  208. make_callback_copy(y_jit, host_y_jit)});
  209. func->execute();
  210. MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, opt.maxerr);
  211. }
  212. #endif
  213. } // anonymous namespace
  214. /* ===================== TestJITHalideCodeGenCude ===================== */
  215. #if MGB_JIT_HALIDE
  216. template <typename tag>
  217. class TestJITHalideCodeGenCuda : public ::testing::Test {};
  218. TYPED_TEST_CASE(TestJITHalideCodeGenCuda, test_types);
  219. TYPED_TEST(TestJITHalideCodeGenCuda, run) {
  220. REQUIRE_GPU(1);
  221. run<TypeParam>(Backend::HALIDE, CompNode::load("gpu0"));
  222. }
  223. #endif
  224. /* ===================== TestJITNvrtcCodeGen ===================== */
  225. template <typename tag>
  226. class TestJITNvrtcCodeGen : public ::testing::Test {};
  227. TYPED_TEST_CASE(TestJITNvrtcCodeGen, test_types);
  228. TYPED_TEST(TestJITNvrtcCodeGen, run) {
  229. REQUIRE_GPU(1);
  230. run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0"));
  231. }
  232. /* ===================== TestJITMlirCodeGen ===================== */
  233. #if MGB_JIT_MLIR
  234. TEST(TestJITMlirCodeGen, Basic) {
  235. auto cn = CompNode::load("cpu0");
  236. run_mlir(cn);
  237. run_mlir_broadcast(cn);
  238. }
  239. TEST(TestJITMlirCodeGen, BasicGPU) {
  240. REQUIRE_GPU(1);
  241. auto cn = CompNode::load("gpu0");
  242. run_mlir(cn);
  243. run_mlir_broadcast(cn);
  244. }
  245. /* ===================== TestJITMlirUnaryElemwise ===================== */
  246. // clang-format off
  247. #define FOREACH_UNARY_MODE(cb) \
  248. cb(RELU) \
  249. cb(ABS) \
  250. cb(NEGATE) \
  251. cb(ACOS) \
  252. cb(ASIN) \
  253. cb(CEIL) \
  254. cb(EXP) \
  255. cb(FLOOR) \
  256. cb(LOG) \
  257. cb(LOG1P) \
  258. cb(SIN) \
  259. cb(COS) \
  260. cb(TANH) \
  261. cb(FAST_TANH) \
  262. cb(H_SWISH) \
  263. cb(SIGMOID) \
  264. cb(EXPM1) \
  265. cb(ROUND) \
  266. cb(ERF) \
  267. cb(ERFINV) \
  268. cb(ERFC) \
  269. cb(ERFCINV)
  270. // clang-format on
  271. template <typename tag>
  272. class TestJITMlirUnaryElemwise : public ::testing::Test {};
  273. #define def_tag(x) \
  274. struct x { \
  275. static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \
  276. };
  277. FOREACH_UNARY_MODE(def_tag)
  278. #undef def_tag
  279. #define t(n) n,
  280. using mlir_elemwise_unary_types =
  281. ::testing::Types<FOREACH_UNARY_MODE(t) ABS>;
  282. #undef t
  283. TYPED_TEST_CASE(TestJITMlirUnaryElemwise, mlir_elemwise_unary_types);
  284. #define SKIP_MODE(_mode) \
  285. if (TypeParam::mode == opr::Elemwise::Mode::_mode) { \
  286. printf("skip\n"); \
  287. return; \
  288. }
  289. TYPED_TEST(TestJITMlirUnaryElemwise, run) {
  290. auto cn = CompNode::load("cpu0");
  291. SKIP_MODE(ROUND);
  292. run_mlir_mode<TypeParam, 1>(cn);
  293. }
  294. TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) {
  295. REQUIRE_GPU(1);
  296. auto cn = CompNode::load("gpu0");
  297. SKIP_MODE(ROUND);
  298. run_mlir_mode<TypeParam, 1>(cn);
  299. }
  300. /* ===================== TestJITMlirBinaryElemwise ===================== */
  301. // clang-format off
  302. #define FOREACH_BINARY_MODE(cb) \
  303. cb(ADD) \
  304. cb(FLOOR_DIV) \
  305. cb(MUL) \
  306. cb(MAX) \
  307. cb(MIN) \
  308. cb(MOD) \
  309. cb(SUB) \
  310. cb(TRUE_DIV) \
  311. cb(POW) \
  312. cb(ABS_GRAD) \
  313. cb(SIGMOID_GRAD) \
  314. cb(SWITCH_GT0) \
  315. cb(TANH_GRAD) \
  316. cb(LT) \
  317. cb(LEQ) \
  318. cb(EQ) \
  319. cb(FUSE_ADD_RELU) \
  320. cb(LOG_SUM_EXP) \
  321. cb(FUSE_ADD_TANH) \
  322. cb(FAST_TANH_GRAD) \
  323. cb(FUSE_ADD_SIGMOID) \
  324. cb(H_SWISH_GRAD) \
  325. cb(FUSE_ADD_H_SWISH) \
  326. cb(ATAN2)
  327. // clang-format on
  328. template <typename tag>
  329. class TestJITMlirBinaryElemwise : public ::testing::Test {};
  330. #define def_tag(x) \
  331. struct x { \
  332. static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \
  333. };
  334. FOREACH_BINARY_MODE(def_tag)
  335. #undef def_tag
  336. #define t(n) n,
  337. using mlir_elemwise_binary_types =
  338. ::testing::Types<FOREACH_BINARY_MODE(t) ADD>;
  339. #undef t
  340. TYPED_TEST_CASE(TestJITMlirBinaryElemwise, mlir_elemwise_binary_types);
  341. TYPED_TEST(TestJITMlirBinaryElemwise, run) {
  342. auto cn = CompNode::load("cpu0");
  343. run_mlir_mode<TypeParam, 2>(cn);
  344. }
  345. TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) {
  346. REQUIRE_GPU(1);
  347. auto cn = CompNode::load("gpu0");
  348. SKIP_MODE(MOD);
  349. run_mlir_mode<TypeParam, 2>(cn);
  350. }
  351. /* ===================== TestJITMlirTenaryElemwise ===================== */
  352. // clang-format off
  353. #define FOREACH_TERNARY_MODE(cb) \
  354. cb(COND_LEQ_MOV) \
  355. cb(FUSE_MUL_ADD3) \
  356. // clang-format on
  357. template <typename tag>
  358. class TestJITMlirTernaryElemwise : public ::testing::Test {};
  359. #define def_tag(x) \
  360. struct x { \
  361. static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \
  362. };
  363. FOREACH_TERNARY_MODE(def_tag)
  364. #undef def_tag
  365. #define t(n) n,
  366. using mlir_elemwise_ternary_types =
  367. ::testing::Types<FOREACH_TERNARY_MODE(t) COND_LEQ_MOV>;
  368. #undef t
  369. TYPED_TEST_CASE(TestJITMlirTernaryElemwise, mlir_elemwise_ternary_types);
  370. TYPED_TEST(TestJITMlirTernaryElemwise, run) {
  371. auto cn = CompNode::load("cpu0");
  372. run_mlir_mode<TypeParam, 3>(cn);
  373. }
  374. TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) {
  375. REQUIRE_GPU(1);
  376. auto cn = CompNode::load("gpu0");
  377. run_mlir_mode<TypeParam, 3>(cn);
  378. }
  379. #undef SKIP_MODE
  380. /* ===================== TestJITMlirTypeCvt ===================== */
  381. template <typename itype, typename otype>
  382. void run_typecvt(CompNode cn) {
  383. set_backend(Backend::MLIR);
  384. auto graph = ComputingGraph::make();
  385. HostTensorGenerator<itype, RandomDistribution::UNIFORM> gen(-10, 10);
  386. auto host_x = gen({23, 42}, cn);
  387. auto x = opr::Host2DeviceCopy::make(*graph, host_x);
  388. auto y = opr::TypeCvt::make(x, otype());
  389. auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
  390. for (auto i : get_rev_topo_order(y)) {
  391. if (!i->template same_type<opr::Host2DeviceCopy>()) {
  392. ig_gen->add_opr(i);
  393. }
  394. }
  395. auto igraph = ig_gen->generate();
  396. auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
  397. HostTensorND host_y, host_y_jit;
  398. auto func = graph->compile({make_callback_copy(y, host_y),
  399. make_callback_copy(y_jit, host_y_jit)});
  400. func->execute();
  401. MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
  402. };
  403. #define add_typecvt_gtest(itype, otype) \
  404. TEST(TestJITMlirTypeCvt, itype##_to_##otype) { \
  405. run_typecvt<dtype::itype, dtype::otype>(CompNode::load("cpu0")); \
  406. } \
  407. TEST(TestJITMlirTypeCvt, itype##_to_##otype##_GPU) { \
  408. REQUIRE_GPU(1); \
  409. run_typecvt<dtype::itype, dtype::otype>(CompNode::load("gpu0")); \
  410. }
  411. #if !MEGDNN_DISABLE_FLOAT16
  412. // TODO: the support for f16 and bf16 is currently not complete in mlir
  413. // FPExtOp
  414. // add_typecvt_gtest(Float16, Float32);
  415. // add_typecvt_gtest(BFloat16, Float32);
  416. // add_typecvt_gtest(Float16, BFloat16);
  417. // FPTruncOp
  418. // add_typecvt_gtest(Float32, Float16);
  419. // add_typecvt_gtest(Float32, BFloat16);
  420. // add_typecvt_gtest(Float16, BFloat16);
  421. #endif
  422. // FPToSIOp
  423. add_typecvt_gtest(Float32, Int8);
  424. add_typecvt_gtest(Float32, Int16);
  425. add_typecvt_gtest(Float32, Int32);
  426. // FPToUIOp
  427. add_typecvt_gtest(Float32, Uint8);
  428. // SIToFPOp
  429. add_typecvt_gtest(Int8, Float32);
  430. add_typecvt_gtest(Int16, Float32);
  431. add_typecvt_gtest(Int32, Float32);
  432. // UIToFPOp
  433. add_typecvt_gtest(Uint8, Float32);
  434. #undef add_typecvt_gtest
  435. /* ===================== TestJITMlirDimshuffle ===================== */
  436. void run_dimshuffle(CompNode cn, TensorShape ishape,
  437. const std::vector<int>& pattern) {
  438. set_backend(Backend::MLIR);
  439. auto graph = ComputingGraph::make();
  440. HostTensorGenerator<> gen;
  441. auto host_x = gen(ishape, cn);
  442. auto x = opr::Host2DeviceCopy::make(*graph, host_x);
  443. auto y = opr::Dimshuffle::make(x, pattern);
  444. auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
  445. for (auto i : get_rev_topo_order(y)) {
  446. if (!i->template same_type<opr::Host2DeviceCopy>()) {
  447. ig_gen->add_opr(i);
  448. }
  449. }
  450. auto igraph = ig_gen->generate();
  451. auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
  452. HostTensorND host_y, host_y_jit;
  453. auto func = graph->compile({make_callback_copy(y, host_y),
  454. make_callback_copy(y_jit, host_y_jit)});
  455. func->execute();
  456. MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
  457. }
  458. void run_dimshuffle_cases(CompNode cn) {
  459. run_dimshuffle(cn, {3, 4, 5}, {2, 0, 1});
  460. run_dimshuffle(cn, {3, 4, 5}, {1, -1, 0, 2});
  461. }
  462. TEST(TestJITMlirDimshuffle, Basic) {
  463. run_dimshuffle_cases(CompNode::load("cpu0"));
  464. }
  465. TEST(TestJITMlirDimshuffle, BasicGPU) {
  466. REQUIRE_GPU(1);
  467. run_dimshuffle_cases(CompNode::load("gpu0"));
  468. }
  469. #endif // MGB_JIT_MLIR
  470. #endif // MGB_JIT
  471. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台