You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. /**
  2. * \file dnn/src/fallback/convolution/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/fallback/convolution/algos.h"
  13. #include "src/common/opr_delegate.h"
  14. #include "src/fallback/convolution/col2img_helper.h"
  15. #include "src/fallback/convolution/run_conv.h"
  16. #include "midout.h"
  17. using namespace megdnn;
  18. using namespace fallback;
  19. MIDOUT_DECL(megdnn_fallback_conv)
  20. namespace {
  21. template <typename T>
  22. void incr_ptr(T*& dst, ptrdiff_t delta) {
  23. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  24. }
  25. using NCBKernSizeParam = ConvolutionBackwardDataImpl::NCBKernSizeParam;
  26. using NCBKernParam = ConvolutionBackwardDataImpl::NCBKernParam;
  27. Relayout* get_relayout_opr() {
  28. static CpuOprDelegationStorage<> storage;
  29. return storage.get<Relayout>();
  30. }
  31. MatrixMul* get_matmul_opr(const NCBKernSizeParam& param) {
  32. using ConvCM = param::Convolution::ComputeMode;
  33. using MmCM = param::MatrixMul::ComputeMode;
  34. static CpuOprDelegationStorage<2> storage;
  35. switch (param.compute_mode) {
  36. default:
  37. return storage.get<MatrixMul, 0>({});
  38. case ConvCM::FLOAT32: {
  39. MatrixMul::Param p;
  40. p.compute_mode = MmCM::FLOAT32;
  41. return storage.get<MatrixMul, 1>(p);
  42. }
  43. }
  44. }
  45. WorkspaceBundle get_bundle(const NCBKernSizeParam& param) {
  46. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  47. MEGDNN_MARK_USED_VAR(N);
  48. MEGDNN_MARK_USED_VAR(OH);
  49. MEGDNN_MARK_USED_VAR(OW);
  50. bool can_matrix_mul_direct =
  51. (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
  52. // temp space to store unrolled matrix
  53. // workspace for matrix mul opr
  54. // workspace for relayout opr
  55. size_t part0, part1, part2;
  56. if (can_matrix_mul_direct) {
  57. part0 = 0;
  58. } else {
  59. part0 = (IC * FH * FW * IH * IW) * param.grad_type.size();
  60. }
  61. part2 = (OC * IC * FH * FW) * param.filter_type.size();
  62. {
  63. TensorLayout A_, B_, C_;
  64. A_ = TensorLayout({IC * FH * FW, OC}, param.filter_type);
  65. B_ = TensorLayout({OC, IH * IW}, param.diff_type);
  66. C_ = TensorLayout({IC * FH * FW, IH * IW}, param.grad_type);
  67. part1 = get_matmul_opr(param)->get_workspace_in_bytes(A_, B_, C_);
  68. }
  69. return {nullptr, {part0, part1, part2}};
  70. }
  71. template <typename ftype, typename dtype, typename gtype>
  72. void kern_matmul(const NCBKernParam& param) {
  73. bool is_xcorr = !param.filter_meta.should_flip;
  74. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  75. auto bundle = get_bundle(param);
  76. bundle.set(param.workspace_ptr);
  77. bool is1X1 =
  78. (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
  79. typedef void (*Func1)(const gtype*, gtype*, int, int, int, int, int, int,
  80. int);
  81. typedef void (*Func2)(const gtype*, gtype*, int, int, int, int, int, int,
  82. int, int, int, int, int);
  83. Func1 f1 = nullptr;
  84. Func2 f2 = nullptr;
  85. if (is_xcorr) {
  86. f1 = col2img<true>;
  87. f2 = col2img_stride_padding<true>;
  88. } else {
  89. f1 = col2img<false>;
  90. f2 = col2img_stride_padding<false>;
  91. }
  92. ftype* filter = const_cast<ftype*>(param.filter<ftype>());
  93. TensorND A_src, A_dst;
  94. {
  95. A_src.layout = TensorLayout({IC * FH * FW, OC},
  96. {static_cast<std::ptrdiff_t>(1),
  97. static_cast<std::ptrdiff_t>(IC * FH * FW)},
  98. param.filter_type);
  99. A_src.raw_ptr = static_cast<void*>(filter);
  100. A_dst.layout = TensorLayout({IC * FH * FW, OC}, param.filter_type);
  101. A_dst.raw_ptr = static_cast<void*>(bundle.get(2));
  102. // TODO Should be removed once armv8 convolution support transpose.
  103. get_relayout_opr()->exec(A_src, A_dst, inplace_cpu_handle().get());
  104. }
  105. for (size_t n = 0; n < N; ++n) {
  106. gtype *C_src, *C_dst;
  107. dtype* diff =
  108. const_cast<dtype*>(param.diff<dtype>() + n * param.inp_bs);
  109. gtype* grad = param.grad<gtype>() + n * param.out_bs;
  110. if (is1X1) {
  111. C_src = grad;
  112. } else {
  113. C_src = static_cast<gtype*>(bundle.get(0));
  114. }
  115. {
  116. TensorND B_, C_;
  117. B_.layout = TensorLayout({OC, IH * IW}, param.diff_type);
  118. B_.raw_ptr = static_cast<void*>(diff);
  119. C_.layout = TensorLayout({IC * FH * FW, IH * IW}, param.grad_type);
  120. C_.raw_ptr = C_src;
  121. Workspace workspace(static_cast<dt_byte*>(bundle.get(1)),
  122. bundle.get_size(1));
  123. get_matmul_opr(param)->exec(A_dst, B_, C_, workspace);
  124. }
  125. if (!is1X1) {
  126. C_dst = grad;
  127. std::memset(C_dst, 0, param.grad_type.size() * IC * OH * OW);
  128. if (PH == 0 && PW == 0 && SH == 1 && SW == 1) {
  129. f1(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW);
  130. } else {
  131. f2(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW, SH, SW, PH, PW);
  132. }
  133. }
  134. }
  135. }
  136. void kern_direct(const NCBKernParam& param) {
  137. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  138. auto diff = param.diff<float>(), filter = param.filter<float>();
  139. auto grad = param.grad<float>();
  140. for (size_t n = 0; n < N; ++n) {
  141. convolution::run_conv_backward_data(
  142. diff + n * param.inp_bs, filter, grad + n * param.out_bs,
  143. param.workspace_ptr, IH, IW, IC, FH, FW, OH, OW, OC, PH, PW, SH,
  144. SW, !param.filter_meta.should_flip);
  145. }
  146. }
  147. } // namespace
  148. /* ===================== fallback algo ===================== */
  149. bool ConvolutionImpl::AlgoFallback::usable(
  150. ConvolutionImpl*, const NCBKernSizeParam& param,
  151. AlgoSelectionStrategy /*algo_selection_strategy*/) const {
  152. auto&& fm = param.filter_meta;
  153. return fm.format == param::Convolution::Format::NCHW &&
  154. param.src_type.enumv() == DTypeEnum::Float32 &&
  155. param.filter_type.enumv() == DTypeEnum::Float32 &&
  156. param.dst_type.enumv() == DTypeEnum::Float32 &&
  157. fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1;
  158. }
  159. size_t ConvolutionImpl::AlgoFallback::get_workspace(
  160. ConvolutionImpl*, const NCBKernSizeParam& param) const {
  161. auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1];
  162. size_t nr_threads = param.nr_threads;
  163. if (param.filter_meta.should_flip) {
  164. // need transpose filter
  165. return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}}
  166. .total_size_in_bytes() *
  167. nr_threads;
  168. } else {
  169. return 0;
  170. }
  171. }
  172. SmallVector<ConvolutionImpl::NCBKern>
  173. ConvolutionImpl::AlgoFallback::dispatch_kern(
  174. ConvolutionImpl* opr, const NCBKernSizeParam& param) const {
  175. size_t group = param.filter_meta.group;
  176. size_t N = param.n;
  177. size_t nr_threads = param.nr_threads;
  178. size_t workspace_per_thread = get_workspace(opr, param) / nr_threads;
  179. auto kern_fallback = [workspace_per_thread](const NCBKernParam& p,
  180. const NCBKernIndex& ncb_index) {
  181. UNPACK_CONV_F32_NCB_KERN_SIZES(p);
  182. size_t batch_id = ncb_index.ndrange_id[1];
  183. size_t group_id = ncb_index.ndrange_id[0];
  184. MEGDNN_MARK_USED_VAR(N);
  185. auto src = p.src<float>(batch_id, group_id),
  186. filter = p.filter<float>(group_id);
  187. auto dst = p.dst<float>(batch_id, group_id);
  188. size_t thread_id = ncb_index.thread_id;
  189. void* workspace_ptr = reinterpret_cast<void*>(
  190. reinterpret_cast<ptrdiff_t>(p.workspace_ptr) +
  191. workspace_per_thread * thread_id);
  192. convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH,
  193. FW, OH, OW, OC, PH, PW, SH, SW,
  194. !p.filter_meta.should_flip);
  195. };
  196. return {{kern_fallback, {group, N, 1_z}}};
  197. }
  198. /* ===================== naive algo ===================== */
  199. bool ConvolutionImpl::AlgoNaive::usable(
  200. ConvolutionImpl*, const NCBKernSizeParam& param,
  201. AlgoSelectionStrategy /*algo_selection_strategy*/) const {
  202. bool ret = false;
  203. #define cb(dt) ret |= (param.src_type.enumv() == DTypeTrait<dt>::enumv);
  204. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
  205. #undef cb
  206. #define cb(dt_src, dt_dst) \
  207. ret |= (param.src_type.enumv() == DTypeTrait<dt_src>::enumv && \
  208. param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
  209. param.dst_type.enumv() == DTypeTrait<dt_dst>::enumv)
  210. cb(dtype::Int8, dtype::Int16);
  211. cb(dtype::Int8, dtype::Int32);
  212. cb(dtype::Quantized8Asymm, dtype::QuantizedS32);
  213. cb(dtype::QuantizedS8, dtype::QuantizedS32);
  214. #undef cb
  215. ret = ret &&
  216. (param.filter_meta.format == param::Convolution::Format::NCHW ||
  217. param.filter_meta.format == param::Convolution::Format::NHWC);
  218. return ret;
  219. }
  220. SmallVector<ConvolutionImpl::NCBKern> ConvolutionImpl::AlgoNaive::dispatch_kern(
  221. ConvolutionImpl*, const NCBKernSizeParam& param) const {
  222. size_t N = param.n;
  223. size_t group = param.filter_meta.group;
  224. #define cb(dt, cmode, compute_type) \
  225. do { \
  226. if (param.src_type.enumv() == DTypeTrait<dt>::enumv && \
  227. param.compute_mode == param::ConvBias::ComputeMode::cmode) { \
  228. using ctype = DTypeTrait<dt>::ctype; \
  229. using comp_type = DTypeTrait<compute_type>::ctype; \
  230. MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(1)) { \
  231. return {{kern_naive_forward<ctype, ctype, comp_type>, \
  232. {group, N, 1_z}}}; \
  233. } \
  234. MIDOUT_END(); \
  235. } \
  236. } while (0)
  237. cb(dtype::Float32, DEFAULT, dtype::Float32);
  238. #if !MEGDNN_DISABLE_FLOAT16
  239. cb(dtype::Float16, DEFAULT, dtype::Float16);
  240. cb(dtype::Float16, FLOAT32, dtype::Float32);
  241. #endif
  242. #undef cb
  243. #define cb(dt_src, dt_dst) \
  244. do { \
  245. if (param.src_type.enumv() == DTypeTrait<dt_src>::enumv && \
  246. param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
  247. param.dst_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
  248. MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(2)) { \
  249. return {{kern_naive_forward<DTypeTrait<dt_src>::ctype, \
  250. DTypeTrait<dt_dst>::ctype, \
  251. DTypeTrait<dt_dst>::ctype>, \
  252. {group, N, 1_z}}}; \
  253. } \
  254. MIDOUT_END(); \
  255. } \
  256. } while (0)
  257. cb(dtype::Int8, dtype::Int16);
  258. cb(dtype::Int8, dtype::Int32);
  259. cb(dtype::Quantized8Asymm, dtype::QuantizedS32);
  260. cb(dtype::QuantizedS8, dtype::QuantizedS32);
  261. megdnn_throw(megdnn_mangle("unknown convolution data type"));
  262. #undef cb
  263. }
  264. /* ===================== default algo ===================== */
  265. ConvolutionImpl::AlgoDefault::AlgoDefault(fallback::ConvBiasImpl* conv_bias_opr,
  266. ConvBiasImpl::AlgoBase* algorithm)
  267. : m_conv_bias_opr(conv_bias_opr), m_algorithm(algorithm) {
  268. megdnn_assert_internal(algorithm);
  269. m_name = ssprintf("CONVOLUTION_DEFAULT_%s", m_algorithm->name());
  270. }
  271. ConvBiasImpl::NCBKernSizeParam
  272. ConvolutionImpl::AlgoDefault::AlgoDefault::init_convbias_opr_and_param(
  273. ConvBiasImpl* conv_bias_opr, const NCBKernSizeParam& param) {
  274. DType bias_type = param.dst_type;
  275. if (bias_type.category() == DTypeCategory::QUANTIZED) {
  276. bias_type = dtype::QuantizedS32(
  277. mul_scale(param.src_type, param.filter_type));
  278. }
  279. ::ConvBiasImpl::NCBKernSizeParam conv_bias_size_param(
  280. param, 0, param::MatrixMul::Format::DEFAULT, bias_type, 0,
  281. BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY);
  282. // nonline mode
  283. conv_bias_opr->param().nonlineMode = conv_bias_size_param.nonlineMode;
  284. // convolution mode
  285. if (conv_bias_size_param.filter_meta.should_flip) {
  286. conv_bias_opr->param().mode = param::ConvolutionV0::Mode::CONVOLUTION;
  287. } else {
  288. conv_bias_opr->param().mode =
  289. param::ConvolutionV0::Mode::CROSS_CORRELATION;
  290. }
  291. // sparse
  292. if (conv_bias_size_param.filter_meta.group > 1) {
  293. conv_bias_opr->param().sparse = param::ConvolutionV0::Sparse::GROUP;
  294. } else {
  295. conv_bias_opr->param().sparse = param::ConvolutionV0::Sparse::DENSE;
  296. }
  297. // format
  298. conv_bias_opr->param().format = conv_bias_size_param.filter_meta.format;
  299. // pad stride dilate
  300. conv_bias_opr->param().pad_h = conv_bias_size_param.filter_meta.padding[0];
  301. conv_bias_opr->param().pad_w = conv_bias_size_param.filter_meta.padding[1];
  302. conv_bias_opr->param().stride_h =
  303. conv_bias_size_param.filter_meta.stride[0];
  304. conv_bias_opr->param().stride_w =
  305. conv_bias_size_param.filter_meta.stride[1];
  306. conv_bias_opr->param().dilate_h =
  307. conv_bias_size_param.filter_meta.dilation[0];
  308. conv_bias_opr->param().dilate_w =
  309. conv_bias_size_param.filter_meta.dilation[1];
  310. // output_block_size
  311. conv_bias_opr->param().output_block_size =
  312. conv_bias_size_param.output_block_size;
  313. // compute_mode
  314. conv_bias_opr->param().compute_mode = conv_bias_size_param.compute_mode;
  315. return conv_bias_size_param;
  316. }
  317. bool ConvolutionImpl::AlgoDefault::is_preferred(
  318. ConvolutionImpl*, const NCBKernSizeParam& param) const {
  319. ::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
  320. init_convbias_opr_and_param(m_conv_bias_opr, param);
  321. return m_algorithm->is_preferred(m_conv_bias_opr, conv_bias_param);
  322. }
  323. bool ConvolutionImpl::AlgoDefault::usable(
  324. ConvolutionImpl*, const NCBKernSizeParam& param,
  325. AlgoSelectionStrategy algo_selection_strategy) const {
  326. ::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
  327. init_convbias_opr_and_param(m_conv_bias_opr, param);
  328. return m_algorithm->usable(m_conv_bias_opr, conv_bias_param,
  329. static_cast<ConvBiasImpl::AlgoSelectionStrategy>(
  330. algo_selection_strategy));
  331. }
  332. WorkspaceBundle ConvolutionImpl::AlgoDefault::get_bundle(
  333. const NCBKernSizeParam& param) const {
  334. ::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
  335. init_convbias_opr_and_param(m_conv_bias_opr, param);
  336. m_conv_bias_opr->execution_policy() = {m_algorithm};
  337. return WorkspaceBundle(nullptr, {m_algorithm->get_workspace(
  338. m_conv_bias_opr, conv_bias_param)});
  339. }
  340. size_t ConvolutionImpl::AlgoDefault::get_workspace(
  341. ConvolutionImpl*, const NCBKernSizeParam& param) const {
  342. return get_bundle(param).total_size_in_bytes();
  343. }
  344. //! Return the implment kernel
  345. SmallVector<ConvolutionImpl::NCBKern> ConvolutionImpl::AlgoDefault::get_kimpl(
  346. ::ConvBiasImpl* conv_bias_opr, ConvBiasImpl::AlgoBase* algo,
  347. const NCBKernSizeParam& param) {
  348. MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(0)) {
  349. // construct the conv_bias kern param
  350. ::ConvBiasImpl::NCBKernParam conv_bias_param;
  351. ::ConvBiasImpl::NCBKernSizeParam conv_bias_size_param =
  352. init_convbias_opr_and_param(conv_bias_opr, param);
  353. static_cast<::ConvBiasImpl::NCBKernSizeParam&>(conv_bias_param) =
  354. conv_bias_size_param;
  355. auto conv_bias_kerns =
  356. algo->dispatch_kerns(conv_bias_opr, conv_bias_param);
  357. SmallVector<ConvolutionImpl::NCBKern> convolution_kerns;
  358. //! Set the conv_bias param using convolution param
  359. auto set_copy_param_run_time_address =
  360. [](const NCBKernParam& conv_param,
  361. ::ConvBiasImpl::NCBKernParam& copied_param) {
  362. copied_param.src_ptr = conv_param.src_ptr;
  363. copied_param.filter_ptr = conv_param.filter_ptr;
  364. copied_param.dst_ptr = conv_param.dst_ptr;
  365. copied_param.workspace_ptr = conv_param.workspace_ptr;
  366. copied_param.workspace_size = conv_param.workspace_size;
  367. };
  368. for (size_t i = 0; i < conv_bias_kerns.size(); i++) {
  369. auto kernel = conv_bias_kerns[i];
  370. //! If the kerenl batch parallel
  371. auto run = [=](const NCBKernParam& p,
  372. const NCBKernIndex& ncb_index) {
  373. auto copy_param = conv_bias_param;
  374. set_copy_param_run_time_address(p, copy_param);
  375. kernel.kern(copy_param,
  376. {ncb_index.thread_id, ncb_index.ndrange_id});
  377. };
  378. convolution_kerns.push_back({run, kernel.global_size});
  379. }
  380. return convolution_kerns;
  381. }
  382. MIDOUT_END();
  383. }
  384. /* ===================== direct algo ===================== */
  385. bool ConvolutionBackwardDataImpl::AlgoDirect::usable(
  386. ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
  387. auto&& fm = param.filter_meta;
  388. return fm.format == param::Convolution::Format::NCHW &&
  389. param.diff_type.enumv() == DTypeEnum::Float32 &&
  390. param.filter_type.enumv() == DTypeEnum::Float32 &&
  391. param.grad_type.enumv() == DTypeEnum::Float32 &&
  392. fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 &&
  393. fm.dilation[1] == 1;
  394. }
  395. size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace(
  396. ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
  397. auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1];
  398. if (param.filter_meta.should_flip) {
  399. // need transpose filter
  400. return FH * FW * sizeof(float);
  401. } else {
  402. return 0;
  403. }
  404. }
  405. ConvolutionBackwardDataImpl::ncb_kern_t
  406. ConvolutionBackwardDataImpl::AlgoDirect::dispatch_kern(
  407. ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
  408. return kern_direct;
  409. }
  410. /* ===================== Matrix mul algo ===================== */
  411. bool ConvolutionBackwardDataImpl::AlgoMatrixMul::usable(
  412. ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
  413. auto&& fm = param.filter_meta;
  414. return fm.format == param::Convolution::Format::NCHW &&
  415. fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 &&
  416. fm.dilation[1] == 1;
  417. }
  418. size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace(
  419. ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
  420. return get_bundle(param).total_size_in_bytes();
  421. }
  422. ConvolutionBackwardDataImpl::ncb_kern_t
  423. ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern(
  424. ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
  425. #define cb(dt) \
  426. do { \
  427. if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \
  428. using ctype = DTypeTrait<dt>::ctype; \
  429. return kern_matmul<ctype, ctype, ctype>; \
  430. } \
  431. } while (0);
  432. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
  433. #undef cb
  434. #define cb(dt_src, dt_dst) \
  435. do { \
  436. if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
  437. param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
  438. param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
  439. return kern_matmul<DTypeTrait<dt_src>::ctype, \
  440. DTypeTrait<dt_src>::ctype, \
  441. DTypeTrait<dt_dst>::ctype>; \
  442. } \
  443. } while (0)
  444. cb(dtype::Int8, dtype::Int32);
  445. cb(dtype::QuantizedS8, dtype::QuantizedS32);
  446. cb(dtype::Quantized8Asymm, dtype::QuantizedS32);
  447. megdnn_throw("unsupported data type on matrix mul");
  448. #undef cb
  449. }
  450. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台