You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. /**
  2. * \file dnn/src/fallback/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/algo_chooser.h"
  13. #include "src/common/metahelper.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/common/utils.h"
  16. #include "src/fallback/convolution/algos.h"
  17. #include "src/fallback/convolution/opr_impl.h"
  18. #include "src/fallback/convolution/run_conv.h"
  19. #include "src/naive/convolution/helper.h"
  20. #include "src/naive/handle.h"
  21. #include "midout.h"
  22. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  23. #include "src/arm_common/convolution/opr_impl.h"
  24. #endif
  25. #include <cstring>
  26. #include <unordered_map>
  27. MIDOUT_DECL(megdnn_fb_convbwd_float)
  28. using namespace megdnn;
  29. using namespace fallback;
  30. namespace {
  31. template <typename T>
  32. void incr_ptr(T*& dst, ptrdiff_t delta) {
  33. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  34. }
  35. } // namespace
  36. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  37. AlgoFallback algo_fallback;
  38. AlgoNaive algo_naive;
  39. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  40. SmallVector<AlgoBase*> m_all_algos;
  41. AlgoBase::Mapper m_all_algos_map;
  42. public:
  43. AlgoPack() {
  44. static CpuOprDelegationStorage<1> storage;
  45. auto conv_bias_opr = storage.get<ConvBias, 0>();
  46. auto&& conv_bias_algo =
  47. static_cast<ConvBiasImpl*>(conv_bias_opr)->get_all_packed_algo();
  48. for (auto&& algorithm : conv_bias_algo) {
  49. // fallback algo
  50. refhold.emplace_back(new AlgoDefault(algorithm));
  51. m_all_algos.emplace_back(refhold.back().get());
  52. }
  53. m_all_algos.emplace_back(&algo_fallback);
  54. m_all_algos.emplace_back(&algo_naive);
  55. for (auto&& algo : m_all_algos) {
  56. m_all_algos_map.emplace(algo->info().desc, algo);
  57. }
  58. }
  59. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  60. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  61. };
  62. const ConvolutionImpl::AlgoPack& ConvolutionImpl::algo_pack() {
  63. static AlgoPack algo_pack;
  64. return algo_pack;
  65. }
  66. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::get_all_packed_algo() {
  67. return algo_pack().all_algos();
  68. }
  69. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
  70. ConvAlgoTypePack target_type) {
  71. megdnn_assert(nr_type_contain(target_type.data_type),
  72. "ConvBias algo selection only support one type");
  73. SmallVector<ConvolutionImpl::AlgoBase*> algos;
  74. for (auto&& algo : get_all_packed_algo()) {
  75. auto algo_type = algo->get_algo_type();
  76. if (contain_data_type(algo_type.data_type, target_type.data_type) &&
  77. algo_type.algo_category == target_type.algo_category) {
  78. algos.push_back(algo);
  79. }
  80. }
  81. return algos;
  82. }
  83. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  84. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  85. }
  86. #define NCB_ALGO_FUNC(name, algo, param) \
  87. static_cast<AlgoBase*>(algo)->name(param)
  88. void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  89. _megdnn_tensor_out dst,
  90. const PreprocessedFilter* preprocessed_filter,
  91. _megdnn_workspace workspace) {
  92. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  93. workspace);
  94. auto&& algo = get_algorithm(fparam, workspace.size);
  95. if (!is_naive_algo(algo) &&
  96. NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
  97. exec_with_ncb_kern(fparam, algo);
  98. } else {
  99. naive::ConvolutionForwardImpl::exec(src, filter, dst,
  100. preprocessed_filter, workspace);
  101. }
  102. }
  103. void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout,
  104. _megdnn_tensor_in filter,
  105. const TensorLayout& dst_layout,
  106. PreprocessedFilter* preprocessed_filter,
  107. _megdnn_workspace workspace) {
  108. //! exec_preprocess currently only support preprocess weights before exec,
  109. //! src/dst will be ignored, just set to nullptr
  110. TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
  111. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  112. workspace);
  113. //! should not pass workspace_size limit otherwise can not find match algo
  114. auto&& algo = get_algorithm(fparam);
  115. if (!is_naive_algo(algo) &&
  116. NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <=
  117. workspace.size) {
  118. exec_preprocess_with_ncb_kern(fparam, algo);
  119. } else {
  120. naive::ConvolutionForwardImpl::exec_preprocess(
  121. src_layout, filter, dst_layout, preprocessed_filter, workspace);
  122. }
  123. }
  124. size_t ConvolutionImpl::get_workspace_in_bytes(
  125. const TensorLayout& src, const TensorLayout& filter,
  126. const TensorLayout& dst,
  127. const PreprocessedFilter* preprocessed_filter) {
  128. TensorLayoutArray layouts{src, filter, dst};
  129. HeuristicCache::Key key{this->handle(), this->get_opr_type(),
  130. layouts.data(), layouts.size(), &this->param(),
  131. sizeof(this->param())};
  132. auto rst = HeuristicCache::instance().get(key);
  133. if (rst.policy.algo.valid()) {
  134. return rst.workspace;
  135. }
  136. auto fparam =
  137. make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
  138. auto&& algo = get_algorithm(fparam);
  139. if (is_naive_algo(algo)) {
  140. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  141. src, filter, dst, preprocessed_filter);
  142. } else {
  143. return NCB_ALGO_FUNC(get_workspace, algo, fparam);
  144. }
  145. }
  146. size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
  147. const TensorLayout& src, const TensorLayout& filter,
  148. const TensorLayout& dst) {
  149. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  150. auto&& algo = get_algorithm(fparam);
  151. if (is_naive_algo(algo)) {
  152. return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes(
  153. src, filter, dst);
  154. } else {
  155. return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
  156. }
  157. }
  158. SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
  159. const TensorLayout& src, const TensorLayout& filter,
  160. const TensorLayout& dst) {
  161. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  162. auto&& algo = get_algorithm(fparam);
  163. if (is_naive_algo(algo)) {
  164. return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout(
  165. src, filter, dst);
  166. } else {
  167. return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
  168. }
  169. }
  170. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  171. const TensorLayout& src, const TensorLayout& filter,
  172. const TensorLayout& dst) {
  173. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  174. auto ret = get_all_algorithms_with_ncb(fparam);
  175. if (ret.empty()) {
  176. return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter,
  177. dst);
  178. }
  179. return ret;
  180. }
  181. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe(
  182. const TensorLayout& src, const TensorLayout& filter,
  183. const TensorLayout& dst) {
  184. auto ret_safe = ConvolutionImpl::get_all_algorithms(src,filter,dst);
  185. return ret_safe;
  186. }
  187. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  188. const TensorLayout& src, const TensorLayout& filter,
  189. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  190. const AlgoAttribute& positive_attr,
  191. const AlgoAttribute& negative_attr) {
  192. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  193. auto result = get_algorithm_heuristic_with_ncb(
  194. fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
  195. if (result == nullptr) {
  196. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  197. src, filter, dst, workspace_limit_in_bytes, positive_attr,
  198. negative_attr);
  199. }
  200. return result;
  201. }
  202. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  203. const TensorLayout& src, const TensorLayout& filter,
  204. const TensorLayout& dst,
  205. const PreprocessedFilter* preprocessed_filter) {
  206. auto safe_u32 = [](size_t v) -> uint32_t {
  207. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  208. "value too large: %zu", v);
  209. return v;
  210. };
  211. size_t spatial_pos;
  212. if (param().format == Param::Format::NCHW88 ||
  213. param().format == Param::Format::NCHW8 ||
  214. param().format == Param::Format::NCHW4 ||
  215. param().format == Param::Format::NCHW44_DOT ||
  216. param().format == Param::Format::NCHW44) {
  217. spatial_pos = 2;
  218. } else if (param().format == Param::Format::NCHW) {
  219. spatial_pos = 2;
  220. } else if (param().format == Param::Format::NHWC) {
  221. spatial_pos = 1;
  222. } else {
  223. megdnn_assert(0, "invalid conv format %d",
  224. static_cast<int>(param().format));
  225. }
  226. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  227. ->megcore_dispatcher()
  228. ->nr_threads();
  229. return {safe_u32(src[0]),
  230. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  231. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  232. check_layout_fwd(src, filter, dst),
  233. src.dtype,
  234. filter.dtype,
  235. dst.dtype,
  236. src.stride[0],
  237. dst.stride[0],
  238. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  239. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  240. param().compute_mode,
  241. nr_threads,
  242. preprocessed_filter};
  243. }
  244. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  245. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  246. const PreprocessedFilter* preprocessed_filter,
  247. _megdnn_workspace workspace) {
  248. NCBKernParam ret;
  249. static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
  250. src.layout, filter.layout, dst.layout, preprocessed_filter);
  251. ret.src_ptr = src.raw_ptr;
  252. ret.filter_ptr = filter.raw_ptr;
  253. ret.dst_ptr = dst.raw_ptr;
  254. ret.workspace_ptr = workspace.raw_ptr;
  255. ret.workspace_size = workspace.size;
  256. return ret;
  257. }
  258. void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
  259. Algorithm* algo) {
  260. auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
  261. auto&& fallback_handle = handle();
  262. for (auto&& kernel : kerns) {
  263. megdnn_assert(
  264. param.filter_meta.format == Param::Format::NCHW ||
  265. param.filter_meta.format == Param::Format::NHWC ||
  266. param.filter_meta.format == Param::Format::NCHW88 ||
  267. param.filter_meta.format == Param::Format::NCHW44 ||
  268. param.filter_meta.format == Param::Format::NCHW44_DOT,
  269. "invalid conv format");
  270. auto run = [param, kernel](size_t index, size_t thread_id) {
  271. CpuNDRange ndrange_id(kernel.global_size, index);
  272. kernel.kern(param, {thread_id, ndrange_id});
  273. };
  274. static_cast<naive::HandleImpl*>(fallback_handle)
  275. ->dispatch_kern(run, kernel.global_size.total_size());
  276. }
  277. }
  278. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
  279. Algorithm* algo) {
  280. auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
  281. auto&& fallback_handle = handle();
  282. for (auto&& kernel : kerns) {
  283. megdnn_assert(
  284. param.filter_meta.format == Param::Format::NCHW ||
  285. param.filter_meta.format == Param::Format::NHWC ||
  286. param.filter_meta.format == Param::Format::NCHW88 ||
  287. param.filter_meta.format == Param::Format::NCHW44 ||
  288. param.filter_meta.format == Param::Format::NCHW44_DOT,
  289. "invalid conv format");
  290. auto run = [param, kernel](size_t index, size_t thread_id) {
  291. CpuNDRange ndrange_id(kernel.global_size, index);
  292. kernel.kern(param, {thread_id, ndrange_id});
  293. };
  294. static_cast<naive::HandleImpl*>(fallback_handle)
  295. ->dispatch_kern(run, kernel.global_size.total_size());
  296. }
  297. }
  298. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  299. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  300. const AlgoAttribute& positive_attr,
  301. const AlgoAttribute& negative_attr) {
  302. auto algo_data_type = param.deduce_algo_data_type();
  303. auto suggest_category_order = suggest_algo_category_order(param);
  304. for (auto category : suggest_category_order) {
  305. auto&& origin_algos = select_algo_type({algo_data_type, category});
  306. ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
  307. for (auto i : origin_algos) {
  308. bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
  309. param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
  310. negative_attr);
  311. if (usable_attribute &&
  312. static_cast<AlgoBase*>(i)->get_workspace(param) <=
  313. workspace_limit_in_bytes) {
  314. //! store the first usable algo if no prefer algo, choose it as
  315. //! the target algo
  316. if (!heuristic_algo) {
  317. heuristic_algo = i;
  318. }
  319. //! choose the first prefer algo
  320. if (i->is_preferred(param)) {
  321. return i;
  322. }
  323. }
  324. }
  325. if (heuristic_algo) {
  326. return heuristic_algo;
  327. }
  328. }
  329. return nullptr;
  330. }
  331. std::vector<ConvolutionImpl::Algorithm*>
  332. ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  333. std::vector<Algorithm*> ret;
  334. std::vector<Algorithm*> prefer_algos;
  335. for (auto&& i : get_all_packed_algo()) {
  336. if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
  337. if (i->is_preferred(param)) {
  338. prefer_algos.push_back(i);
  339. } else {
  340. ret.push_back(i);
  341. }
  342. }
  343. }
  344. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  345. return ret;
  346. }
  347. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
  348. const AlgorithmDesc& desc) {
  349. if (!desc.valid()) {
  350. return nullptr;
  351. } else {
  352. switch (desc.handle_type) {
  353. case Handle::HandleType::FALLBACK: {
  354. const auto& map = algo_pack().all_algos_map();
  355. megdnn_assert(map.find(desc) != map.end());
  356. return map.at(desc);
  357. }
  358. case Handle::HandleType::NAIVE: {
  359. auto algo = static_cast<naive::HandleImpl*>(handle())
  360. ->default_conv_fwd_algo();
  361. megdnn_assert(algo->info().desc == desc);
  362. return algo;
  363. }
  364. default:
  365. megdnn_throw("Unknown handle type");
  366. return nullptr;
  367. }
  368. }
  369. }
  370. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  371. const NCBKernSizeParam& param, size_t workspace_size) {
  372. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  373. return algo;
  374. }
  375. if (!m_prev_selected_algo ||
  376. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  377. m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
  378. param, workspace_size, AlgoAttribute::DEFAULT,
  379. AlgoAttribute::DEFAULT);
  380. m_prev_selected_algo_sizep = param;
  381. }
  382. return m_prev_selected_algo;
  383. }
  384. SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
  385. const NCBKernSizeParam& param) const {
  386. static CpuOprDelegationStorage<1> storage;
  387. auto conv_bias_opr = storage.get<ConvBias, 0>();
  388. auto conv_bias_param =
  389. ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
  390. return static_cast<ConvBiasImpl*>(conv_bias_opr)
  391. ->suggest_algo_category_order(conv_bias_param);
  392. }
  393. const char* ConvolutionImpl::get_algorithm_set_name() const {
  394. // fallback version 0
  395. return "F0";
  396. }
  397. ConvolutionImpl::AlgoDataType
  398. ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
  399. if (src_type.enumv() == DTypeEnum::Float32) {
  400. return ConvolutionImpl::AlgoDataType::FLOAT32;
  401. #if !MEGDNN_DISABLE_FLOAT16
  402. } else if (src_type.enumv() == DTypeEnum::Float16) {
  403. return ConvolutionImpl::AlgoDataType::FLOAT16;
  404. #endif
  405. } else if (src_type.enumv() == DTypeEnum::Int8 ||
  406. src_type.enumv() == DTypeEnum::QuantizedS8) {
  407. if (dst_type.enumv() == DTypeEnum::Int16) {
  408. return ConvolutionImpl::AlgoDataType::INT8X8X16;
  409. } else {
  410. return ConvolutionImpl::AlgoDataType::QINT8X8X32;
  411. }
  412. } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  413. return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
  414. } else if (src_type.enumv() == DTypeEnum::QuantizedS4) {
  415. return ConvolutionImpl::AlgoDataType::QINT4x4x32;
  416. } else {
  417. megdnn_throw(ssprintf("not support data type of %s * %s -> %s\n",
  418. src_type.name(), filter_type.name(),
  419. dst_type.name()));
  420. }
  421. }
  422. /* ===================== ConvolutionBackwardData ===================== */
  423. class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
  424. AlgoNaive algo_naive;
  425. AlgoDirect algo_direct;
  426. AlgoMatrixMul algo_matmul;
  427. SmallVector<AlgoBase*> m_all_algos;
  428. AlgoBase::Mapper m_all_algos_map;
  429. public:
  430. AlgoPack() {
  431. m_all_algos.emplace_back(&algo_matmul);
  432. m_all_algos.emplace_back(&algo_direct);
  433. m_all_algos.emplace_back(&algo_naive);
  434. for (auto&& algo : m_all_algos) {
  435. m_all_algos_map.emplace(algo->info().desc, algo);
  436. }
  437. }
  438. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  439. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  440. };
  441. const ConvolutionBackwardDataImpl::AlgoPack&
  442. ConvolutionBackwardDataImpl::algo_pack() {
  443. static AlgoPack algo_pack;
  444. return algo_pack;
  445. }
  446. SmallVector<ConvolutionBackwardDataImpl::AlgoBase*>
  447. ConvolutionBackwardDataImpl::get_all_packed_algo() {
  448. return algo_pack().all_algos();
  449. }
  450. void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
  451. _megdnn_tensor_in diff,
  452. _megdnn_tensor_out grad,
  453. _megdnn_workspace workspace) {
  454. if (param().format == param::Convolution::Format::NHWCD4 ||
  455. param().format == param::Convolution::Format::NCHW4 ||
  456. (param().format == param::Convolution::Format::NCHW &&
  457. grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  458. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad,
  459. workspace);
  460. }
  461. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  462. return exec_with_ncb_kern(fparam);
  463. }
  464. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  465. const TensorLayout& filter, const TensorLayout& diff,
  466. const TensorLayout& grad) {
  467. TensorLayoutArray layouts{filter, diff, grad};
  468. HeuristicCache::Key key{this->handle(), this->get_opr_type(),
  469. layouts.data(), layouts.size(), &this->param(),
  470. sizeof(this->param())};
  471. auto rst = HeuristicCache::instance().get(key);
  472. if (rst.policy.algo.valid()) {
  473. return rst.workspace;
  474. }
  475. if (param().format == param::Convolution::Format::NHWCD4 ||
  476. param().format == param::Convolution::Format::NCHW4 ||
  477. (param().format == param::Convolution::Format::NCHW &&
  478. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  479. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  480. filter, diff, grad);
  481. }
  482. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  483. return get_workspace_with_ncb(fparam);
  484. }
  485. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  486. ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
  487. const TensorLayout& diff,
  488. const TensorLayout& grad) {
  489. if (param().format == param::Convolution::Format::NHWCD4 ||
  490. param().format == param::Convolution::Format::NCHW4 ||
  491. (param().format == param::Convolution::Format::NCHW &&
  492. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  493. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  494. filter, diff, grad);
  495. }
  496. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  497. auto ret = get_all_algorithms_with_ncb(fparam);
  498. return ret;
  499. }
  500. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  501. ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter,
  502. const TensorLayout& diff,
  503. const TensorLayout& grad) {
  504. auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter,diff,grad);
  505. megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm");
  506. return ret_safe;
  507. }
  508. ConvolutionBackwardDataImpl::Algorithm*
  509. ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  510. const TensorLayout& filter, const TensorLayout& diff,
  511. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  512. const AlgoAttribute& positive_attr,
  513. const AlgoAttribute& negative_attr) {
  514. if (param().format == param::Convolution::Format::NHWCD4 ||
  515. param().format == param::Convolution::Format::NCHW4 ||
  516. (param().format == param::Convolution::Format::NCHW &&
  517. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  518. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  519. filter, diff, grad, workspace_limit_in_bytes, positive_attr,
  520. negative_attr);
  521. }
  522. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  523. return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
  524. positive_attr, negative_attr);
  525. }
  526. ConvolutionBackwardDataImpl::NCBKernSizeParam
  527. ConvolutionBackwardDataImpl::make_ncb_kern_size_param(
  528. const TensorLayout& filter, const TensorLayout& diff,
  529. const TensorLayout& grad) {
  530. auto safe_u32 = [](size_t v) -> uint32_t {
  531. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  532. "value too large: %zu", v);
  533. return v;
  534. };
  535. size_t spatial_pos;
  536. if (param().format == Param::Format::NCHW) {
  537. spatial_pos = 2;
  538. } else {
  539. megdnn_assert(param().format == Param::Format::NHWC,
  540. "invalid conv format");
  541. spatial_pos = 1;
  542. }
  543. auto grad_fwd = grad;
  544. auto filter_fwd = filter;
  545. auto diff_fwd = diff;
  546. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  547. return {
  548. safe_u32(diff[0]),
  549. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  550. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  551. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  552. diff.dtype,
  553. filter.dtype,
  554. grad.dtype,
  555. diff,
  556. filter,
  557. grad,
  558. diff.stride[0],
  559. grad.stride[0],
  560. 0,
  561. 0,
  562. 0,
  563. param().compute_mode,
  564. };
  565. }
  566. ConvolutionBackwardDataImpl::NCBKernParam
  567. ConvolutionBackwardDataImpl::make_ncb_kern_param(_megdnn_tensor_in filter,
  568. _megdnn_tensor_in diff,
  569. _megdnn_tensor_out grad,
  570. _megdnn_workspace workspace) {
  571. NCBKernParam ret;
  572. static_cast<NCBKernSizeParam&>(ret) =
  573. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  574. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  575. megdnn_assert(workspace.size >= required_workspace_in_bytes,
  576. "required workspace: %zu; provided workspace: %zu",
  577. required_workspace_in_bytes, workspace.size);
  578. ret.filter_ptr = filter.raw_ptr;
  579. ret.diff_ptr = diff.raw_ptr;
  580. ret.grad_ptr = grad.raw_ptr;
  581. ret.workspace_ptr = workspace.raw_ptr;
  582. ret.workspace_size = workspace.size;
  583. return ret;
  584. }
  585. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(
  586. const NCBKernParam& param) {
  587. auto p1g = param;
  588. auto group = p1g.filter_meta.group;
  589. p1g.filter_meta.group = 1;
  590. auto&& algo = get_algorithm(p1g);
  591. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  592. if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
  593. auto run = [kptr, param]() { kptr(param); };
  594. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  595. } else {
  596. megdnn_assert(p1g.filter_meta.format == Param::Format::NCHW ||
  597. p1g.filter_meta.format == Param::Format::NHWC,
  598. "invalid conv format");
  599. auto run = [kptr, p1g_orig = p1g, group]() {
  600. auto p1g = p1g_orig;
  601. ptrdiff_t istrd, fstrd, ostrd;
  602. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  603. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  604. p1g.filter_type.size();
  605. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  606. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  607. p1g.diff_extra_mem_size =
  608. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  609. p1g.filter_extra_mem_size =
  610. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  611. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  612. p1g.filter_type.size();
  613. p1g.grad_extra_mem_size =
  614. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  615. if (p1g.filter_meta.format == Param::Format::NCHW) {
  616. istrd *= p1g.isz[0] * p1g.isz[1];
  617. ostrd *= p1g.osz[0] * p1g.osz[1];
  618. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  619. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  620. } else {
  621. // must be NHWC. No action performed.
  622. }
  623. for (size_t i = 0; i < group; ++i) {
  624. kptr(p1g);
  625. incr_ptr(p1g.diff_ptr, istrd);
  626. incr_ptr(p1g.filter_ptr, fstrd);
  627. incr_ptr(p1g.grad_ptr, ostrd);
  628. p1g.diff_extra_mem_size -= istrd;
  629. p1g.filter_extra_mem_size -= fstrd;
  630. p1g.grad_extra_mem_size -= ostrd;
  631. }
  632. };
  633. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  634. }
  635. }
  636. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  637. const NCBKernSizeParam& param) {
  638. if (param.filter_meta.group != 1) {
  639. auto p1g = param;
  640. p1g.filter_meta.group = 1;
  641. auto algo = get_algorithm(p1g);
  642. return ncb_1g_get_workspace(algo, p1g);
  643. }
  644. auto algo = get_algorithm(param);
  645. return ncb_1g_get_workspace(algo, param);
  646. }
  647. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  648. ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
  649. const NCBKernSizeParam& param) {
  650. if (param.filter_meta.group != 1) {
  651. auto p1g = param;
  652. p1g.filter_meta.group = 1;
  653. return ncb_1g_get_all_algorithms(p1g);
  654. }
  655. return ncb_1g_get_all_algorithms(param);
  656. }
  657. ConvolutionBackwardDataImpl::Algorithm*
  658. ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
  659. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  660. const AlgoAttribute& positive_attr,
  661. const AlgoAttribute& negative_attr) {
  662. if (param.filter_meta.group != 1) {
  663. auto p1g = param;
  664. p1g.filter_meta.group = 1;
  665. return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
  666. positive_attr, negative_attr);
  667. }
  668. return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
  669. positive_attr, negative_attr);
  670. }
  671. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  672. Algorithm* algo, const NCBKernSizeParam& param) {
  673. megdnn_assert(param.filter_meta.group == 1);
  674. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  675. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  676. }
  677. return 0;
  678. }
  679. ConvolutionBackwardDataImpl::ncb_kern_t
  680. ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  681. Algorithm* algo, const NCBKernSizeParam& param) {
  682. megdnn_assert(param.filter_meta.group == 1);
  683. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  684. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  685. }
  686. megdnn_throw("no suitable ConvolutionBackwardData algorithm");
  687. }
  688. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  689. const NCBKernSizeParam& param) {
  690. auto&& fm = param.filter_meta;
  691. auto OC = fm.ocpg, IC = fm.icpg;
  692. return (OC * IC >= 32) ||
  693. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  694. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  695. }
  696. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  697. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
  698. const NCBKernSizeParam& param) {
  699. std::vector<Algorithm*> ret;
  700. std::vector<Algorithm*> prefer_algos;
  701. for (auto&& i : get_all_packed_algo()) {
  702. if (i->usable(this, param)) {
  703. if (i->is_preferred(param)) {
  704. prefer_algos.push_back(i);
  705. } else {
  706. ret.push_back(i);
  707. }
  708. }
  709. }
  710. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  711. return ret;
  712. }
  713. ConvolutionBackwardDataImpl::Algorithm*
  714. ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
  715. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  716. const AlgoAttribute& positive_attr,
  717. const AlgoAttribute& negative_attr) {
  718. for (auto i : ncb_1g_get_all_algorithms(param)) {
  719. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  720. if (i->contain_attribute_all(positive_attr) &&
  721. !i->contain_attribute_any(negative_attr)) {
  722. return i;
  723. }
  724. }
  725. }
  726. megdnn_assert(0,
  727. "no suitable algorithm found within given workspace limit");
  728. }
  729. ConvolutionBackwardDataImpl::Algorithm*
  730. ConvolutionBackwardDataImpl::get_algorithm_from_desc(
  731. const AlgorithmDesc& desc) {
  732. if (!desc.valid()) {
  733. return nullptr;
  734. } else {
  735. switch (desc.handle_type) {
  736. case Handle::HandleType::FALLBACK: {
  737. const auto& map = algo_pack().all_algos_map();
  738. megdnn_assert(map.find(desc) != map.end());
  739. return map.at(desc);
  740. }
  741. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  742. case Handle::HandleType::ARM_COMMON:
  743. case Handle::HandleType::AARCH64:
  744. case Handle::HandleType::ARMV7:
  745. return arm_common::ConvolutionBackwardDataImpl::
  746. get_algo_from_desc(desc);
  747. #endif
  748. case Handle::HandleType::NAIVE: {
  749. auto algo = static_cast<naive::HandleImpl*>(handle())
  750. ->default_conv_bwd_data_algo();
  751. megdnn_assert(algo->info().desc == desc);
  752. return algo;
  753. }
  754. default:
  755. megdnn_throw("Unknown handle type");
  756. return nullptr;
  757. }
  758. }
  759. }
  760. ConvolutionBackwardDataImpl::Algorithm*
  761. ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
  762. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  763. return algo;
  764. }
  765. if (!m_prev_selected_algo ||
  766. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  767. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  768. param, std::numeric_limits<size_t>::max(),
  769. AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  770. m_prev_selected_algo_sizep = param;
  771. }
  772. return m_prev_selected_algo;
  773. }
  774. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  775. // fallback version 0
  776. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  777. }
  778. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台