You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. #pragma once
  2. #include "include/megdnn/thin/function.h"
  3. #include "src/common/utils.h"
  4. #include "src/fallback/conv_bias/common.h"
  5. #include "src/fallback/convolution/opr_impl.h"
  6. #include "src/fallback/matrix_mul/opr_impl.h"
  7. #include "src/naive/conv_bias/opr_impl.h"
  8. #include <unordered_map>
  9. namespace megdnn {
  10. namespace fallback {
  11. /*!
  12. * \brief get the pack_size according to the format
  13. * Note TODO: when remove format from param,
  14. * may using like this "opr::param::format specify"
  15. * */
  16. size_t pack_size(param::ConvBias::Format format);
  17. /*!
  18. * \brief fallback conv bias forward impl
  19. *
  20. * Note: this operator class serves for multiple purposes:
  21. *
  22. * 1. canonizing conv reprs into NCBKernParam and NCBKernSizeParam, and
  23. * subclasses should impl by overriding *_ncb methods
  24. * 2. providing a default impl for group conv by calling ncb_1g* methods
  25. * 3. providing a conv impl faster than naive under some cases
  26. * 4. providing a default impl for choosing heuristic algorithm, by using the
  27. * first algo that fits the workspace limit
  28. */
  29. class ConvBiasImpl : public naive::ConvBiasForwardImpl {
  30. public:
  31. using naive::ConvBiasForwardImpl::ConvBiasForwardImpl;
  32. using AlgoSelectionStrategy = detail::AlgoSelectionStrategy;
  33. using AlgoDataType = detail::AlgoDataType;
  34. //! implemented by exec_with_ncb_kern()
  35. void exec(
  36. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  37. _megdnn_tensor_in z, _megdnn_tensor_out dst, const PreprocessedFilter*,
  38. _megdnn_workspace workspace) override;
  39. bool is_thread_safe() const override { return true; }
  40. void exec_preprocess(
  41. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  42. _megdnn_tensor_in bias, const TensorLayout& z_layout,
  43. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  44. _megdnn_workspace workspace) override;
  45. SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  46. const TensorLayout& src, const TensorLayout& filter,
  47. const TensorLayout& bias, const TensorLayout& z,
  48. const TensorLayout& dst) override;
  49. size_t get_preprocess_workspace_in_bytes(
  50. const TensorLayout& src, const TensorLayout& filter,
  51. const TensorLayout& bias, const TensorLayout& z,
  52. const TensorLayout& dst) override;
  53. //! implemented by get_workspace_with_ncb()
  54. size_t get_workspace_in_bytes(
  55. const TensorLayout& src, const TensorLayout& filter,
  56. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  57. const PreprocessedFilter*) override;
  58. //! implemented by get_all_algorithms_with_ncb()
  59. std::vector<Algorithm*> get_all_algorithms(
  60. const TensorLayout& src, const TensorLayout& filter,
  61. const TensorLayout& bias, const TensorLayout& z,
  62. const TensorLayout& dst) override;
  63. std::vector<Algorithm*> get_all_algorithms_safe(
  64. const TensorLayout& src, const TensorLayout& filter,
  65. const TensorLayout& bias, const TensorLayout& z,
  66. const TensorLayout& dst) override;
  67. //! implemented by get_algorithm_heuristic_with_ncb()
  68. Algorithm* get_algorithm_heuristic(
  69. const TensorLayout& src, const TensorLayout& filter,
  70. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  71. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  72. const AlgoAttribute& negative_attr) override;
  73. //! size param for kernels with non-contiguous batch
  74. struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam {
  75. NCBKernSizeParam() = default;
  76. NCBKernSizeParam(
  77. const ConvolutionImpl::NCBKernSizeParam& param, DType bias_type,
  78. ptrdiff_t bias_bs, BiasMode bias_mode, Param::NonlineMode nonlineMode)
  79. : ConvolutionImpl::NCBKernSizeParam(param),
  80. bias_type{bias_type},
  81. bias_bs{bias_bs},
  82. bias_mode{bias_mode},
  83. nonlineMode{nonlineMode} {}
  84. DType bias_type;
  85. //! stride for batch of bias
  86. ptrdiff_t bias_bs;
  87. BiasMode bias_mode;
  88. Param::NonlineMode nonlineMode;
  89. };
  90. //! memory param for kernels with non-contiguous batch
  91. struct NCBKernParam : public NCBKernSizeParam {
  92. NCBKernParam() = default;
  93. RefPtr src_ptr;
  94. RefPtr filter_ptr;
  95. RefPtr bias_ptr;
  96. RefPtr dst_ptr;
  97. void* workspace_ptr;
  98. size_t workspace_size;
  99. template <typename T>
  100. const T* src() const {
  101. src_type.assert_is_compatible_ctype<T>();
  102. return static_cast<const T*>(src_ptr.get_ptr());
  103. }
  104. //! when format is nchwxx, multi channel will pack into one
  105. //! chnannel_pack_id. pack_channel_size is the number of packed channel
  106. //! when format is nchwxx and channel wise, multi group will pack into
  107. //! one group_pack_id. group_pack_size is the number of packed group
  108. //! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8}
  109. size_t src_offset(
  110. size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
  111. size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
  112. size_t bias_offset(
  113. size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
  114. size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
  115. size_t dst_offset(
  116. size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
  117. size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
  118. template <typename T>
  119. const T* src(
  120. size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
  121. size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
  122. template <typename T>
  123. const T* bias(
  124. size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
  125. size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
  126. template <typename T>
  127. T* dst(size_t batch_id, size_t group_pack_id, size_t channel_pack_id = 0,
  128. size_t group_pack_size = 1, size_t channel_pack_size = 1) const;
  129. //! when format is nchwxx and channel wise, multi group will pack into
  130. //! one group_pack_id. group_pack_size is the number of packed group
  131. //! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8}
  132. size_t filter_offset(size_t group_pack_id, size_t pack_group_size = 1_z) const;
  133. template <typename T>
  134. const T* filter(size_t group_pack_id, size_t pack_group_size = 1_z) const;
  135. template <typename T>
  136. const T* filter() const {
  137. filter_type.assert_is_compatible_ctype<T>();
  138. return static_cast<const T*>(filter_ptr.get_ptr());
  139. }
  140. template <typename T>
  141. const T* bias() const {
  142. bias_type.assert_is_compatible_ctype<T>();
  143. return static_cast<const T*>(bias_ptr.get_ptr());
  144. }
  145. template <typename T>
  146. T* dst() const {
  147. dst_type.assert_is_compatible_ctype<T>();
  148. return static_cast<T*>(dst_ptr.get_ptr());
  149. }
  150. template <typename T>
  151. T* workspace() const {
  152. return static_cast<T*>(workspace_ptr);
  153. }
  154. };
  155. /**
  156. * \brief Kernel run time id, This information is used for getting the work
  157. * data
  158. */
  159. struct NCBKernIndex {
  160. size_t thread_id = 0; //!< Thread id
  161. CpuNDRange ndrange_id;
  162. };
  163. //! move arm_common to fallback
  164. virtual bool is_matmul_quantized_prefer(
  165. const ConvBiasImpl::NCBKernSizeParam& ncb_param) const {
  166. MEGDNN_MARK_USED_VAR(ncb_param);
  167. return true;
  168. };
  169. using ncb_kern_t = thin_function<void(
  170. const NCBKernParam& param, const NCBKernIndex& ncb_index)>;
  171. struct NCBKern {
  172. ncb_kern_t kern; //!< conv kern parallel ptr
  173. CpuNDRange global_size;
  174. };
  175. class AlgoBase : public Algorithm {
  176. public:
  177. AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; }
  178. enum class AlgoType : uint32_t {
  179. //! fallback
  180. FB_NAIVE = 1 << 0,
  181. FB_WINOGRAD_F32,
  182. FB_WINOGRAD_4X4_F32,
  183. FB_WINOGRAD_QS8,
  184. FB_WINOGRAD_8X8_QS8,
  185. FB_CONV1x1,
  186. FB_CONV1x1_GEMV,
  187. FB_IM2COL,
  188. GI_COMMON_WINOGRAD_F23_4X4_FP32,
  189. GI_COMMON_WINOGRAD_F63_FP32,
  190. GI_COMMON_WINOGRAD_F43_FP32,
  191. GI_COMMON_WINOGRAD_F63_4X4_FP32,
  192. GI_COMMON_WINOGRAD_F43_4X4_FP32,
  193. GI_COMMON_WINOGRAD_F54_FP32,
  194. GI_COMMON_WINOGRAD_F45_FP32,
  195. GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32,
  196. GI_COMMON_WINOGRAD_F43_4X4_NCHW44_F32,
  197. GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32,
  198. GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32,
  199. GI_COMMON_DIRECT_FP32,
  200. GI_COMMON_DIRECT_STRD1_FP32,
  201. GI_COMMON_DIRECT_STRD2_FP32,
  202. GI_COMMON_DIRECT_NCHW44_FP32,
  203. GI_COMMON_DIRECT_NCHW_NCHW44_FP32,
  204. GI_COMMON_DIRECT_NCHW_NCHW44_AGENT_FP32,
  205. GI_COMMON_CHWNWISE_NCHW44_F32,
  206. #if MEGDNN_X86
  207. X86_DIRECT = 1 << 8,
  208. X86_DIRECT_STRD2,
  209. X86_WINOGRAD_F63_8x8_F32,
  210. X86_WINOGRAD_F23_8x8_F32,
  211. X86_MKLDNN,
  212. X86_CHANWISE_AVX2_STRD1_QINT8,
  213. X86_CHANWISE_AVX2_STRD2_QINT8,
  214. X86_DIRECT_AVX2_STRD1_INT8,
  215. X86_DIRECT_AVX2_STRD2_INT8,
  216. X86_MKLDNN_QINT8,
  217. X86_MKLDNN_MATMUL_QINT8,
  218. #elif MEGDNN_AARCH64 || MEGDNN_ARMV7
  219. ARM_COMMON_WINOGRAD_F23_FP16 = 1 << 8,
  220. ARM_COMMON_WINOGRAD_F45_FP16,
  221. ARM_COMMON_WINOGRAD_F63_FP16,
  222. ARM_COMMON_WINOGRAD_F23_8X8_FP16,
  223. ARM_COMMON_DIRECT_FP16,
  224. ARM_COMMON_DIRECT_STRD1_FP16,
  225. ARM_COMMON_CHWNWISE_NCHW88_F16,
  226. ARM_COMMON_DIRECT_NCHW88_FP16,
  227. ARM_COMMON_DIRECT_STRD1_S8,
  228. ARM_COMMON_DIRECT_STRD2_S8,
  229. ARM_COMMON_DIRECT_NCHW44,
  230. ARM_COMMON_DIRECT_NCHW_NCHW44_S8,
  231. ARM_COMMON_CHANWISE_STRD1_NCHW44_S8,
  232. ARM_COMMON_CHANWISE_STRD2_NCHW44_S8,
  233. ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8,
  234. //! LARGE for large filter
  235. ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8,
  236. ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8,
  237. ARM_COMMON_DIRECT_STRD1_DOT_S8,
  238. ARM_COMMON_DIRECT_STRD2_DOT_S8,
  239. ARM_COMMON_DIRECT_NCHW44_DOT_S8,
  240. ARM_COMMON_WINOGRAD_F23_8X8_S8,
  241. ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32,
  242. ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8,
  243. ARM_COMMON_DIRECT_INT8X8X16,
  244. ARM_COMMON_DIRECT_NCHW44_INT8X8X16,
  245. ARM_COMMON_DIRECT_STRD2_INT8X8X16,
  246. ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16,
  247. ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16,
  248. ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16,
  249. ARM_COMMON_DIRECT_STRD1_QU8,
  250. ARM_COMMON_DIRECT_STRD2_QU8,
  251. ARM_COMMON_DIRECT_STRD1_DOT_QU8,
  252. ARM_COMMON_DIRECT_STRD2_DOT_QU8,
  253. #if MEGDNN_AARCH64
  254. AARCH64_DIRECT_STRD2_FP16,
  255. AARCH64_DIRECT_STRD2_FP32,
  256. AARCH64_MATMUL_S8,
  257. AARCH64_MATMUL_QU8,
  258. #else
  259. ARMV7_MATMUL_S8,
  260. ARMV7_MATMUL_QU8,
  261. #endif // MEGDNN_AARCH64
  262. #endif
  263. };
  264. virtual ~AlgoBase() = default;
  265. virtual bool usable(
  266. const NCBKernSizeParam& param,
  267. AlgoSelectionStrategy algo_selection_strategy) const = 0;
  268. virtual size_t get_workspace(const NCBKernSizeParam& param) const = 0;
  269. virtual SmallVector<NCBKern> dispatch_kerns(
  270. const NCBKernSizeParam& param) const = 0;
  271. virtual SmallVector<NCBKern> dispatch_preprocess_kerns(
  272. const NCBKernSizeParam&) const {
  273. return {};
  274. };
  275. //! get the layouts of weight_prerocess dst
  276. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  277. const NCBKernSizeParam&) const {
  278. return {};
  279. };
  280. //! get the workspace when weight_prerocess
  281. virtual size_t get_preprocess_workspace(const NCBKernSizeParam&) const {
  282. return 0_z;
  283. };
  284. //! Temporarily used to identify whether the matmul algorithm is
  285. //! is_preferred.
  286. virtual bool is_preferred(const NCBKernSizeParam&) const { return false; }
  287. bool usable_attribute(
  288. const NCBKernSizeParam& param,
  289. AlgoSelectionStrategy algo_selection_strategy,
  290. const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
  291. const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) const {
  292. return contain_attribute_all(positive_attr) &&
  293. !contain_attribute_any(negative_attr) &&
  294. usable(param, algo_selection_strategy);
  295. }
  296. //! get the type of the algo
  297. virtual ConvAlgoTypePack get_algo_type() const = 0;
  298. using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
  299. };
  300. using AlgoMapper = AlgoBase::Mapper;
  301. /**
  302. * \brief get all the algorithm for the opr.
  303. */
  304. virtual SmallVector<AlgoBase*> get_all_packed_algo();
  305. /**
  306. * \brief select algo according to input algo type
  307. */
  308. SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type);
  309. /**
  310. * \brief suggest algo category according to the param
  311. */
  312. virtual SmallVector<AlgoCategory> suggest_algo_category_order(
  313. const NCBKernSizeParam& param) const;
  314. protected:
  315. virtual void exec_with_ncb_kern(
  316. const NCBKernParam& param, ConvBiasImpl::Algorithm* algo);
  317. virtual void exec_preprocess_with_ncb_kern(
  318. const NCBKernParam& param, Algorithm* algo);
  319. virtual std::vector<Algorithm*> get_all_algorithms_with_ncb(
  320. const NCBKernSizeParam& param);
  321. virtual Algorithm* get_algorithm_heuristic_with_ncb(
  322. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  323. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr);
  324. const char* get_algorithm_set_name() const override;
  325. private:
  326. class AlgoNaive;
  327. class AlgoIm2col;
  328. class AlgoConv1x1;
  329. class AlgoConv1x1Gemv;
  330. class AlgoWinogradF32;
  331. class AlgoWinogradF32_4x4;
  332. class AlgoWinogradQS8;
  333. class AlgoWinogradQS8_8x8;
  334. class AlgoFP32WinogradF23_4x4;
  335. class AlgoFP32WinogradF63;
  336. class AlgoFP32WinogradF43;
  337. class AlgoFP32WinogradF63_4x4;
  338. class AlgoFP32WinogradF43_4x4;
  339. class AlgoFP32WinogradF54;
  340. class AlgoFP32WinogradF45;
  341. class AlgoFP32WinogradF23_4x4_NCHW44;
  342. class AlgoFP32WinogradF43_4x4_NCHW44;
  343. class AlgoFP32WinogradF63_4x4_NCHW44;
  344. class AlgoFP32WinogradF73_4x4_NCHW44;
  345. class AlgoF32Direct;
  346. class AlgoF32DirectStride1;
  347. class AlgoF32DirectStride2;
  348. class AlgoF32DirectNCHWNCHW44;
  349. class AlgoF32DirectNCHWNCHW44AGENT;
  350. class AlgoF32ChannelWiseNCHW44;
  351. class AlgoF32DirectNCHW44;
  352. class AlgoPack;
  353. NCBKernSizeParam m_prev_selected_algo_sizep;
  354. Algorithm* m_prev_selected_algo = nullptr;
  355. bool is_naive_algo(ConvBiasImpl::Algorithm* algo);
  356. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  357. //! get algorithm set by user or by heuristic
  358. Algorithm* get_algorithm(
  359. const NCBKernSizeParam& param,
  360. size_t workspace_size = std::numeric_limits<size_t>::max());
  361. NCBKernSizeParam make_ncb_kern_size_param(
  362. const TensorLayout& src, const TensorLayout& filter,
  363. const TensorLayout& bias, const TensorLayout& dst,
  364. const PreprocessedFilter* preprocessed_filter);
  365. NCBKernParam make_ncb_kern_param(
  366. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  367. _megdnn_tensor_out dst, _megdnn_workspace workspace,
  368. const PreprocessedFilter* preprocessed_filter);
  369. static const AlgoPack& algo_pack();
  370. };
  371. inline bool is_enable_filter_preprocess(const ConvBiasImpl::NCBKernSizeParam& param) {
  372. return param.preprocessed_filter && param.preprocessed_filter->tensors.size() >= 1;
  373. }
  374. } // namespace fallback
  375. } // namespace megdnn
  376. //! unpack NCBKernSizeParam into local variables (N, IC, IH, IW, ...)
  377. #define UNPACK_CONV_NCB_KERN_SIZES(_p) \
  378. auto N = _p.n, IC = _p.filter_meta.icpg, IH = _p.isz[0], IW = _p.isz[1], \
  379. OC = _p.filter_meta.ocpg, OH = _p.osz[0], OW = _p.osz[1], \
  380. FH = _p.filter_meta.spatial[0], FW = _p.filter_meta.spatial[1], \
  381. SH = _p.filter_meta.stride[0], SW = _p.filter_meta.stride[1], \
  382. PH = _p.filter_meta.padding[0], PW = _p.filter_meta.padding[1]
  383. // vim: syntax=cpp.doxygen