You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. /**
  2. g * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. *
  4. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. *
  6. * Unless required by applicable law or agreed to in writing,
  7. * software distributed under the License is distributed on an
  8. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  9. * implied.
  10. */
  11. #include "src/fallback/conv_bias/opr_impl.h"
  12. #include "src/common/algo_chooser.h"
  13. #include "src/common/metahelper.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/common/utils.h"
  16. #include "src/fallback/conv_bias/algos.h"
  17. #include "src/fallback/conv_bias/conv1x1/algos.h"
  18. #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
  19. #include "src/fallback/conv_bias/gi/fp32/algos.h"
  20. #include "src/fallback/conv_bias/im2col/algos.h"
  21. #include "src/fallback/convolution/opr_impl.h"
  22. #include "src/naive/convolution/algorithms.h"
  23. #include "src/naive/handle.h"
  24. #if MEGDNN_X86
  25. #include "src/x86/conv_bias/opr_impl.h"
  26. #elif MEGDNN_AARCH64
  27. #include "src/aarch64/conv_bias/opr_impl.h"
  28. #elif MEGDNN_ARMV7
  29. #include "src/armv7/conv_bias/opr_impl.h"
  30. #endif
  31. #include <cstring>
  32. using namespace megdnn;
  33. using namespace fallback;
  34. namespace {
  35. //! TODO: imp is_fallback_exclude_gi_or_naive
  36. bool is_naive(const detail::Algorithm* algo) {
  37. return algo->handle_type() == Handle::HandleType::NAIVE;
  38. }
  39. } // anonymous namespace
  40. size_t megdnn::fallback::pack_size(param::ConvBias::Format format) {
  41. switch (format) {
  42. case param::ConvBias::Format::NCHW44:
  43. case param::ConvBias::Format::NCHW44_DOT:
  44. case param::ConvBias::Format::NCHW4:
  45. return 4_z;
  46. case param::ConvBias::Format::NCHW88:
  47. return 8_z;
  48. default:
  49. return 1_z;
  50. }
  51. }
  52. namespace {
  53. template <typename T>
  54. void incr_ptr(T*& dst, ptrdiff_t delta) {
  55. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  56. }
  57. } // namespace
  58. #if MEGDNN_X86
  59. #define SKIP_GEMV()
  60. //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
  61. //! fallback to naive implementation, which may cause performance very low, so
  62. //! here we just enable im2col for gemv in x86 backend.
  63. //! FIXME: remove it when we add direct conv support for int8x8x16
  64. #else
  65. #define SKIP_GEMV() \
  66. if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \
  67. continue; \
  68. }
  69. #endif
  70. class ConvBiasImpl::AlgoPack : NonCopyableObj {
  71. AlgoNaive algo_naive;
  72. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  73. SmallVector<AlgoBase*> m_all_algos;
  74. AlgoBase::Mapper m_all_algos_map;
  75. SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_gi_winograd_algos;
  76. AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
  77. AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
  78. AlgoF32DirectNCHW44 f32_direct_nchw44;
  79. AlgoF32Direct f32_direct;
  80. AlgoF32DirectStride2 f32_direct_stride2;
  81. AlgoF32DirectStride1 f32_direct_stride1;
  82. public:
  83. AlgoPack() {
  84. // fallback gi fp32 algo
  85. m_all_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
  86. m_all_algos.emplace_back(&f32_chanel_wise_nchw44);
  87. m_all_algos.emplace_back(&f32_direct_nchw44);
  88. m_all_algos.emplace_back(&f32_direct_stride1);
  89. m_all_algos.emplace_back(&f32_direct_stride2);
  90. m_all_algos.emplace_back(&f32_direct);
  91. static CpuOprDelegationStorage<2> storage;
  92. auto matmul_opr = storage.get<MatrixMul, 0>();
  93. using MatmulFormat = param::MatrixMul::Format;
  94. auto&& matmul_algos =
  95. static_cast<fallback::MatrixMulImpl*>(matmul_opr)
  96. ->select_algo_type({AlgoDataType::FLOAT32, MatmulFormat::MK4});
  97. for (auto&& algo : matmul_algos) {
  98. if (is_naive(algo))
  99. continue;
  100. for (uint32_t tile_size : {16, 8, 24, 32}) {
  101. refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
  102. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  103. tile_size));
  104. m_gi_winograd_algos.emplace_back(refhold.back().get());
  105. refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
  106. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  107. tile_size));
  108. m_gi_winograd_algos.emplace_back(refhold.back().get());
  109. refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
  110. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  111. tile_size));
  112. m_gi_winograd_algos.emplace_back(refhold.back().get());
  113. refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
  114. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  115. tile_size));
  116. m_gi_winograd_algos.emplace_back(refhold.back().get());
  117. //! uncomment this when low precision mode is done
  118. #if 0
  119. refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
  120. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  121. tile_size));
  122. m_gi_winograd_algos.emplace_back(refhold.back().get());
  123. #endif
  124. }
  125. }
  126. matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
  127. ->select_algo_type(
  128. {AlgoDataType::FLOAT32, MatmulFormat::DEFAULT});
  129. for (auto&& algo : matmul_algos) {
  130. if (is_naive(algo))
  131. continue;
  132. for (uint32_t tile_size : {16, 8, 24, 32}) {
  133. refhold.emplace_back(new AlgoFP32WinogradF63(
  134. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  135. tile_size));
  136. m_gi_winograd_algos.emplace_back(refhold.back().get());
  137. refhold.emplace_back(new AlgoFP32WinogradF54(
  138. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  139. tile_size));
  140. m_gi_winograd_algos.emplace_back(refhold.back().get());
  141. refhold.emplace_back(new AlgoFP32WinogradF45(
  142. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  143. tile_size));
  144. m_gi_winograd_algos.emplace_back(refhold.back().get());
  145. }
  146. }
  147. for (auto&& algo : m_gi_winograd_algos) {
  148. m_all_algos.emplace_back(algo);
  149. }
  150. // end fallback gi fp32 algo
  151. refhold.emplace_back(new AlgoConv1x1Gemv());
  152. m_all_algos.emplace_back(refhold.back().get());
  153. matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
  154. ->get_all_packed_algo();
  155. for (auto&& algo : matmul_algos) {
  156. #if MEGDNN_X86
  157. //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
  158. //! fallback to naive implementation, which may cause performance very low, so
  159. //! here we just enable im2col for gemv in x86 backend.
  160. //! FIXME: remove it when we add direct conv support for int8x8x16
  161. #else
  162. if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
  163. continue;
  164. }
  165. #endif
  166. //! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci
  167. //! test. so we just disable all im2col and conv1x1 in riscv64
  168. //! FIXME: remove it when impl postprocess for riscv64
  169. #if !MEGDNN_RISCV64
  170. for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) {
  171. refhold.emplace_back(new AlgoIm2col(
  172. static_cast<MatrixMulImpl::AlgoBase*>(algo), ohw_tile_size));
  173. m_all_algos.emplace_back(refhold.back().get());
  174. }
  175. for (size_t oc_tile_size : {48, 24}) {
  176. refhold.emplace_back(new AlgoConv1x1(
  177. static_cast<MatrixMulImpl::AlgoBase*>(algo), oc_tile_size));
  178. m_all_algos.emplace_back(refhold.back().get());
  179. }
  180. #endif
  181. #if 0
  182. //! As these algos maybe very slow, it will make fastrun search slow, so
  183. //! we disable it, but for the test of strategyhelper, we just keep it.
  184. //! FIXME: I do not know a better way to do it.
  185. refhold.emplace_back(new AlgoWinogradF32(
  186. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  187. m_all_algos.emplace_back(refhold.back().get());
  188. refhold.emplace_back(new AlgoWinogradF32_4x4(
  189. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  190. m_all_algos.emplace_back(refhold.back().get());
  191. refhold.emplace_back(new AlgoWinogradQS8(
  192. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  193. m_all_algos.emplace_back(refhold.back().get());
  194. refhold.emplace_back(new AlgoWinogradQS8_8x8(
  195. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  196. m_all_algos.emplace_back(refhold.back().get());
  197. #endif
  198. }
  199. m_all_algos.emplace_back(&algo_naive);
  200. for (auto&& algo : m_all_algos) {
  201. m_all_algos_map.emplace(algo->info().desc, algo);
  202. }
  203. }
  204. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  205. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  206. };
  207. const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
  208. static AlgoPack algo_pack;
  209. return algo_pack;
  210. }
  211. SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() {
  212. return algo_pack().all_algos();
  213. }
  214. SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type(
  215. ConvAlgoTypePack target_type) {
  216. megdnn_assert(
  217. nr_type_contain(target_type.data_type),
  218. "ConvBias algo selection only support one type");
  219. SmallVector<ConvBiasImpl::AlgoBase*> algos;
  220. for (auto&& algo : get_all_packed_algo()) {
  221. auto algo_type = algo->get_algo_type();
  222. if (contain_data_type(algo_type.data_type, target_type.data_type) &&
  223. algo_type.algo_category == target_type.algo_category) {
  224. algos.push_back(algo);
  225. }
  226. }
  227. return algos;
  228. }
  229. bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) {
  230. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  231. }
  232. #define NCB_ALGO_FUNC(name, algo, param) static_cast<AlgoBase*>(algo)->name(param)
  233. void ConvBiasImpl::exec(
  234. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  235. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  236. const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
  237. check_exec(
  238. src.layout, filter.layout, bias.layout, z.layout, dst.layout,
  239. workspace.size, preprocessed_filter);
  240. auto fparam =
  241. make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter);
  242. auto&& algo = get_algorithm(fparam, workspace.size);
  243. if (!is_naive_algo(algo) &&
  244. NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
  245. exec_with_ncb_kern(fparam, algo);
  246. } else {
  247. naive::ConvBiasForwardImpl::exec(
  248. src, filter, bias, z, dst, preprocessed_filter, workspace);
  249. }
  250. }
  251. void ConvBiasImpl::exec_preprocess(
  252. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  253. _megdnn_tensor_in bias, const TensorLayout& z_layout,
  254. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  255. _megdnn_workspace workspace) {
  256. //! exec_preprocess currently only support preprocess weights and bias
  257. //! before exec, src/dst/z will be ignored, just set to nullptr
  258. TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
  259. auto fparam =
  260. make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter);
  261. //! should not pass workspace_size limit otherwise can not find match algo
  262. auto&& algo = get_algorithm(fparam);
  263. if (!is_naive_algo(algo) &&
  264. NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= workspace.size) {
  265. exec_preprocess_with_ncb_kern(fparam, algo);
  266. } else {
  267. naive::ConvBiasForwardImpl::exec_preprocess(
  268. src_layout, filter, bias, z_layout, dst_layout, preprocessed_filter,
  269. workspace);
  270. }
  271. }
  272. size_t ConvBiasImpl::get_workspace_in_bytes(
  273. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  274. const TensorLayout& z, const TensorLayout& dst,
  275. const PreprocessedFilter* preprocessed_filter) {
  276. TensorLayoutArray layouts{src, filter, bias, z, dst};
  277. AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
  278. layouts.data(), layouts.size(),
  279. &this->param(), sizeof(this->param())};
  280. auto rst = AlgorithmCache::instance().get(key);
  281. if (rst.policy.algo.valid()) {
  282. return rst.workspace;
  283. }
  284. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, preprocessed_filter);
  285. auto&& algo = get_algorithm(fparam);
  286. if (is_naive_algo(algo)) {
  287. return naive::ConvBiasForwardImpl::get_workspace_in_bytes(
  288. src, filter, bias, z, dst, preprocessed_filter);
  289. } else {
  290. return NCB_ALGO_FUNC(get_workspace, algo, fparam);
  291. }
  292. }
  293. size_t ConvBiasImpl::get_preprocess_workspace_in_bytes(
  294. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  295. const TensorLayout& z, const TensorLayout& dst) {
  296. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  297. auto&& algo = get_algorithm(fparam);
  298. if (is_naive_algo(algo)) {
  299. return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes(
  300. src, filter, bias, z, dst);
  301. } else {
  302. return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
  303. }
  304. }
  305. SmallVector<TensorLayout> ConvBiasImpl::deduce_preprocessed_filter_layout(
  306. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  307. const TensorLayout& z, const TensorLayout& dst) {
  308. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  309. auto&& algo = get_algorithm(fparam);
  310. if (is_naive_algo(algo)) {
  311. return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout(
  312. src, filter, bias, z, dst);
  313. } else {
  314. return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
  315. }
  316. }
  317. std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms(
  318. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  319. const TensorLayout& z, const TensorLayout& dst) {
  320. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  321. auto ret = get_all_algorithms_with_ncb(fparam);
  322. if (ret.empty()) {
  323. return naive::ConvBiasForwardImpl::get_all_algorithms_safe(
  324. src, filter, bias, z, dst);
  325. }
  326. return ret;
  327. }
  328. std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_safe(
  329. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  330. const TensorLayout& z, const TensorLayout& dst) {
  331. auto ret_safe = ConvBiasImpl::get_all_algorithms(src, filter, bias, z, dst);
  332. return ret_safe;
  333. }
  334. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
  335. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  336. const TensorLayout& z, const TensorLayout& dst, size_t workspace_limit_in_bytes,
  337. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
  338. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  339. auto result = get_algorithm_heuristic_with_ncb(
  340. fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
  341. if (result == nullptr) {
  342. result = naive::ConvBiasForwardImpl::get_algorithm_heuristic(
  343. src, filter, bias, z, dst, workspace_limit_in_bytes, positive_attr,
  344. negative_attr);
  345. }
  346. return result;
  347. }
  348. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
  349. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  350. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
  351. if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
  352. return nullptr;
  353. }
  354. auto algo_data_type = param.deduce_algo_data_type();
  355. auto suggest_category_order = suggest_algo_category_order(param);
  356. for (auto category : suggest_category_order) {
  357. auto&& origin_algos = select_algo_type({algo_data_type, category});
  358. ConvBiasImpl::Algorithm* heuristic_algo = nullptr;
  359. for (auto i : origin_algos) {
  360. bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
  361. param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
  362. negative_attr);
  363. if (usable_attribute && static_cast<AlgoBase*>(i)->get_workspace(param) <=
  364. workspace_limit_in_bytes) {
  365. //! store the first usable algo if no prefer algo, choose it as
  366. //! the target algo
  367. if (!heuristic_algo) {
  368. heuristic_algo = i;
  369. }
  370. //! choose the first prefer algo
  371. if (i->is_preferred(param)) {
  372. return i;
  373. }
  374. }
  375. }
  376. if (heuristic_algo) {
  377. return heuristic_algo;
  378. }
  379. }
  380. return nullptr;
  381. }
  382. ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
  383. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  384. const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) {
  385. auto safe_u32 = [](size_t v) -> uint32_t {
  386. megdnn_assert(
  387. v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
  388. return v;
  389. };
  390. size_t spatial_pos;
  391. if (param().format == Param::Format::NCHW88 ||
  392. param().format == Param::Format::NCHW8 ||
  393. param().format == Param::Format::NCHW4 ||
  394. param().format == Param::Format::NCHW44 ||
  395. param().format == Param::Format::NCHW44_DOT ||
  396. param().format == Param::Format::NCHW ||
  397. param().format == Param::Format::NCHW32 ||
  398. param().format == Param::Format::NCHW64) {
  399. spatial_pos = 2;
  400. } else if (
  401. param().format == Param::Format::NHWC ||
  402. param().format == Param::Format::NHWCD4) {
  403. spatial_pos = 1;
  404. } else {
  405. megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format));
  406. }
  407. BiasMode bias_mode;
  408. //! dst only channel BIAS is viewed as BROADCAST_CHANNEL_BIAS
  409. bool dst_only_c = dst[0] == 1 && dst[spatial_pos] == 1 && dst[spatial_pos + 1] == 1;
  410. if (bias.ndim == 0) {
  411. bias_mode = BiasMode::NO_BIAS;
  412. } else if (bias.eq_shape(dst) && !dst_only_c) {
  413. bias_mode = BiasMode::BIAS;
  414. } else {
  415. //! just check the ndim, the detail shape check is in check_exec
  416. megdnn_assert(bias.ndim == dst.ndim);
  417. bias_mode = BiasMode::BROADCAST_CHANNEL_BIAS;
  418. }
  419. static_assert(
  420. sizeof(CanonizedFilterMeta) == sizeof(ConvolutionImpl::CanonizedFilterMeta),
  421. "sizeof CanonizedFilterMeta in convolution and conv_bias "
  422. "should be equal");
  423. auto&& fm = check_layout_fwd(src, filter, dst);
  424. auto& conv_fm = reinterpret_cast<ConvolutionImpl::CanonizedFilterMeta&>(fm);
  425. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  426. ->megcore_dispatcher()
  427. ->nr_threads();
  428. return {{safe_u32(src[0]),
  429. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  430. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  431. conv_fm,
  432. src.dtype,
  433. filter.dtype,
  434. dst.dtype,
  435. src.stride[0],
  436. dst.stride[0],
  437. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  438. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  439. param().compute_mode,
  440. nr_threads,
  441. reinterpret_cast<const ConvolutionForward::PreprocessedFilter*>(
  442. preprocessed_filter)},
  443. bias.dtype,
  444. bias.stride[0],
  445. bias_mode,
  446. param().nonlineMode};
  447. }
  448. ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
  449. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  450. _megdnn_tensor_out dst, _megdnn_workspace workspace,
  451. const PreprocessedFilter* preprocessed_filter) {
  452. NCBKernParam ret;
  453. static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
  454. src.layout, filter.layout, bias.layout, dst.layout, preprocessed_filter);
  455. ret.src_ptr = src.get_ref_ptr();
  456. ret.filter_ptr = filter.get_ref_ptr();
  457. ret.bias_ptr = bias.get_ref_ptr();
  458. ret.dst_ptr = dst.get_ref_ptr();
  459. ret.workspace_ptr = workspace.raw_ptr;
  460. ret.workspace_size = workspace.size;
  461. return ret;
  462. }
  463. void ConvBiasImpl::exec_with_ncb_kern(
  464. const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
  465. auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param);
  466. for (auto&& kernel : ncb_kerns) {
  467. auto run = [kernel, param](size_t index, size_t thread_id) {
  468. CpuNDRange ndrange_id(kernel.global_size, index);
  469. kernel.kern(param, {thread_id, ndrange_id});
  470. };
  471. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
  472. run, kernel.global_size.total_size());
  473. }
  474. }
  475. void ConvBiasImpl::exec_preprocess_with_ncb_kern(
  476. const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
  477. auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param);
  478. for (auto&& kernel : ncb_kerns) {
  479. auto run = [kernel, param](size_t index, size_t thread_id) {
  480. CpuNDRange ndrange_id(kernel.global_size, index);
  481. kernel.kern(param, {thread_id, ndrange_id});
  482. };
  483. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
  484. run, kernel.global_size.total_size());
  485. }
  486. }
  487. std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
  488. const NCBKernSizeParam& param) {
  489. MEGDNN_MARK_USED_VAR(param);
  490. std::vector<Algorithm*> algos;
  491. std::vector<Algorithm*> prefer_algos;
  492. for (auto&& algo : get_all_packed_algo()) {
  493. if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
  494. if (algo->is_preferred(param)) {
  495. prefer_algos.push_back(algo);
  496. } else {
  497. algos.push_back(algo);
  498. }
  499. }
  500. }
  501. //! Prefer algo inserted from begin
  502. algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end());
  503. return algos;
  504. }
  505. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc(
  506. const AlgorithmDesc& desc) {
  507. if (!desc.valid()) {
  508. return nullptr;
  509. } else {
  510. switch (desc.handle_type) {
  511. case Handle::HandleType::FALLBACK: {
  512. const auto& map = algo_pack().all_algos_map();
  513. megdnn_assert(map.find(desc) != map.end());
  514. return map.at(desc);
  515. };
  516. #if MEGDNN_X86
  517. case Handle::HandleType::X86:
  518. return x86::ConvBiasImpl::get_algo_from_desc(desc);
  519. #elif MEGDNN_AARCH64 || MEGDNN_ARMV7
  520. case Handle::HandleType::ARM_COMMON:
  521. return arm_common::ConvBiasImpl::get_algo_from_desc(desc);
  522. #if MEGDNN_AARCH64
  523. case Handle::HandleType::AARCH64:
  524. return aarch64::ConvBiasImpl::get_algo_from_desc(desc);
  525. #else
  526. case Handle::HandleType::ARMV7:
  527. return armv7::ConvBiasImpl::get_algo_from_desc(desc);
  528. #endif
  529. #endif
  530. case Handle::HandleType::NAIVE: {
  531. auto algo = static_cast<naive::HandleImpl*>(handle())
  532. ->default_conv_bias_fwd_algo();
  533. megdnn_assert(algo->info().desc == desc);
  534. return algo;
  535. }
  536. default:
  537. megdnn_throw("Unknown handle type");
  538. return nullptr;
  539. }
  540. }
  541. }
  542. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
  543. const NCBKernSizeParam& param, size_t workspace_size) {
  544. if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
  545. return nullptr;
  546. }
  547. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  548. return algo;
  549. }
  550. if (!m_prev_selected_algo ||
  551. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  552. m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
  553. param, workspace_size, AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  554. m_prev_selected_algo_sizep = param;
  555. }
  556. return m_prev_selected_algo;
  557. }
  558. SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
  559. const NCBKernSizeParam& param) const {
  560. auto IC = param.filter_meta.icpg;
  561. auto OC = param.filter_meta.ocpg;
  562. auto FH = param.filter_meta.spatial[0];
  563. auto FW = param.filter_meta.spatial[1];
  564. //! TODO: now winograd only support in fast-run
  565. //! im2col + matmul
  566. bool im2col_prefer = (IC >= 32 || OC >= 32);
  567. //! quantized algo use matmul when direct algo is unusable
  568. if (param.src_type.category() == DTypeCategory::QUANTIZED) {
  569. im2col_prefer = is_matmul_quantized_prefer(param);
  570. }
  571. //! conv1x1
  572. im2col_prefer |= (FH == 1 && FW == 1);
  573. if (im2col_prefer) {
  574. return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, AlgoCategory::NAIVE};
  575. } else {
  576. return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, AlgoCategory::NAIVE};
  577. }
  578. }
  579. const char* ConvBiasImpl::get_algorithm_set_name() const {
  580. // fallback version 0
  581. return "F0";
  582. }
  583. namespace megdnn {
  584. namespace fallback {
  585. size_t ConvBiasImpl::NCBKernParam::src_offset(
  586. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  587. size_t group_pack_size, size_t channel_pack_size) const {
  588. size_t batch_offset = batch_id * inp_bs * src_type.size();
  589. size_t group_offset = group_pack_size * group_pack_id * filter_meta.icpg * isz[0] *
  590. isz[1] * src_type.size();
  591. size_t channel_offset =
  592. channel_pack_size * channel_pack_id * isz[0] * isz[1] * src_type.size();
  593. return (batch_offset + group_offset + channel_offset);
  594. }
  595. template <typename T>
  596. const T* ConvBiasImpl::NCBKernParam::src(
  597. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  598. size_t group_pack_size, size_t channel_pack_size) const {
  599. return reinterpret_cast<T*>(
  600. reinterpret_cast<ptrdiff_t>(src_ptr.get_ptr()) +
  601. src_offset(
  602. batch_id, group_pack_id, channel_pack_id, group_pack_size,
  603. channel_pack_size));
  604. }
  605. size_t ConvBiasImpl::NCBKernParam::filter_offset(
  606. size_t group_pack_id, size_t pack_group_size) const {
  607. size_t group_offset = 0_z;
  608. switch (filter_meta.format) {
  609. case Param::Format::NCHW: {
  610. group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
  611. filter_meta.ocpg * filter_meta.spatial[0] *
  612. filter_meta.spatial[1] * filter_type.size();
  613. break;
  614. }
  615. case Param::Format::NCHW88: {
  616. size_t group = filter_meta.group;
  617. size_t icpg = filter_meta.icpg;
  618. size_t ocpg = filter_meta.ocpg;
  619. //! four format of weight layout
  620. //! 1. {oc/8, ic/8, fh, fw, 8, 8},
  621. //! 2. {g, oc/8, ic/8, fh, fw, 8, 8},
  622. //! 3. {g/8, fh, fw, 1, 1, 8}, 4. {oc/8, fh, fw, ic, 8}
  623. megdnn_assert(
  624. (icpg % 8 == 0 && ocpg % 8 == 0) ||
  625. (group % 8 == 0 && icpg == 1 && ocpg == 1 &&
  626. pack_group_size > 1) ||
  627. (group == 1 && ocpg % 8 == 0),
  628. "The filter shepe is not right of nchw88");
  629. group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
  630. filter_meta.ocpg * filter_meta.spatial[0] *
  631. filter_meta.spatial[1] * filter_type.size();
  632. break;
  633. }
  634. case Param::Format::NCHW44_DOT:
  635. case Param::Format::NCHW44: {
  636. size_t group = filter_meta.group;
  637. size_t icpg = filter_meta.icpg;
  638. size_t ocpg = filter_meta.ocpg;
  639. //! four format of weight layout
  640. //! 1. {oc/4, ic/4, fh, fw, 4, 4},
  641. //! 2. {g, oc/4, ic/4, fh, fw, 4, 4},
  642. //! 3. {g/4, fh, fw, 1, 1, 4},
  643. //! 4. {oc/4, fh, fw, ic, 4}
  644. megdnn_assert(
  645. (icpg % 4 == 0 && ocpg % 4 == 0) ||
  646. (group % 4 == 0 && icpg == 1 && ocpg == 1 &&
  647. pack_group_size > 1) ||
  648. (group == 1 && ocpg % 4 == 0),
  649. "The filter shepe is not right of nchw44");
  650. group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
  651. filter_meta.ocpg * filter_meta.spatial[0] *
  652. filter_meta.spatial[1] * filter_type.size();
  653. break;
  654. }
  655. default:
  656. megdnn_assert(0, "other filter format is not support yet");
  657. }
  658. return group_offset;
  659. }
  660. template <typename T>
  661. const T* ConvBiasImpl::NCBKernParam::filter(
  662. size_t group_pack_id, size_t pack_group_size) const {
  663. size_t group_offset = filter_offset(group_pack_id, pack_group_size);
  664. return reinterpret_cast<T*>(
  665. reinterpret_cast<ptrdiff_t>(filter_ptr.get_ptr()) + group_offset);
  666. }
  667. size_t ConvBiasImpl::NCBKernParam::bias_offset(
  668. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  669. size_t group_pack_size, size_t channel_pack_size) const {
  670. size_t batch_offset = 0_z;
  671. size_t group_offset = 0_z;
  672. size_t channel_offset = 0_z;
  673. if (bias_mode == BiasMode::BIAS) {
  674. batch_offset = batch_id * bias_bs * bias_type.size();
  675. group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * osz[0] *
  676. osz[1] * bias_type.size();
  677. channel_offset = channel_pack_size * channel_pack_id * osz[0] * osz[1] *
  678. bias_type.size();
  679. } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  680. group_offset =
  681. group_pack_size * group_pack_id * filter_meta.ocpg * bias_type.size();
  682. channel_offset = channel_pack_size * channel_pack_id * bias_type.size();
  683. }
  684. return (batch_offset + group_offset + channel_offset);
  685. }
  686. template <typename T>
  687. const T* ConvBiasImpl::NCBKernParam::bias(
  688. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  689. size_t group_pack_size, size_t channel_pack_size) const {
  690. return reinterpret_cast<T*>(
  691. reinterpret_cast<ptrdiff_t>(bias_ptr.get_ptr()) +
  692. bias_offset(
  693. batch_id, group_pack_id, channel_pack_id, group_pack_size,
  694. channel_pack_size));
  695. }
  696. size_t ConvBiasImpl::NCBKernParam::dst_offset(
  697. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  698. size_t group_pack_size, size_t channel_pack_size) const {
  699. size_t batch_offset = batch_id * out_bs * dst_type.size();
  700. size_t group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * osz[0] *
  701. osz[1] * dst_type.size();
  702. size_t channel_offset =
  703. channel_pack_size * channel_pack_id * osz[0] * osz[1] * dst_type.size();
  704. return (batch_offset + group_offset + channel_offset);
  705. }
  706. template <typename T>
  707. T* ConvBiasImpl::NCBKernParam::dst(
  708. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  709. size_t group_pack_size, size_t channel_pack_size) const {
  710. return reinterpret_cast<T*>(
  711. reinterpret_cast<ptrdiff_t>(dst_ptr.get_ptr()) +
  712. dst_offset(
  713. batch_id, group_pack_id, channel_pack_id, group_pack_size,
  714. channel_pack_size));
  715. }
  716. #define INST(T) \
  717. template const T* ConvBiasImpl::NCBKernParam::src<T>( \
  718. size_t batch_id, size_t group_id, size_t channel_id, \
  719. size_t group_pack_size, size_t channel_pack_size) const; \
  720. template const T* ConvBiasImpl::NCBKernParam::bias<T>( \
  721. size_t batch_id, size_t group_id, size_t channel_id, \
  722. size_t group_pack_size, size_t channel_pack_size) const; \
  723. template const T* ConvBiasImpl::NCBKernParam::filter<T>( \
  724. size_t group_id, size_t group_pack_size) const; \
  725. template T* ConvBiasImpl::NCBKernParam::dst<T>( \
  726. size_t batch_id, size_t group_id, size_t channel_id, \
  727. size_t group_pack_size, size_t channel_pack_size) const;
  728. #define INST_DT(d) INST(DTypeTrait<d>::ctype)
  729. MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT)
  730. INST(void)
  731. #undef INST
  732. #undef INST_DT
  733. } // namespace fallback
  734. } // namespace megdnn
  735. // vim: syntax=cpp.doxygen