You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. /**
  2. g * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. *
  4. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. *
  6. * Unless required by applicable law or agreed to in writing,
  7. * software distributed under the License is distributed on an
  8. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  9. * implied.
  10. */
  11. #include "src/fallback/conv_bias/opr_impl.h"
  12. #include "src/common/algo_chooser.h"
  13. #include "src/common/metahelper.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/common/utils.h"
  16. #include "src/fallback/conv_bias/algos.h"
  17. #include "src/fallback/conv_bias/conv1x1/algos.h"
  18. #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
  19. #include "src/fallback/conv_bias/im2col/algos.h"
  20. #include "src/fallback/convolution/opr_impl.h"
  21. #include "src/naive/convolution/algorithms.h"
  22. #include "src/naive/handle.h"
  23. #if MEGDNN_X86
  24. #include "src/x86/conv_bias/opr_impl.h"
  25. #elif MEGDNN_AARCH64
  26. #include "src/aarch64/conv_bias/opr_impl.h"
  27. #elif MEGDNN_ARMV7
  28. #include "src/armv7/conv_bias/opr_impl.h"
  29. #endif
  30. #include <cstring>
  31. using namespace megdnn;
  32. using namespace fallback;
  33. size_t megdnn::fallback::pack_size(param::ConvBias::Format format) {
  34. switch (format) {
  35. case param::ConvBias::Format::NCHW44:
  36. case param::ConvBias::Format::NCHW44_DOT:
  37. case param::ConvBias::Format::NCHW4:
  38. return 4_z;
  39. case param::ConvBias::Format::NCHW88:
  40. return 8_z;
  41. default:
  42. return 1_z;
  43. }
  44. }
  45. namespace {
  46. template <typename T>
  47. void incr_ptr(T*& dst, ptrdiff_t delta) {
  48. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  49. }
  50. } // namespace
  51. #if MEGDNN_X86
  52. #define SKIP_GEMV()
  53. //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
  54. //! fallback to naive implementation, which may cause performance very low, so
  55. //! here we just enable im2col for gemv in x86 backend.
  56. //! FIXME: remove it when we add direct conv support for int8x8x16
  57. #else
  58. #define SKIP_GEMV() \
  59. if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \
  60. continue; \
  61. }
  62. #endif
  63. class ConvBiasImpl::AlgoPack : NonCopyableObj {
  64. AlgoNaive algo_naive;
  65. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  66. SmallVector<AlgoBase*> m_all_algos;
  67. AlgoBase::Mapper m_all_algos_map;
  68. public:
  69. AlgoPack() {
  70. refhold.emplace_back(new AlgoConv1x1Gemv());
  71. m_all_algos.emplace_back(refhold.back().get());
  72. static CpuOprDelegationStorage<> storage;
  73. auto matmul_opr = storage.get<MatrixMul>();
  74. auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
  75. ->get_all_packed_algo();
  76. for (auto&& algo : matmul_algos) {
  77. #if MEGDNN_X86
  78. //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
  79. //! fallback to naive implementation, which may cause performance very low, so
  80. //! here we just enable im2col for gemv in x86 backend.
  81. //! FIXME: remove it when we add direct conv support for int8x8x16
  82. #else
  83. if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
  84. continue;
  85. }
  86. #endif
  87. //! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci
  88. //! test. so we just disable all im2col and conv1x1 in riscv64
  89. //! FIXME: remove it when impl postprocess for riscv64
  90. #if !MEGDNN_RISCV64
  91. for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) {
  92. refhold.emplace_back(new AlgoIm2col(
  93. static_cast<MatrixMulImpl::AlgoBase*>(algo), ohw_tile_size));
  94. m_all_algos.emplace_back(refhold.back().get());
  95. }
  96. for (size_t oc_tile_size : {48, 24}) {
  97. refhold.emplace_back(new AlgoConv1x1(
  98. static_cast<MatrixMulImpl::AlgoBase*>(algo), oc_tile_size));
  99. m_all_algos.emplace_back(refhold.back().get());
  100. }
  101. #endif
  102. #if 0
  103. //! As these algos maybe very slow, it will make fastrun search slow, so
  104. //! we disable it, but for the test of strategyhelper, we just keep it.
  105. //! FIXME: I do not know a better way to do it.
  106. refhold.emplace_back(new AlgoWinogradF32(
  107. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  108. m_all_algos.emplace_back(refhold.back().get());
  109. refhold.emplace_back(new AlgoWinogradF32_4x4(
  110. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  111. m_all_algos.emplace_back(refhold.back().get());
  112. refhold.emplace_back(new AlgoWinogradQS8(
  113. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  114. m_all_algos.emplace_back(refhold.back().get());
  115. refhold.emplace_back(new AlgoWinogradQS8_8x8(
  116. static_cast<MatrixMulImpl::AlgoBase*>(algo)));
  117. m_all_algos.emplace_back(refhold.back().get());
  118. #endif
  119. }
  120. m_all_algos.emplace_back(&algo_naive);
  121. for (auto&& algo : m_all_algos) {
  122. m_all_algos_map.emplace(algo->info().desc, algo);
  123. }
  124. }
  125. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  126. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  127. };
  128. const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
  129. static AlgoPack algo_pack;
  130. return algo_pack;
  131. }
  132. SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() {
  133. return algo_pack().all_algos();
  134. }
  135. SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type(
  136. ConvAlgoTypePack target_type) {
  137. megdnn_assert(
  138. nr_type_contain(target_type.data_type),
  139. "ConvBias algo selection only support one type");
  140. SmallVector<ConvBiasImpl::AlgoBase*> algos;
  141. for (auto&& algo : get_all_packed_algo()) {
  142. auto algo_type = algo->get_algo_type();
  143. if (contain_data_type(algo_type.data_type, target_type.data_type) &&
  144. algo_type.algo_category == target_type.algo_category) {
  145. algos.push_back(algo);
  146. }
  147. }
  148. return algos;
  149. }
  150. bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) {
  151. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  152. }
  153. #define NCB_ALGO_FUNC(name, algo, param) static_cast<AlgoBase*>(algo)->name(param)
  154. void ConvBiasImpl::exec(
  155. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  156. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  157. const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
  158. check_exec(
  159. src.layout, filter.layout, bias.layout, z.layout, dst.layout,
  160. workspace.size, preprocessed_filter);
  161. auto fparam =
  162. make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter);
  163. auto&& algo = get_algorithm(fparam, workspace.size);
  164. if (!is_naive_algo(algo) &&
  165. NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
  166. exec_with_ncb_kern(fparam, algo);
  167. } else {
  168. naive::ConvBiasForwardImpl::exec(
  169. src, filter, bias, z, dst, preprocessed_filter, workspace);
  170. }
  171. }
  172. void ConvBiasImpl::exec_preprocess(
  173. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  174. _megdnn_tensor_in bias, const TensorLayout& z_layout,
  175. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  176. _megdnn_workspace workspace) {
  177. //! exec_preprocess currently only support preprocess weights and bias
  178. //! before exec, src/dst/z will be ignored, just set to nullptr
  179. TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
  180. auto fparam =
  181. make_ncb_kern_param(src, filter, bias, dst, workspace, preprocessed_filter);
  182. //! should not pass workspace_size limit otherwise can not find match algo
  183. auto&& algo = get_algorithm(fparam);
  184. if (!is_naive_algo(algo) &&
  185. NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= workspace.size) {
  186. exec_preprocess_with_ncb_kern(fparam, algo);
  187. } else {
  188. naive::ConvBiasForwardImpl::exec_preprocess(
  189. src_layout, filter, bias, z_layout, dst_layout, preprocessed_filter,
  190. workspace);
  191. }
  192. }
  193. size_t ConvBiasImpl::get_workspace_in_bytes(
  194. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  195. const TensorLayout& z, const TensorLayout& dst,
  196. const PreprocessedFilter* preprocessed_filter) {
  197. TensorLayoutArray layouts{src, filter, bias, z, dst};
  198. AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
  199. layouts.data(), layouts.size(),
  200. &this->param(), sizeof(this->param())};
  201. auto rst = AlgorithmCache::instance().get(key);
  202. if (rst.policy.algo.valid()) {
  203. return rst.workspace;
  204. }
  205. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, preprocessed_filter);
  206. auto&& algo = get_algorithm(fparam);
  207. if (is_naive_algo(algo)) {
  208. return naive::ConvBiasForwardImpl::get_workspace_in_bytes(
  209. src, filter, bias, z, dst, preprocessed_filter);
  210. } else {
  211. return NCB_ALGO_FUNC(get_workspace, algo, fparam);
  212. }
  213. }
  214. size_t ConvBiasImpl::get_preprocess_workspace_in_bytes(
  215. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  216. const TensorLayout& z, const TensorLayout& dst) {
  217. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  218. auto&& algo = get_algorithm(fparam);
  219. if (is_naive_algo(algo)) {
  220. return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes(
  221. src, filter, bias, z, dst);
  222. } else {
  223. return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
  224. }
  225. }
  226. SmallVector<TensorLayout> ConvBiasImpl::deduce_preprocessed_filter_layout(
  227. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  228. const TensorLayout& z, const TensorLayout& dst) {
  229. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  230. auto&& algo = get_algorithm(fparam);
  231. if (is_naive_algo(algo)) {
  232. return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout(
  233. src, filter, bias, z, dst);
  234. } else {
  235. return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
  236. }
  237. }
  238. std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms(
  239. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  240. const TensorLayout& z, const TensorLayout& dst) {
  241. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  242. auto ret = get_all_algorithms_with_ncb(fparam);
  243. if (ret.empty()) {
  244. return naive::ConvBiasForwardImpl::get_all_algorithms_safe(
  245. src, filter, bias, z, dst);
  246. }
  247. return ret;
  248. }
  249. std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_safe(
  250. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  251. const TensorLayout& z, const TensorLayout& dst) {
  252. auto ret_safe = ConvBiasImpl::get_all_algorithms(src, filter, bias, z, dst);
  253. return ret_safe;
  254. }
  255. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
  256. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  257. const TensorLayout& z, const TensorLayout& dst, size_t workspace_limit_in_bytes,
  258. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
  259. auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
  260. auto result = get_algorithm_heuristic_with_ncb(
  261. fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
  262. if (result == nullptr) {
  263. result = naive::ConvBiasForwardImpl::get_algorithm_heuristic(
  264. src, filter, bias, z, dst, workspace_limit_in_bytes, positive_attr,
  265. negative_attr);
  266. }
  267. return result;
  268. }
  269. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
  270. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  271. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
  272. if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
  273. return nullptr;
  274. }
  275. auto algo_data_type = param.deduce_algo_data_type();
  276. auto suggest_category_order = suggest_algo_category_order(param);
  277. for (auto category : suggest_category_order) {
  278. auto&& origin_algos = select_algo_type({algo_data_type, category});
  279. ConvBiasImpl::Algorithm* heuristic_algo = nullptr;
  280. for (auto i : origin_algos) {
  281. bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
  282. param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
  283. negative_attr);
  284. if (usable_attribute && static_cast<AlgoBase*>(i)->get_workspace(param) <=
  285. workspace_limit_in_bytes) {
  286. //! store the first usable algo if no prefer algo, choose it as
  287. //! the target algo
  288. if (!heuristic_algo) {
  289. heuristic_algo = i;
  290. }
  291. //! choose the first prefer algo
  292. if (i->is_preferred(param)) {
  293. return i;
  294. }
  295. }
  296. }
  297. if (heuristic_algo) {
  298. return heuristic_algo;
  299. }
  300. }
  301. return nullptr;
  302. }
  303. ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
  304. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  305. const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) {
  306. auto safe_u32 = [](size_t v) -> uint32_t {
  307. megdnn_assert(
  308. v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
  309. return v;
  310. };
  311. size_t spatial_pos;
  312. if (param().format == Param::Format::NCHW88 ||
  313. param().format == Param::Format::NCHW8 ||
  314. param().format == Param::Format::NCHW4 ||
  315. param().format == Param::Format::NCHW44 ||
  316. param().format == Param::Format::NCHW44_DOT ||
  317. param().format == Param::Format::NCHW ||
  318. param().format == Param::Format::NCHW32 ||
  319. param().format == Param::Format::NCHW64) {
  320. spatial_pos = 2;
  321. } else if (
  322. param().format == Param::Format::NHWC ||
  323. param().format == Param::Format::NHWCD4) {
  324. spatial_pos = 1;
  325. } else {
  326. megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format));
  327. }
  328. BiasMode bias_mode;
  329. if (bias.ndim == 0) {
  330. bias_mode = BiasMode::NO_BIAS;
  331. } else if (bias.eq_shape(dst)) {
  332. bias_mode = BiasMode::BIAS;
  333. } else {
  334. //! just check the ndim, the detail shape check is in check_exec
  335. megdnn_assert(bias.ndim == dst.ndim);
  336. bias_mode = BiasMode::BROADCAST_CHANNEL_BIAS;
  337. }
  338. static_assert(
  339. sizeof(CanonizedFilterMeta) == sizeof(ConvolutionImpl::CanonizedFilterMeta),
  340. "sizeof CanonizedFilterMeta in convolution and conv_bias "
  341. "should be equal");
  342. auto&& fm = check_layout_fwd(src, filter, dst);
  343. auto& conv_fm = reinterpret_cast<ConvolutionImpl::CanonizedFilterMeta&>(fm);
  344. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  345. ->megcore_dispatcher()
  346. ->nr_threads();
  347. return {{safe_u32(src[0]),
  348. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  349. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  350. conv_fm,
  351. src.dtype,
  352. filter.dtype,
  353. dst.dtype,
  354. src.stride[0],
  355. dst.stride[0],
  356. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  357. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  358. param().compute_mode,
  359. nr_threads,
  360. reinterpret_cast<const ConvolutionForward::PreprocessedFilter*>(
  361. preprocessed_filter)},
  362. bias.dtype,
  363. bias.stride[0],
  364. bias_mode,
  365. param().nonlineMode};
  366. }
  367. ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
  368. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  369. _megdnn_tensor_out dst, _megdnn_workspace workspace,
  370. const PreprocessedFilter* preprocessed_filter) {
  371. NCBKernParam ret;
  372. static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
  373. src.layout, filter.layout, bias.layout, dst.layout, preprocessed_filter);
  374. ret.src_ptr = src.get_ref_ptr();
  375. ret.filter_ptr = filter.get_ref_ptr();
  376. ret.bias_ptr = bias.get_ref_ptr();
  377. ret.dst_ptr = dst.get_ref_ptr();
  378. ret.workspace_ptr = workspace.raw_ptr;
  379. ret.workspace_size = workspace.size;
  380. return ret;
  381. }
  382. void ConvBiasImpl::exec_with_ncb_kern(
  383. const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
  384. auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param);
  385. for (auto&& kernel : ncb_kerns) {
  386. auto run = [kernel, param](size_t index, size_t thread_id) {
  387. CpuNDRange ndrange_id(kernel.global_size, index);
  388. kernel.kern(param, {thread_id, ndrange_id});
  389. };
  390. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
  391. run, kernel.global_size.total_size());
  392. }
  393. }
  394. void ConvBiasImpl::exec_preprocess_with_ncb_kern(
  395. const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
  396. auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param);
  397. for (auto&& kernel : ncb_kerns) {
  398. auto run = [kernel, param](size_t index, size_t thread_id) {
  399. CpuNDRange ndrange_id(kernel.global_size, index);
  400. kernel.kern(param, {thread_id, ndrange_id});
  401. };
  402. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
  403. run, kernel.global_size.total_size());
  404. }
  405. }
  406. std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
  407. const NCBKernSizeParam& param) {
  408. MEGDNN_MARK_USED_VAR(param);
  409. std::vector<Algorithm*> algos;
  410. std::vector<Algorithm*> prefer_algos;
  411. for (auto&& algo : get_all_packed_algo()) {
  412. if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
  413. if (algo->is_preferred(param)) {
  414. prefer_algos.push_back(algo);
  415. } else {
  416. algos.push_back(algo);
  417. }
  418. }
  419. }
  420. //! Prefer algo inserted from begin
  421. algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end());
  422. return algos;
  423. }
  424. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc(
  425. const AlgorithmDesc& desc) {
  426. if (!desc.valid()) {
  427. return nullptr;
  428. } else {
  429. switch (desc.handle_type) {
  430. case Handle::HandleType::FALLBACK: {
  431. const auto& map = algo_pack().all_algos_map();
  432. megdnn_assert(map.find(desc) != map.end());
  433. return map.at(desc);
  434. };
  435. #if MEGDNN_X86
  436. case Handle::HandleType::X86:
  437. return x86::ConvBiasImpl::get_algo_from_desc(desc);
  438. #elif MEGDNN_AARCH64 || MEGDNN_ARMV7
  439. case Handle::HandleType::ARM_COMMON:
  440. return arm_common::ConvBiasImpl::get_algo_from_desc(desc);
  441. #if MEGDNN_AARCH64
  442. case Handle::HandleType::AARCH64:
  443. return aarch64::ConvBiasImpl::get_algo_from_desc(desc);
  444. #else
  445. case Handle::HandleType::ARMV7:
  446. return armv7::ConvBiasImpl::get_algo_from_desc(desc);
  447. #endif
  448. #endif
  449. case Handle::HandleType::NAIVE: {
  450. auto algo = static_cast<naive::HandleImpl*>(handle())
  451. ->default_conv_bias_fwd_algo();
  452. megdnn_assert(algo->info().desc == desc);
  453. return algo;
  454. }
  455. default:
  456. megdnn_throw("Unknown handle type");
  457. return nullptr;
  458. }
  459. }
  460. }
  461. ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
  462. const NCBKernSizeParam& param, size_t workspace_size) {
  463. if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
  464. return nullptr;
  465. }
  466. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  467. return algo;
  468. }
  469. if (!m_prev_selected_algo ||
  470. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  471. m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
  472. param, workspace_size, AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  473. m_prev_selected_algo_sizep = param;
  474. }
  475. return m_prev_selected_algo;
  476. }
  477. SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
  478. const NCBKernSizeParam& param) const {
  479. auto IC = param.filter_meta.icpg;
  480. auto OC = param.filter_meta.ocpg;
  481. auto FH = param.filter_meta.spatial[0];
  482. auto FW = param.filter_meta.spatial[1];
  483. //! TODO: now winograd only support in fast-run
  484. //! im2col + matmul
  485. bool im2col_prefer = (IC >= 32 || OC >= 32);
  486. //! quantized algo use matmul when direct algo is unusable
  487. if (param.src_type.category() == DTypeCategory::QUANTIZED) {
  488. im2col_prefer = is_matmul_quantized_prefer(param);
  489. }
  490. //! conv1x1
  491. im2col_prefer |= (FH == 1 && FW == 1);
  492. if (im2col_prefer) {
  493. return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, AlgoCategory::NAIVE};
  494. } else {
  495. return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, AlgoCategory::NAIVE};
  496. }
  497. }
  498. const char* ConvBiasImpl::get_algorithm_set_name() const {
  499. // fallback version 0
  500. return "F0";
  501. }
  502. namespace megdnn {
  503. namespace fallback {
  504. size_t ConvBiasImpl::NCBKernParam::src_offset(
  505. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  506. size_t group_pack_size, size_t channel_pack_size) const {
  507. size_t batch_offset = batch_id * inp_bs * src_type.size();
  508. size_t group_offset = group_pack_size * group_pack_id * filter_meta.icpg * isz[0] *
  509. isz[1] * src_type.size();
  510. size_t channel_offset =
  511. channel_pack_size * channel_pack_id * isz[0] * isz[1] * src_type.size();
  512. return (batch_offset + group_offset + channel_offset);
  513. }
  514. template <typename T>
  515. const T* ConvBiasImpl::NCBKernParam::src(
  516. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  517. size_t group_pack_size, size_t channel_pack_size) const {
  518. return reinterpret_cast<T*>(
  519. reinterpret_cast<ptrdiff_t>(src_ptr.get_ptr()) +
  520. src_offset(
  521. batch_id, group_pack_id, channel_pack_id, group_pack_size,
  522. channel_pack_size));
  523. }
  524. size_t ConvBiasImpl::NCBKernParam::filter_offset(
  525. size_t group_pack_id, size_t pack_group_size) const {
  526. size_t group_offset = 0_z;
  527. switch (filter_meta.format) {
  528. case Param::Format::NCHW: {
  529. group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
  530. filter_meta.ocpg * filter_meta.spatial[0] *
  531. filter_meta.spatial[1] * filter_type.size();
  532. break;
  533. }
  534. case Param::Format::NCHW88: {
  535. size_t group = filter_meta.group;
  536. size_t icpg = filter_meta.icpg;
  537. size_t ocpg = filter_meta.ocpg;
  538. //! four format of weight layout
  539. //! 1. {oc/8, ic/8, fh, fw, 8, 8},
  540. //! 2. {g, oc/8, ic/8, fh, fw, 8, 8},
  541. //! 3. {g/8, fh, fw, 1, 1, 8}, 4. {oc/8, fh, fw, ic, 8}
  542. megdnn_assert(
  543. (icpg % 8 == 0 && ocpg % 8 == 0) ||
  544. (group % 8 == 0 && icpg == 1 && ocpg == 1 &&
  545. pack_group_size > 1) ||
  546. (group == 1 && ocpg % 8 == 0),
  547. "The filter shepe is not right of nchw88");
  548. group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
  549. filter_meta.ocpg * filter_meta.spatial[0] *
  550. filter_meta.spatial[1] * filter_type.size();
  551. break;
  552. }
  553. case Param::Format::NCHW44_DOT:
  554. case Param::Format::NCHW44: {
  555. size_t group = filter_meta.group;
  556. size_t icpg = filter_meta.icpg;
  557. size_t ocpg = filter_meta.ocpg;
  558. //! four format of weight layout
  559. //! 1. {oc/4, ic/4, fh, fw, 4, 4},
  560. //! 2. {g, oc/4, ic/4, fh, fw, 4, 4},
  561. //! 3. {g/4, fh, fw, 1, 1, 4},
  562. //! 4. {oc/4, fh, fw, ic, 4}
  563. megdnn_assert(
  564. (icpg % 4 == 0 && ocpg % 4 == 0) ||
  565. (group % 4 == 0 && icpg == 1 && ocpg == 1 &&
  566. pack_group_size > 1) ||
  567. (group == 1 && ocpg % 4 == 0),
  568. "The filter shepe is not right of nchw44");
  569. group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
  570. filter_meta.ocpg * filter_meta.spatial[0] *
  571. filter_meta.spatial[1] * filter_type.size();
  572. break;
  573. }
  574. default:
  575. megdnn_assert(0, "other filter format is not support yet");
  576. }
  577. return group_offset;
  578. }
  579. template <typename T>
  580. const T* ConvBiasImpl::NCBKernParam::filter(
  581. size_t group_pack_id, size_t pack_group_size) const {
  582. size_t group_offset = filter_offset(group_pack_id, pack_group_size);
  583. return reinterpret_cast<T*>(
  584. reinterpret_cast<ptrdiff_t>(filter_ptr.get_ptr()) + group_offset);
  585. }
  586. size_t ConvBiasImpl::NCBKernParam::bias_offset(
  587. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  588. size_t group_pack_size, size_t channel_pack_size) const {
  589. size_t batch_offset = 0_z;
  590. size_t group_offset = 0_z;
  591. size_t channel_offset = 0_z;
  592. if (bias_mode == BiasMode::BIAS) {
  593. batch_offset = batch_id * bias_bs * bias_type.size();
  594. group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * osz[0] *
  595. osz[1] * bias_type.size();
  596. channel_offset = channel_pack_size * channel_pack_id * osz[0] * osz[1] *
  597. bias_type.size();
  598. } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
  599. group_offset =
  600. group_pack_size * group_pack_id * filter_meta.ocpg * bias_type.size();
  601. channel_offset = channel_pack_size * channel_pack_id * bias_type.size();
  602. }
  603. return (batch_offset + group_offset + channel_offset);
  604. }
  605. template <typename T>
  606. const T* ConvBiasImpl::NCBKernParam::bias(
  607. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  608. size_t group_pack_size, size_t channel_pack_size) const {
  609. return reinterpret_cast<T*>(
  610. reinterpret_cast<ptrdiff_t>(bias_ptr.get_ptr()) +
  611. bias_offset(
  612. batch_id, group_pack_id, channel_pack_id, group_pack_size,
  613. channel_pack_size));
  614. }
  615. size_t ConvBiasImpl::NCBKernParam::dst_offset(
  616. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  617. size_t group_pack_size, size_t channel_pack_size) const {
  618. size_t batch_offset = batch_id * out_bs * dst_type.size();
  619. size_t group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * osz[0] *
  620. osz[1] * dst_type.size();
  621. size_t channel_offset =
  622. channel_pack_size * channel_pack_id * osz[0] * osz[1] * dst_type.size();
  623. return (batch_offset + group_offset + channel_offset);
  624. }
  625. template <typename T>
  626. T* ConvBiasImpl::NCBKernParam::dst(
  627. size_t batch_id, size_t group_pack_id, size_t channel_pack_id,
  628. size_t group_pack_size, size_t channel_pack_size) const {
  629. return reinterpret_cast<T*>(
  630. reinterpret_cast<ptrdiff_t>(dst_ptr.get_ptr()) +
  631. dst_offset(
  632. batch_id, group_pack_id, channel_pack_id, group_pack_size,
  633. channel_pack_size));
  634. }
  635. #define INST(T) \
  636. template const T* ConvBiasImpl::NCBKernParam::src<T>( \
  637. size_t batch_id, size_t group_id, size_t channel_id, \
  638. size_t group_pack_size, size_t channel_pack_size) const; \
  639. template const T* ConvBiasImpl::NCBKernParam::bias<T>( \
  640. size_t batch_id, size_t group_id, size_t channel_id, \
  641. size_t group_pack_size, size_t channel_pack_size) const; \
  642. template const T* ConvBiasImpl::NCBKernParam::filter<T>( \
  643. size_t group_id, size_t group_pack_size) const; \
  644. template T* ConvBiasImpl::NCBKernParam::dst<T>( \
  645. size_t batch_id, size_t group_id, size_t channel_id, \
  646. size_t group_pack_size, size_t channel_pack_size) const;
  647. #define INST_DT(d) INST(DTypeTrait<d>::ctype)
  648. MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT)
  649. INST(void)
  650. #undef INST
  651. #undef INST_DT
  652. } // namespace fallback
  653. } // namespace megdnn
  654. // vim: syntax=cpp.doxygen