You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. /**
  2. * \file dnn/src/fallback/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/fallback/convolution/opr_impl.h"
  12. #include "src/common/algo_chooser.h"
  13. #include "src/common/metahelper.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/common/utils.h"
  16. #include "src/fallback/convolution/algos.h"
  17. #include "src/fallback/convolution/run_conv.h"
  18. #include "src/naive/convolution/helper.h"
  19. #include "src/naive/handle.h"
  20. #include "midout.h"
  21. #include <cstring>
  22. MIDOUT_DECL(megdnn_fb_conv_float)
  23. MIDOUT_DECL(megdnn_fb_convbwd_float)
  24. using namespace megdnn;
  25. using namespace fallback;
  26. namespace {
  27. class NaiveConvolutionBackwardData final
  28. : public megdnn::ConvolutionBackwardData::Algorithm {
  29. bool is_reproducible() const override { return true; }
  30. const char* name() const override { return "NCBD"; }
  31. };
  32. NaiveConvolutionBackwardData naive_conv_backward_data;
  33. uint8_t fallback_deconv_algo_type_storage;
  34. uint8_t fallback_conv_algo_type_storage;
  35. template <typename T>
  36. void incr_ptr(T*& dst, ptrdiff_t delta) {
  37. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  38. }
  39. } // namespace
  40. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  41. AlgoFallback algo_fallback;
  42. AlgoNaive algo_naive;
  43. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  44. public:
  45. AlgoPack() {
  46. static CpuOprDelegationStorage<1> storage;
  47. auto conv_bias_opr = storage.get<ConvBias, 0>();
  48. auto&& conv_bias_algo =
  49. static_cast<ConvBiasImpl*>(conv_bias_opr)->algo_pack();
  50. for (auto&& algorithm : conv_bias_algo) {
  51. // fallback algo
  52. refhold.emplace_back(new AlgoDefault(
  53. static_cast<ConvBiasImpl*>(conv_bias_opr), algorithm));
  54. all_algos.emplace_back(refhold.back().get());
  55. }
  56. all_algos.emplace_back(&algo_fallback);
  57. all_algos.emplace_back(&algo_naive);
  58. }
  59. SmallVector<AlgoBase*> all_algos;
  60. };
  61. void* const ConvolutionImpl::sm_fallback_conv_algo_type =
  62. &fallback_conv_algo_type_storage;
  63. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
  64. static AlgoPack sl_algo_pack;
  65. return sl_algo_pack.all_algos;
  66. }
  67. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  68. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  69. }
  70. void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  71. _megdnn_tensor_out dst,
  72. _megdnn_workspace workspace) {
  73. auto fparam = make_ncb_kern_param(src, filter, dst, workspace);
  74. ConvolutionImpl::Algorithm* algo = get_algorithm(fparam, workspace.size);
  75. if (!is_naive_algo(algo) &&
  76. ncb_algo_get_workspace(algo, fparam) <= workspace.size) {
  77. exec_with_ncb_kern(fparam, algo);
  78. } else {
  79. naive::ConvolutionForwardImpl::exec(src, filter, dst, workspace);
  80. }
  81. }
  82. size_t ConvolutionImpl::get_workspace_in_bytes(const TensorLayout& src,
  83. const TensorLayout& filter,
  84. const TensorLayout& dst) {
  85. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  86. Algorithm* algo = get_algorithm(fparam);
  87. if (is_naive_algo(algo)) {
  88. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  89. src, filter, dst);
  90. } else {
  91. return ncb_algo_get_workspace(algo, fparam);
  92. }
  93. }
  94. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  95. const TensorLayout& src, const TensorLayout& filter,
  96. const TensorLayout& dst) {
  97. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  98. auto ret = get_all_algorithms_with_ncb(fparam);
  99. if (ret.empty()) {
  100. return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter,
  101. dst);
  102. }
  103. return ret;
  104. }
  105. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  106. const TensorLayout& src, const TensorLayout& filter,
  107. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  108. bool reproducible) {
  109. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  110. auto result = get_algorithm_heuristic_with_ncb(
  111. fparam, workspace_limit_in_bytes, reproducible);
  112. if (result == nullptr) {
  113. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  114. src, filter, dst, workspace_limit_in_bytes, reproducible);
  115. }
  116. return result;
  117. }
  118. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  119. const TensorLayout& src, const TensorLayout& filter,
  120. const TensorLayout& dst) {
  121. auto safe_u32 = [](size_t v) -> uint32_t {
  122. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  123. "value too large: %zu", v);
  124. return v;
  125. };
  126. size_t spatial_pos;
  127. if (param().format == Param::Format::NCHW88 ||
  128. param().format == Param::Format::NCHW8 ||
  129. param().format == Param::Format::NCHW4) {
  130. spatial_pos = 2;
  131. } else if (param().format == Param::Format::NCHW ||
  132. param().format == Param::Format::NCHW_WINOGRAD) {
  133. spatial_pos = 2;
  134. } else if (param().format == Param::Format::NHWC) {
  135. spatial_pos = 1;
  136. } else {
  137. megdnn_assert(0, "invalid conv format %d",
  138. static_cast<int>(param().format));
  139. }
  140. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  141. ->megcore_dispatcher()
  142. ->nr_threads();
  143. return {safe_u32(src[0]),
  144. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  145. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  146. check_layout_fwd(src, filter, dst),
  147. src.dtype,
  148. filter.dtype,
  149. dst.dtype,
  150. src.stride[0],
  151. dst.stride[0],
  152. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  153. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  154. param().compute_mode,
  155. nr_threads};
  156. }
  157. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  158. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  159. _megdnn_workspace workspace) {
  160. NCBKernParam ret;
  161. static_cast<NCBKernSizeParam&>(ret) =
  162. make_ncb_kern_size_param(src.layout, filter.layout, dst.layout);
  163. ret.src_ptr = src.raw_ptr;
  164. ret.filter_ptr = filter.raw_ptr;
  165. ret.dst_ptr = dst.raw_ptr;
  166. ret.workspace_ptr = workspace.raw_ptr;
  167. ret.workspace_size = workspace.size;
  168. return ret;
  169. }
  170. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
  171. Algorithm* algo) {
  172. auto kerns = ncb_algo_dispatch_kern(algo, param);
  173. auto fallback_handle = handle();
  174. for (auto kernel : kerns) {
  175. megdnn_assert(param.filter_meta.format == Param::Format::NCHW ||
  176. param.filter_meta.format == Param::Format::NHWC ||
  177. param.filter_meta.format == Param::Format::NCHW88,
  178. "invalid conv format");
  179. auto run = [param, kernel](size_t index, size_t thread_id) {
  180. CpuNDRange ndrange_id(kernel.global_size, index);
  181. kernel.kern(param, {thread_id, ndrange_id});
  182. };
  183. static_cast<naive::HandleImpl*>(fallback_handle)
  184. ->dispatch_kern(run, kernel.global_size.total_size());
  185. }
  186. }
  187. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  188. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  189. bool reproducible) {
  190. for (auto i : get_all_algorithms_with_ncb(param)) {
  191. if (static_cast<AlgoBase*>(i)->usable_reproducible(
  192. this, param, AlgoSelectionStrategy::HEURISTIC,
  193. reproducible) &&
  194. ncb_algo_get_workspace(i, param) <= workspace_limit_in_bytes) {
  195. return i;
  196. }
  197. }
  198. return nullptr;
  199. }
  200. std::vector<ConvolutionImpl::Algorithm*>
  201. ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  202. std::vector<Algorithm*> ret;
  203. std::vector<Algorithm*> prefer_algos;
  204. for (auto&& i : algo_pack()) {
  205. if (i->usable(this, param, AlgoSelectionStrategy::FULL_RUN)) {
  206. if (i->is_preferred(this, param)) {
  207. prefer_algos.push_back(i);
  208. } else {
  209. ret.push_back(i);
  210. }
  211. }
  212. }
  213. std::reverse(prefer_algos.begin(), prefer_algos.end());
  214. //! Prefer algo inserted from begin
  215. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  216. return ret;
  217. }
  218. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  219. const NCBKernSizeParam& param, size_t workspace_size) {
  220. if (auto set = execution_policy().algorithm) {
  221. return set;
  222. }
  223. if (!m_prev_selected_algo ||
  224. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  225. m_prev_selected_algo =
  226. get_algorithm_heuristic_with_ncb(param, workspace_size);
  227. m_prev_selected_algo_sizep = param;
  228. }
  229. return m_prev_selected_algo;
  230. }
  231. const char* ConvolutionImpl::get_algorithm_set_name() const {
  232. // fallback version 0
  233. return "F0";
  234. }
  235. /* ===================== ConvolutionBackwardData ===================== */
  236. void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type =
  237. &fallback_deconv_algo_type_storage;
  238. struct ConvolutionBackwardDataImpl::AlgoPack {
  239. AlgoDirect direct;
  240. AlgoMatrixMul matmul;
  241. };
  242. ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
  243. void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
  244. _megdnn_tensor_in diff,
  245. _megdnn_tensor_out grad,
  246. _megdnn_workspace workspace) {
  247. if (param().format == param::Convolution::Format::NHWCD4 ||
  248. param().format == param::Convolution::Format::NCHW4) {
  249. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad,
  250. workspace);
  251. }
  252. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  253. return exec_with_ncb_kern(fparam);
  254. }
  255. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  256. const TensorLayout& filter, const TensorLayout& diff,
  257. const TensorLayout& grad) {
  258. if (param().format == param::Convolution::Format::NHWCD4 ||
  259. param().format == param::Convolution::Format::NCHW4) {
  260. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  261. filter, diff, grad);
  262. }
  263. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  264. return get_workspace_with_ncb(fparam);
  265. }
  266. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  267. ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
  268. const TensorLayout& diff,
  269. const TensorLayout& grad) {
  270. if (param().format == param::Convolution::Format::NHWCD4 ||
  271. param().format == param::Convolution::Format::NCHW4) {
  272. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  273. filter, diff, grad);
  274. }
  275. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  276. auto ret = get_all_algorithms_with_ncb(fparam);
  277. megdnn_assert(!ret.empty(), "no usable conv fwd algorithm");
  278. return ret;
  279. }
  280. ConvolutionBackwardDataImpl::Algorithm*
  281. ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  282. const TensorLayout& filter, const TensorLayout& diff,
  283. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  284. bool reproducible) {
  285. if (param().format == param::Convolution::Format::NHWCD4 ||
  286. param().format == param::Convolution::Format::NCHW4) {
  287. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  288. filter, diff, grad, workspace_limit_in_bytes, reproducible);
  289. }
  290. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  291. return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
  292. reproducible);
  293. }
  294. ConvolutionBackwardDataImpl::NCBKernSizeParam
  295. ConvolutionBackwardDataImpl::make_ncb_kern_size_param(
  296. const TensorLayout& filter, const TensorLayout& diff,
  297. const TensorLayout& grad) {
  298. auto safe_u32 = [](size_t v) -> uint32_t {
  299. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  300. "value too large: %zu", v);
  301. return v;
  302. };
  303. size_t spatial_pos;
  304. if (param().format == Param::Format::NCHW) {
  305. spatial_pos = 2;
  306. } else {
  307. megdnn_assert(param().format == Param::Format::NHWC,
  308. "invalid conv format");
  309. spatial_pos = 1;
  310. }
  311. auto grad_fwd = grad;
  312. auto filter_fwd = filter;
  313. auto diff_fwd = diff;
  314. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  315. return {
  316. safe_u32(diff[0]),
  317. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  318. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  319. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  320. diff.dtype,
  321. filter.dtype,
  322. grad.dtype,
  323. diff,
  324. filter,
  325. grad,
  326. diff.stride[0],
  327. grad.stride[0],
  328. 0,
  329. 0,
  330. 0,
  331. param().compute_mode,
  332. };
  333. }
  334. ConvolutionBackwardDataImpl::NCBKernParam
  335. ConvolutionBackwardDataImpl::make_ncb_kern_param(_megdnn_tensor_in filter,
  336. _megdnn_tensor_in diff,
  337. _megdnn_tensor_out grad,
  338. _megdnn_workspace workspace) {
  339. NCBKernParam ret;
  340. static_cast<NCBKernSizeParam&>(ret) =
  341. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  342. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  343. megdnn_assert(workspace.size >= required_workspace_in_bytes,
  344. "required workspace: %zu; provided workspace: %zu",
  345. required_workspace_in_bytes, workspace.size);
  346. ret.filter_ptr = filter.raw_ptr;
  347. ret.diff_ptr = diff.raw_ptr;
  348. ret.grad_ptr = grad.raw_ptr;
  349. ret.workspace_ptr = workspace.raw_ptr;
  350. ret.workspace_size = workspace.size;
  351. return ret;
  352. }
  353. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(
  354. const NCBKernParam& param) {
  355. auto p1g = param;
  356. auto group = p1g.filter_meta.group;
  357. p1g.filter_meta.group = 1;
  358. auto algo = get_algorithm(p1g);
  359. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  360. if (algo == &naive_conv_backward_data || group == 1) {
  361. auto run = [kptr, param]() { kptr(param); };
  362. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  363. } else {
  364. megdnn_assert(p1g.filter_meta.format == Param::Format::NCHW ||
  365. p1g.filter_meta.format == Param::Format::NHWC,
  366. "invalid conv format");
  367. auto run = [kptr, p1g_orig = p1g, group]() {
  368. auto p1g = p1g_orig;
  369. ptrdiff_t istrd, fstrd, ostrd;
  370. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  371. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  372. p1g.filter_type.size();
  373. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  374. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  375. p1g.diff_extra_mem_size =
  376. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  377. p1g.filter_extra_mem_size =
  378. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  379. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  380. p1g.filter_type.size();
  381. p1g.grad_extra_mem_size =
  382. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  383. if (p1g.filter_meta.format == Param::Format::NCHW) {
  384. istrd *= p1g.isz[0] * p1g.isz[1];
  385. ostrd *= p1g.osz[0] * p1g.osz[1];
  386. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  387. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  388. } else {
  389. // must be NHWC. No action performed.
  390. }
  391. for (size_t i = 0; i < group; ++i) {
  392. kptr(p1g);
  393. incr_ptr(p1g.diff_ptr, istrd);
  394. incr_ptr(p1g.filter_ptr, fstrd);
  395. incr_ptr(p1g.grad_ptr, ostrd);
  396. p1g.diff_extra_mem_size -= istrd;
  397. p1g.filter_extra_mem_size -= fstrd;
  398. p1g.grad_extra_mem_size -= ostrd;
  399. }
  400. };
  401. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  402. }
  403. }
  404. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  405. const NCBKernSizeParam& param) {
  406. if (param.filter_meta.group != 1) {
  407. auto p1g = param;
  408. p1g.filter_meta.group = 1;
  409. return ncb_1g_get_workspace(get_algorithm(p1g), p1g);
  410. }
  411. return ncb_1g_get_workspace(get_algorithm(param), param);
  412. }
  413. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  414. ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
  415. const NCBKernSizeParam& param) {
  416. if (param.filter_meta.group != 1) {
  417. auto p1g = param;
  418. p1g.filter_meta.group = 1;
  419. return ncb_1g_get_all_algorithms(p1g);
  420. }
  421. return ncb_1g_get_all_algorithms(param);
  422. }
  423. ConvolutionBackwardDataImpl::Algorithm*
  424. ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
  425. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  426. bool reproducible) {
  427. if (param.filter_meta.group != 1) {
  428. auto p1g = param;
  429. p1g.filter_meta.group = 1;
  430. return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
  431. reproducible);
  432. }
  433. return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
  434. reproducible);
  435. }
  436. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  437. Algorithm* algo, const NCBKernSizeParam& param) {
  438. megdnn_assert(param.filter_meta.group == 1);
  439. if (algo->type() == sm_fallback_deconv_algo_type) {
  440. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  441. }
  442. megdnn_assert(algo == &naive_conv_backward_data);
  443. return 0;
  444. }
  445. ConvolutionBackwardDataImpl::ncb_kern_t
  446. ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  447. Algorithm* algo, const NCBKernSizeParam& param) {
  448. megdnn_assert(param.filter_meta.group == 1);
  449. if (algo->type() == sm_fallback_deconv_algo_type) {
  450. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  451. }
  452. if (algo == &naive_conv_backward_data) {
  453. #define cb(_dt) \
  454. do { \
  455. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  456. MIDOUT_BEGIN(megdnn_fb_convbwd_float, \
  457. midout_iv(DTypeTrait<_dt>::enumv)) { \
  458. using ctype = DTypeTrait<_dt>::ctype; \
  459. return kern_naive<ctype, ctype, ctype>; \
  460. } \
  461. MIDOUT_END(); \
  462. } \
  463. } while (0);
  464. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
  465. #undef cb
  466. #define cb(dt_src, dt_dst) \
  467. do { \
  468. if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
  469. param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
  470. param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
  471. return kern_naive<DTypeTrait<dt_src>::ctype, \
  472. DTypeTrait<dt_src>::ctype, \
  473. DTypeTrait<dt_dst>::ctype>; \
  474. } \
  475. } while (0);
  476. cb(dtype::Int8, dtype::Int32) cb(dtype::Quantized8Asymm,
  477. dtype::QuantizedS32)
  478. cb(dtype::QuantizedS8, dtype::QuantizedS32) megdnn_throw(
  479. "unsupported data type on ConvolutionBackwardData");
  480. #undef cb
  481. }
  482. megdnn_throw(
  483. megdnn_mangle("no suitable ConvolutionBackwardData algorithm"));
  484. }
  485. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  486. const NCBKernSizeParam& param) {
  487. auto&& fm = param.filter_meta;
  488. auto OC = fm.ocpg, IC = fm.icpg;
  489. return (OC * IC >= 32) ||
  490. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  491. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  492. }
  493. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  494. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
  495. const NCBKernSizeParam& param) {
  496. std::vector<Algorithm*> ret;
  497. ret.reserve(2);
  498. ret.push_back(&naive_conv_backward_data);
  499. // insert from lowest to highest preference
  500. AlgoBase* cand[2] = {nullptr};
  501. if (param.filter_meta.group == 1 && param.filter_meta.dilation[0] == 1 &&
  502. param.filter_meta.dilation[1] == 1) {
  503. // we currently only have non-dilated algos
  504. if (param.filter_type.enumv() == DTypeEnum::Float32) {
  505. if (is_matrix_mul_preferred(param)) {
  506. cand[0] = &sm_algo_pack.direct;
  507. cand[1] = &sm_algo_pack.matmul;
  508. } else {
  509. cand[0] = &sm_algo_pack.matmul;
  510. cand[1] = &sm_algo_pack.direct;
  511. }
  512. } else {
  513. cand[0] = &sm_algo_pack.matmul;
  514. }
  515. }
  516. for (auto i : cand) {
  517. if (i && i->usable(this, param)) {
  518. ret.push_back(i);
  519. }
  520. }
  521. std::reverse(ret.begin(), ret.end());
  522. return ret;
  523. }
  524. ConvolutionBackwardDataImpl::Algorithm*
  525. ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
  526. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  527. bool reproducible) {
  528. for (auto i : ncb_1g_get_all_algorithms(param)) {
  529. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  530. if (reproducible) {
  531. if (i->is_reproducible()) {
  532. return i;
  533. }
  534. } else {
  535. return i;
  536. }
  537. }
  538. }
  539. megdnn_assert(0,
  540. "no suitable algorithm found within given workspace limit");
  541. }
  542. ConvolutionBackwardDataImpl::Algorithm*
  543. ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
  544. if (auto set = execution_policy().algorithm) {
  545. return set;
  546. }
  547. if (!m_prev_selected_algo ||
  548. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  549. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  550. param, std::numeric_limits<size_t>::max());
  551. m_prev_selected_algo_sizep = param;
  552. }
  553. return m_prev_selected_algo;
  554. }
  555. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  556. // fallback version 0
  557. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  558. }
  559. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)