You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. /**
  2. * \file dnn/src/x86/conv_bias/f32/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/x86/conv_bias/f32/algos.h"
  13. #include <unordered_map>
  14. #include "src/common/opr_delegate.h"
  15. #include "src/common/utils.h"
  16. #include "src/fallback/convolution/img2col_helper.h"
  17. #include "src/x86/conv_bias/f32/do_conv_stride2.h"
  18. #include "src/x86/conv_bias/opr_impl.h"
  19. #include "src/x86/conv_bias/postprocess_helper.h"
  20. #include "src/x86/convolution/convolution_direct_special_cases.h"
  21. #include "src/x86/handle.h"
  22. #include "src/x86/profile.h"
  23. #include "midout.h"
  24. using namespace megdnn;
  25. using namespace x86;
  26. namespace {
  27. bool need_dst_copy(const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
  28. if (param.osz[0] % 8 != 0 || param.osz[1] % 8 != 0) {
  29. // If the size of output is not multiples of 8, we need to copy it.
  30. return true;
  31. }
  32. return false;
  33. }
  34. bool need_src_copy(const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
  35. if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) {
  36. // If padding is not zero, we need to copy to eliminate padding effect.
  37. return true;
  38. }
  39. return need_dst_copy(param);
  40. }
  41. void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH,
  42. size_t FW, size_t PH, size_t PW, size_t& IH2,
  43. size_t& IW2, size_t& OH2, size_t& OW2) {
  44. MEGDNN_MARK_USED_VAR(PH);
  45. MEGDNN_MARK_USED_VAR(PW);
  46. OH2 = (OH + 7) & ~7;
  47. OW2 = (OW + 7) & ~7;
  48. IH2 = 2 * OH2 + FH - 2;
  49. IW2 = 2 * OW2 + FW - 2;
  50. // Because stride is 2, sometimes IH/W == IH/W2 + 1
  51. // Do a max update to handle this case.
  52. IH2 = std::max(IH2, IH);
  53. IW2 = std::max(IW2, IW);
  54. }
  55. } // namespace
  56. #define GET_KERN \
  57. auto fm = param.filter_meta; \
  58. size_t N = param.n; \
  59. size_t IC = param.filter_meta.icpg; \
  60. size_t OC = param.filter_meta.ocpg; \
  61. size_t group = fm.group; \
  62. WorkspaceBundle bundle = get_bundle(param); \
  63. SmallVector<NCBKern> ret_kerns; \
  64. if (m_large_group) { \
  65. auto exec_one_group = [bundle]( \
  66. const NCBKernParam& kern_param, \
  67. const NCBKernIndex& ncb_index) mutable { \
  68. bundle.set(kern_param.workspace_ptr); \
  69. auto fm = kern_param.filter_meta; \
  70. size_t IC = fm.icpg; \
  71. size_t OC = fm.ocpg; \
  72. for (size_t ic = 0; ic < IC; ic++) { \
  73. copy_padding_kern(bundle, kern_param, ncb_index, \
  74. {ncb_index.thread_id, 0, ic}); \
  75. } \
  76. for (size_t oc = 0; oc < OC; oc++) { \
  77. do_conv_kern(bundle, kern_param, ncb_index, \
  78. {ncb_index.thread_id, 0, oc}); \
  79. } \
  80. }; \
  81. ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \
  82. } else { \
  83. auto copy_padding = [bundle](const NCBKernParam& kern_param, \
  84. const NCBKernIndex& ncb_index) mutable { \
  85. bundle.set(kern_param.workspace_ptr); \
  86. copy_padding_kern(bundle, kern_param, ncb_index, \
  87. ncb_index.ndrange_id); \
  88. }; \
  89. ret_kerns.push_back({copy_padding, {group, N, IC}}); \
  90. auto do_conv = [bundle](const NCBKernParam& kern_param, \
  91. const NCBKernIndex& ncb_index) mutable { \
  92. bundle.set(kern_param.workspace_ptr); \
  93. do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); \
  94. }; \
  95. ret_kerns.push_back({do_conv, {group, N, OC}}); \
  96. } \
  97. return ret_kerns;
  98. /* ===================== direct algo ===================== */
  99. bool ConvBiasImpl::AlgoDirect::usable(
  100. const NCBKernSizeParam& param,
  101. AlgoSelectionStrategy algo_selection_strategy) const {
  102. auto&& fm = param.filter_meta;
  103. bool aviliable = fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
  104. param.src_type.enumv() == DTypeEnum::Float32 &&
  105. param.filter_type.enumv() == DTypeEnum::Float32 &&
  106. param.dst_type.enumv() == DTypeEnum::Float32 &&
  107. fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
  108. fm.spatial[0] <= 7 && fm.stride[0] == 1 &&
  109. fm.stride[1] == 1;
  110. if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
  111. bool large_group = param.filter_meta.group >= param.nr_threads;
  112. aviliable &= (large_group == m_large_group);
  113. }
  114. return aviliable;
  115. }
  116. WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle(
  117. const NCBKernSizeParam& param) const {
  118. auto&& fm = param.filter_meta;
  119. size_t nr_threads = param.nr_threads;
  120. size_t group = fm.group, batch = param.n;
  121. auto IC = fm.icpg, IH = param.isz[0], IW = param.isz[1];
  122. auto FH = fm.spatial[0], FW = fm.spatial[1];
  123. auto OH = param.osz[0], OW = param.osz[1];
  124. size_t OH2, OW2, IH2, IW2;
  125. get_rectified_img_size(IH, IW, FH, FW, OH, OW, fm.padding[0], fm.padding[1],
  126. IH2, IW2, OH2, OW2);
  127. size_t part0 = 0u, part1 = 0u;
  128. if (IH != IH2 || IW != IW2) {
  129. part0 = m_large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads
  130. : IC * IH2 * IW2 * sizeof(float) * group * batch;
  131. }
  132. if (OH != OH2 || OW != OW2) {
  133. part1 = OH2 * OW2 * sizeof(float) * nr_threads;
  134. }
  135. return {nullptr, {part0, part1}};
  136. }
  137. size_t ConvBiasImpl::AlgoDirect::get_workspace(
  138. const NCBKernSizeParam& param) const {
  139. return get_bundle(param).total_size_in_bytes();
  140. }
  141. //! Process one input channel copy padding
  142. void ConvBiasImpl::AlgoDirect::copy_padding_kern(
  143. const WorkspaceBundle& bundle,
  144. const ConvBiasImpl::NCBKernParam& kern_param,
  145. const ConvBiasImpl::NCBKernIndex& ncb_index,
  146. const CpuNDRange& workspace_ids) {
  147. size_t IH = kern_param.isz[0];
  148. size_t IW = kern_param.isz[1];
  149. size_t IC = kern_param.filter_meta.icpg;
  150. size_t OH = kern_param.osz[0];
  151. size_t OW = kern_param.osz[1];
  152. size_t PH = kern_param.filter_meta.padding[0];
  153. size_t PW = kern_param.filter_meta.padding[1];
  154. size_t FH = kern_param.filter_meta.spatial[0];
  155. size_t FW = kern_param.filter_meta.spatial[1];
  156. size_t GROUP = kern_param.filter_meta.group;
  157. size_t OH2, OW2, IH2, IW2;
  158. get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
  159. bool rectify_src = (IH != IH2 || IW != IW2);
  160. size_t padding_group_size = IH2 * IW2 * IC;
  161. size_t batch_id = ncb_index.ndrange_id[1];
  162. size_t group_id = ncb_index.ndrange_id[0];
  163. size_t channel_id = workspace_ids[2];
  164. const float* sptr = static_cast<const float*>(
  165. kern_param.src<float>(batch_id, group_id)) +
  166. channel_id * IH * IW;
  167. //! Used for get the workspace offset
  168. size_t workspace_group_id = workspace_ids[0],
  169. workspace_batch_id = workspace_ids[1],
  170. workspace_channel_id = workspace_ids[2];
  171. //! If large group, each thread has its own worspace, set group_id with
  172. //! thread_id
  173. if (rectify_src) {
  174. //! copy to sptr_base to eliminate padding effect
  175. float* sptr_base = static_cast<float*>(bundle.get(0)) +
  176. workspace_group_id * padding_group_size +
  177. workspace_batch_id * GROUP * padding_group_size +
  178. workspace_channel_id * IH2 * IW2;
  179. std::memset(sptr_base, 0, sizeof(float) * IH2 * IW2);
  180. rep(ih, IH) {
  181. std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
  182. sizeof(float) * IW);
  183. }
  184. }
  185. };
  186. #define DISPATCH \
  187. if (is_supported(SIMDType::FMA)) { \
  188. DISPATCH_SIMD(fma) \
  189. } else if (is_supported(SIMDType::AVX)) { \
  190. DISPATCH_SIMD(avx) \
  191. } else if (is_supported(SIMDType::SSE)) { \
  192. DISPATCH_SIMD(sse) \
  193. } else { \
  194. megdnn_throw(megdnn_mangle("no fma/avx/sse detected")); \
  195. }
  196. #define DISPATCH_SIMD(simd) \
  197. if (is_xcorr) { \
  198. DISPATCH_SIMD_MODE(simd, xcorr) \
  199. } else { \
  200. DISPATCH_SIMD_MODE(simd, conv) \
  201. }
  202. #define DISPATCH_SIMD_MODE(simd, mode) \
  203. switch (FH) { \
  204. case 1: \
  205. DISPATCH_SIMD_MODE_FSIZE(simd, mode, 1); \
  206. break; \
  207. case 2: \
  208. DISPATCH_SIMD_MODE_FSIZE(simd, mode, 2); \
  209. break; \
  210. case 3: \
  211. DISPATCH_SIMD_MODE_FSIZE(simd, mode, 3); \
  212. break; \
  213. case 4: \
  214. DISPATCH_SIMD_MODE_FSIZE(simd, mode, 4); \
  215. break; \
  216. case 5: \
  217. DISPATCH_SIMD_MODE_FSIZE(simd, mode, 5); \
  218. break; \
  219. case 6: \
  220. DISPATCH_SIMD_MODE_FSIZE(simd, mode, 6); \
  221. break; \
  222. case 7: \
  223. DISPATCH_SIMD_MODE_FSIZE(simd, mode, 7); \
  224. break; \
  225. default: \
  226. megdnn_throw(megdnn_mangle("unsupported filter size")); \
  227. }
  228. #define DISPATCH_SIMD_MODE_FSIZE(simd, mode, fsize) \
  229. func = detail::convolution_##mode##_fh##fsize##_##simd;
  230. //! compute one output channel
  231. void ConvBiasImpl::AlgoDirect::do_conv_kern(const WorkspaceBundle& bundle,
  232. const NCBKernParam& kern_param,
  233. const NCBKernIndex& ncb_index,
  234. const CpuNDRange& workspace_ids) {
  235. size_t OH = kern_param.osz[0];
  236. size_t OW = kern_param.osz[1];
  237. size_t IH = kern_param.isz[0];
  238. size_t IW = kern_param.isz[1];
  239. size_t FH = kern_param.filter_meta.spatial[0];
  240. size_t FW = kern_param.filter_meta.spatial[1];
  241. size_t IC = kern_param.filter_meta.icpg;
  242. size_t PH = kern_param.filter_meta.padding[0];
  243. size_t PW = kern_param.filter_meta.padding[1];
  244. auto is_xcorr = !kern_param.filter_meta.should_flip;
  245. size_t GROUP = kern_param.filter_meta.group;
  246. size_t OH2, OW2, IH2, IW2;
  247. get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
  248. bool rectify_src = (IH != IH2 || IW != IW2);
  249. bool rectify_dst = (OH != OH2 || OW != OW2);
  250. size_t padding_group_size = IH2 * IW2 * IC;
  251. //! Choose the compute kernel
  252. std::function<void(const float*, const float*, float*, size_t, size_t,
  253. size_t, size_t, size_t)>
  254. func = nullptr;
  255. DISPATCH;
  256. size_t bias_offset = 0;
  257. if (kern_param.bias_mode == megdnn::BiasMode::BIAS) {
  258. bias_offset = OH * OW;
  259. } else if (kern_param.bias_mode ==
  260. megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  261. bias_offset = 1_z;
  262. }
  263. size_t group_id = ncb_index.ndrange_id[0];
  264. size_t batch_id = ncb_index.ndrange_id[1];
  265. //! Used for get the workspace offset
  266. size_t workspace_group_id = workspace_ids[0],
  267. workspace_batch_id = workspace_ids[1], oc = workspace_ids[2];
  268. const float* sptr = kern_param.src<float>(batch_id, group_id);
  269. const float* filter =
  270. kern_param.filter<float>(group_id) + oc * FH * FW * IC;
  271. const float* bias_ptr =
  272. kern_param.bias<float>(batch_id, group_id) + oc * bias_offset;
  273. float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW;
  274. if (rectify_src) {
  275. sptr = static_cast<float*>(bundle.get(0)) +
  276. workspace_group_id * padding_group_size +
  277. workspace_batch_id * GROUP * padding_group_size;
  278. }
  279. float* dptr = nullptr;
  280. if (rectify_dst) {
  281. dptr = static_cast<float*>(bundle.get(1)) +
  282. ncb_index.thread_id * OH2 * OW2;
  283. } else {
  284. dptr = dst;
  285. }
  286. std::memset(dptr, 0, sizeof(float) * OH2 * OW2);
  287. rep(ic, IC) {
  288. func(sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, IW2, OH2,
  289. OW2, FW);
  290. }
  291. if (rectify_dst) {
  292. rep(oh, OH) {
  293. std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(float) * OW);
  294. }
  295. }
  296. PostProcess<dt_float32>::run(dst, const_cast<float*>(bias_ptr), dst,
  297. kern_param.bias_mode, kern_param.nonlineMode,
  298. kern_param.bias_type, kern_param.dst_type, 1_z,
  299. 1_z, OH, OW);
  300. }
  301. SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDirect::get_kimpls(
  302. const NCBKernSizeParam& param) const {
  303. GET_KERN;
  304. }
  305. /* ===================== direct-stride2 algo ===================== */
  306. bool ConvBiasImpl::AlgoDirectStride2::usable(
  307. const NCBKernSizeParam& param,
  308. AlgoSelectionStrategy algo_selection_strategy) const {
  309. auto&& fm = param.filter_meta;
  310. auto FH = fm.spatial[0];
  311. bool aviliable =
  312. param.filter_meta.format == param::ConvBias::Format::NCHW &&
  313. param.src_type.enumv() == DTypeEnum::Float32 &&
  314. param.filter_type.enumv() == DTypeEnum::Float32 &&
  315. param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip &&
  316. fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
  317. fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
  318. FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7);
  319. if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
  320. bool large_group = param.filter_meta.group >= param.nr_threads;
  321. aviliable &= (large_group == m_large_group);
  322. }
  323. return aviliable;
  324. }
  325. WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle(
  326. const NCBKernSizeParam& param) const {
  327. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  328. MEGDNN_MARK_USED_VAR(N);
  329. MEGDNN_MARK_USED_VAR(OC);
  330. MEGDNN_MARK_USED_VAR(SH);
  331. MEGDNN_MARK_USED_VAR(SW);
  332. size_t nr_threads = param.nr_threads;
  333. size_t group = param.filter_meta.group;
  334. size_t batch = param.n;
  335. size_t src_size = 0, dst_size = 0;
  336. size_t IH2, IW2, OH2, OW2;
  337. get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
  338. if (need_src_copy(param)) {
  339. src_size = m_large_group
  340. ? IC * IH2 * IW2 * sizeof(float) * nr_threads
  341. : IC * IH2 * IW2 * sizeof(float) * group * batch;
  342. }
  343. if (need_dst_copy(param)) {
  344. // we only need one dst plane
  345. dst_size = OH2 * OW2 * sizeof(float) * nr_threads;
  346. }
  347. return WorkspaceBundle(nullptr, {src_size, dst_size});
  348. }
  349. size_t ConvBiasImpl::AlgoDirectStride2::get_workspace(
  350. const NCBKernSizeParam& param) const {
  351. return get_bundle(param).total_size_in_bytes();
  352. }
  353. //! Process one input channel copy padding
  354. void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern(
  355. const WorkspaceBundle& bundle,
  356. const ConvBiasImpl::NCBKernParam& kern_param,
  357. const ConvBiasImpl::NCBKernIndex& ncb_index,
  358. const CpuNDRange& workspace_ids) {
  359. size_t IH = kern_param.isz[0];
  360. size_t IW = kern_param.isz[1];
  361. size_t IC = kern_param.filter_meta.icpg;
  362. size_t OH = kern_param.osz[0];
  363. size_t OW = kern_param.osz[1];
  364. size_t PH = kern_param.filter_meta.padding[0];
  365. size_t PW = kern_param.filter_meta.padding[1];
  366. size_t FH = kern_param.filter_meta.spatial[0];
  367. size_t FW = kern_param.filter_meta.spatial[1];
  368. size_t GROUP = kern_param.filter_meta.group;
  369. size_t OH2, OW2, IH2, IW2;
  370. get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
  371. bool rectify_src = need_src_copy(kern_param);
  372. size_t padding_group_size = IH2 * IW2 * IC;
  373. size_t group_id = ncb_index.ndrange_id[0];
  374. size_t batch_id = ncb_index.ndrange_id[1];
  375. size_t channel_id = workspace_ids[2];
  376. const float* sptr = static_cast<const float*>(
  377. kern_param.src<float>(batch_id, group_id)) +
  378. channel_id * IH * IW;
  379. //! Used for get the workspace offset
  380. size_t workspace_group_id = workspace_ids[0],
  381. workspace_batch_id = workspace_ids[1],
  382. workspace_channel_id = workspace_ids[2];
  383. if (rectify_src) {
  384. //! copy to sptr_base to eliminate padding effect
  385. float* sptr_base = static_cast<float*>(bundle.get(0)) +
  386. workspace_group_id * padding_group_size +
  387. workspace_batch_id * GROUP * padding_group_size +
  388. workspace_channel_id * IH2 * IW2;
  389. std::memset(sptr_base, 0, sizeof(float) * IH2 * IW2);
  390. rep(ih, IH) {
  391. std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW,
  392. sizeof(float) * IW);
  393. }
  394. }
  395. };
  396. //! compute one output channel
  397. void ConvBiasImpl::AlgoDirectStride2::do_conv_kern(
  398. const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
  399. const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
  400. size_t OH = kern_param.osz[0];
  401. size_t OW = kern_param.osz[1];
  402. size_t IH = kern_param.isz[0];
  403. size_t IW = kern_param.isz[1];
  404. size_t FH = kern_param.filter_meta.spatial[0];
  405. size_t FW = kern_param.filter_meta.spatial[1];
  406. size_t IC = kern_param.filter_meta.icpg;
  407. size_t PH = kern_param.filter_meta.padding[0];
  408. size_t PW = kern_param.filter_meta.padding[1];
  409. size_t GROUP = kern_param.filter_meta.group;
  410. size_t OH2, OW2, IH2, IW2;
  411. get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
  412. bool rectify_src = need_src_copy(kern_param);
  413. bool rectify_dst = need_dst_copy(kern_param);
  414. size_t padding_group_size = IH2 * IW2 * IC;
  415. //! Choose the compute kernel
  416. using Func = std::function<void(const float*, const float*, float*, size_t,
  417. size_t, size_t, size_t, size_t, size_t)>;
  418. Func func_no_add_dst = nullptr, func_add_dst = nullptr;
  419. if (FH == 2) {
  420. func_no_add_dst = conv_general_simd::do_conv_2x2_stride2<false>;
  421. func_add_dst = conv_general_simd::do_conv_2x2_stride2<true>;
  422. } else if (FH == 3) {
  423. func_no_add_dst = conv_general_simd::do_conv_3x3_stride2<false>;
  424. func_add_dst = conv_general_simd::do_conv_3x3_stride2<true>;
  425. } else if (FH == 5) {
  426. func_no_add_dst = conv_general_simd::do_conv_5x5_stride2<false>;
  427. func_add_dst = conv_general_simd::do_conv_5x5_stride2<true>;
  428. } else if (FH == 7) {
  429. func_no_add_dst = conv_general_simd::do_conv_7x7_stride2<false>;
  430. func_add_dst = conv_general_simd::do_conv_7x7_stride2<true>;
  431. }
  432. size_t bias_offset = 0;
  433. if (kern_param.bias_mode == megdnn::BiasMode::BIAS) {
  434. bias_offset = OH * OW;
  435. } else if (kern_param.bias_mode ==
  436. megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  437. bias_offset = 1_z;
  438. }
  439. size_t group_id = ncb_index.ndrange_id[0];
  440. size_t batch_id = ncb_index.ndrange_id[1];
  441. //! Used for get the workspace offset
  442. size_t workspace_group_id = workspace_ids[0],
  443. workspace_batch_id = workspace_ids[1], oc = workspace_ids[2];
  444. const float* sptr = kern_param.src<float>(batch_id, group_id);
  445. const float* filter =
  446. kern_param.filter<float>(group_id) + oc * FH * FW * IC;
  447. const float* bias_ptr =
  448. kern_param.bias<float>(batch_id, group_id) + oc * bias_offset;
  449. float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW;
  450. if (rectify_src) {
  451. sptr = static_cast<float*>(bundle.get(0)) +
  452. workspace_group_id * padding_group_size +
  453. workspace_batch_id * GROUP * padding_group_size;
  454. }
  455. float* dptr = nullptr;
  456. if (rectify_dst) {
  457. dptr = static_cast<float*>(bundle.get(1)) +
  458. ncb_index.thread_id * OH2 * OW2;
  459. } else {
  460. dptr = dst;
  461. }
  462. func_no_add_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0);
  463. for (size_t ic = 1; ic < IC; ++ic) {
  464. func_add_dst(sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2,
  465. IW2, OH2, OW2, 0, 0);
  466. }
  467. if (rectify_dst) {
  468. rep(oh, OH) {
  469. std::memcpy(dst + oh * OW, dptr + oh * OW2, sizeof(float) * OW);
  470. }
  471. }
  472. PostProcess<dt_float32>::run(dst, const_cast<float*>(bias_ptr), dst,
  473. kern_param.bias_mode, kern_param.nonlineMode,
  474. kern_param.bias_type, kern_param.dst_type, 1_z,
  475. 1_z, OH, OW);
  476. }
  477. SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDirectStride2::get_kimpls(
  478. const NCBKernSizeParam& param) const {
  479. GET_KERN;
  480. }
  481. /* ===================== matmul algo ===================== */
  482. WorkspaceBundle ConvBiasImpl::AlgoMatrixMul::get_bundle(
  483. const NCBKernSizeParam& param) {
  484. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  485. MEGDNN_MARK_USED_VAR(N);
  486. MEGDNN_MARK_USED_VAR(OC);
  487. auto IW2 = IH + 2 * PH;
  488. auto IH2 = IW + 2 * PW;
  489. bool can_matrix_mul_direct =
  490. (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
  491. // temp space to store padding-free src (with 4 extra floats)
  492. // temp space to store unrolled matrix (with 4 extra floats)
  493. // workspace for matrix mul opr
  494. size_t part0, part1, part2;
  495. if (can_matrix_mul_direct) {
  496. part0 = part1 = 0;
  497. } else {
  498. part0 = (IC * IH2 * IW2 + 4) * sizeof(float);
  499. part1 = (IC * FH * FW * OH * OW + 4) * sizeof(float);
  500. }
  501. {
  502. TensorLayout A_, B_, C_;
  503. A_ = TensorLayout({OC, IC * FH * FW}, dtype::Float32());
  504. B_ = TensorLayout({IC * FH * FW, OH * OW}, dtype::Float32());
  505. C_ = TensorLayout({OC, OH * OW}, dtype::Float32());
  506. part2 = get_matmul_opr()->get_workspace_in_bytes(A_, B_, C_);
  507. }
  508. return {nullptr, {part0, part1, part2}};
  509. }
  510. bool ConvBiasImpl::AlgoMatrixMul::is_preferred(
  511. const NCBKernSizeParam& param) const {
  512. auto&& fm = param.filter_meta;
  513. if (fm.dilation[0] != 1 || fm.dilation[1] != 1) {
  514. return false;
  515. }
  516. // single channel conv should never use matrix mul
  517. if (fm.ocpg == 1 || fm.icpg == 1)
  518. return false;
  519. // 1x1 conv should always use matrix mul
  520. if (fm.spatial[0] == 1 && fm.spatial[1] == 1)
  521. return true;
  522. // if stride is not 1x1, always use matrix mul
  523. if (fm.stride[0] != 1 || fm.stride[1] != 1)
  524. return true;
  525. int f = find_nearest_elem<int>(
  526. std::round(geometric_mean(fm.spatial[0], fm.spatial[1])),
  527. {2, 3, 4, 5, 6, 7});
  528. int oc = find_nearest_elem<int>(fm.ocpg, {4, 8, 16, 32, 64, 96, 128});
  529. int ic = find_nearest_elem<int>(fm.icpg, {4, 8, 16, 32, 64, 96, 128});
  530. int on = std::round(geometric_mean(param.osz[0], param.osz[1]));
  531. ProfileElement cur(f, oc, ic, on);
  532. auto H = static_cast<HandleImpl*>(inplace_cpu_handle().get());
  533. auto&& target = std::lower_bound(H->profile_cache().begin(),
  534. H->profile_cache().end(), cur);
  535. megdnn_assert_internal(target->f == cur.f);
  536. megdnn_assert_internal(target->oc == cur.oc);
  537. megdnn_assert_internal(target->ic == cur.ic);
  538. return on < target->on_threshold;
  539. }
  540. MatrixMul* ConvBiasImpl::AlgoMatrixMul::get_matmul_opr() {
  541. static CpuOprDelegationStorage<> storage;
  542. return storage.get<MatrixMul>();
  543. }
  544. void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param,
  545. const NCBKernIndex& ncb_index) {
  546. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  547. auto IH2 = IH + 2 * PH;
  548. auto IW2 = IW + 2 * PW;
  549. size_t group_id = ncb_index.ndrange_id[0];
  550. bool is_xcorr = !param.filter_meta.should_flip;
  551. auto bundle = get_bundle(param);
  552. bundle.set(param.workspace_ptr);
  553. // workspace = tmp..src2
  554. for (size_t n = 0; n < N; ++n) {
  555. float* src = const_cast<float*>(param.src<float>(n, group_id));
  556. float* dst = param.dst<float>(n, group_id);
  557. float* bias_ptr = static_cast<float*>(
  558. const_cast<void*>(param.bias<void>(n, group_id)));
  559. float *B, *src2;
  560. if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
  561. // special case: 1x1
  562. B = src;
  563. } else {
  564. src2 = static_cast<float*>(bundle.get(0));
  565. // copy src to src2;
  566. float* src2_ptr = src2;
  567. const float* src_ptr = src;
  568. rep(ic, IC) {
  569. if (PH != 0) {
  570. std::memset(src2_ptr, 0, sizeof(float) * PH * IW2);
  571. src2_ptr += PH * IW2;
  572. }
  573. rep(ih, IH) {
  574. if (PW != 0)
  575. rep(pw, PW) * (src2_ptr++) = 0.0f;
  576. std::memcpy(src2_ptr, src_ptr, sizeof(float) * IW);
  577. src2_ptr += IW;
  578. src_ptr += IW;
  579. if (PW != 0)
  580. rep(pw, PW) * (src2_ptr++) = 0.0f;
  581. }
  582. if (PH != 0) {
  583. std::memset(src2_ptr, 0, sizeof(float) * PH * IW2);
  584. src2_ptr += PH * IW2;
  585. }
  586. }
  587. B = static_cast<float*>(bundle.get(1));
  588. if (SH == 1 && SW == 1) {
  589. if (is_xcorr) {
  590. img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
  591. } else {
  592. img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
  593. }
  594. } else {
  595. if (is_xcorr) {
  596. img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
  597. FW, SH, SW);
  598. } else {
  599. img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
  600. FW, SH, SW);
  601. }
  602. }
  603. }
  604. {
  605. TensorND A_, B_, C_;
  606. A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Float32());
  607. A_.raw_ptr = const_cast<float*>(param.filter<float>(group_id));
  608. B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Float32());
  609. B_.raw_ptr = B;
  610. C_.layout = TensorLayout({OC, OH * OW}, dtype::Float32());
  611. C_.raw_ptr = dst;
  612. Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
  613. bundle.get_size(2));
  614. get_matmul_opr()->exec(A_, B_, C_, workspace);
  615. }
  616. PostProcess<float>::run(dst, bias_ptr, dst, param.bias_mode,
  617. param.nonlineMode, param.bias_type,
  618. param.dst_type, 1_z, OC, OH, OW);
  619. }
  620. }
  621. #if MEGDNN_X86_WITH_MKL_DNN
  622. static inline void mkldnn_fp32_conv_instance(
  623. const ConvBiasImpl::NCBKernParam& param, const uint32_t ocpg,
  624. const uint32_t icpg, const uint32_t group, const uint32_t in,
  625. const uint32_t ic, const uint32_t oc, const uint32_t ih,
  626. const uint32_t iw, const uint32_t kh, const uint32_t kw,
  627. const uint32_t pad_h, const uint32_t pad_w, const uint32_t stride_h,
  628. const uint32_t stride_w, const uint32_t oh, const uint32_t ow,
  629. std::vector<dnnl::primitive>& net,
  630. std::vector<std::unordered_map<int, dnnl::memory>>& net_args,
  631. dnnl::engine& eng_mkldnn) {
  632. dnnl::memory::dims src_shape = {in, ic, ih, iw};
  633. dnnl::memory::dims weight_shape = {oc, ic, kh, kw};
  634. dnnl::memory::dims bias_shape = {oc};
  635. dnnl::memory::dims dst_shape = {in, oc, oh, ow};
  636. dnnl::memory::dims strides_shape = {stride_h, stride_w};
  637. dnnl::memory::dims padding_shape = {pad_h, pad_w};
  638. auto user_src_desc =
  639. dnnl::memory::desc({src_shape}, dnnl::memory::data_type::f32,
  640. dnnl::memory::format_tag::nChw8c);
  641. if (group == 1 && ic < 8) {
  642. user_src_desc =
  643. dnnl::memory::desc({src_shape}, dnnl::memory::data_type::f32,
  644. dnnl::memory::format_tag::nchw);
  645. }
  646. auto user_src_mem = dnnl::memory(user_src_desc, eng_mkldnn,
  647. const_cast<void*>(param.src_ptr));
  648. auto weight_tag = dnnl::memory::format_tag::OIhw8i8o;
  649. if (group > 1) {
  650. weight_shape = {group, ocpg, icpg, kh, kw};
  651. if (oc == group && ic == group) {
  652. weight_tag = dnnl::memory::format_tag::Goihw8g;
  653. } else {
  654. weight_tag = dnnl::memory::format_tag::gOIhw8i8o;
  655. }
  656. } else if (group == 1 && ic < 8) {
  657. weight_tag = dnnl::memory::format_tag::Ohwi8o;
  658. }
  659. auto user_weights_desc = dnnl::memory::desc(
  660. {weight_shape}, dnnl::memory::data_type::f32, weight_tag);
  661. auto user_weights_mem = dnnl::memory(user_weights_desc, eng_mkldnn,
  662. const_cast<void*>(param.filter_ptr));
  663. auto user_bias_desc = dnnl::memory::desc();
  664. if (param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
  665. user_bias_desc =
  666. dnnl::memory::desc({bias_shape}, dnnl::memory::data_type::f32,
  667. dnnl::memory::format_tag::x);
  668. }
  669. auto user_bias_mem = dnnl::memory(user_bias_desc, eng_mkldnn,
  670. const_cast<void*>(param.bias_ptr));
  671. auto user_dst_desc =
  672. dnnl::memory::desc({dst_shape}, dnnl::memory::data_type::f32,
  673. dnnl::memory::format_tag::nChw8c);
  674. auto user_dst_mem = dnnl::memory(user_dst_desc, eng_mkldnn,
  675. const_cast<void*>(param.dst_ptr));
  676. auto conv_desc = dnnl::convolution_forward::desc(
  677. dnnl::prop_kind::forward_inference,
  678. dnnl::algorithm::convolution_auto, user_src_mem.get_desc(),
  679. user_weights_mem.get_desc(), user_bias_mem.get_desc(),
  680. user_dst_mem.get_desc(), strides_shape, padding_shape,
  681. padding_shape);
  682. dnnl::primitive_attr attr;
  683. if ((param.nonlineMode == NonlineMode::RELU ||
  684. param.nonlineMode == NonlineMode::SIGMOID) &&
  685. (param.bias_mode == megdnn::BiasMode::NO_BIAS ||
  686. param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS)) {
  687. auto post_tag = dnnl::algorithm::eltwise_linear;
  688. switch (param.nonlineMode) {
  689. case NonlineMode::RELU:
  690. post_tag = dnnl::algorithm::eltwise_relu;
  691. break;
  692. case NonlineMode::SIGMOID:
  693. post_tag = dnnl::algorithm::eltwise_logistic;
  694. break;
  695. default:
  696. megdnn_assert(0, "not supported nonline mode %d\n",
  697. static_cast<int>(param.nonlineMode));
  698. }
  699. dnnl::post_ops ops;
  700. ops.append_eltwise(1.f, post_tag, 0.f, 0.f);
  701. attr.set_post_ops(ops);
  702. }
  703. auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(
  704. conv_desc, attr, eng_mkldnn);
  705. net.push_back(dnnl::convolution_forward(conv_prim_desc));
  706. net_args.push_back({{DNNL_ARG_SRC, user_src_mem},
  707. {DNNL_ARG_WEIGHTS, user_weights_mem},
  708. {DNNL_ARG_BIAS, user_bias_mem},
  709. {DNNL_ARG_DST, user_dst_mem}});
  710. }
  711. namespace {
  712. struct NCBKernParamEqual {
  713. bool operator()(const fallback::ConvBiasImpl::NCBKernParam& x,
  714. const fallback::ConvBiasImpl::NCBKernParam& y) const {
  715. bool flag = true;
  716. flag = flag && (x.src_ptr == y.src_ptr);
  717. flag = flag && (x.dst_ptr == y.dst_ptr);
  718. flag = flag && (x.filter_ptr == y.filter_ptr);
  719. flag = flag && (x.bias_ptr == y.bias_ptr);
  720. flag = flag && (x.isz == y.isz);
  721. flag = flag && (x.osz == y.osz);
  722. flag = flag && (x.src_type == y.src_type);
  723. flag = flag && (x.dst_type == y.dst_type);
  724. flag = flag && (x.filter_type == y.filter_type);
  725. flag = flag && (x.bias_type == y.bias_type);
  726. flag = flag && (x.filter_meta == y.filter_meta);
  727. flag = flag && (x.n == y.n);
  728. flag = flag && (x.bias_mode == y.bias_mode);
  729. flag = flag && (x.nonlineMode == y.nonlineMode);
  730. flag = flag && (x.bias_bs == y.bias_bs);
  731. return flag;
  732. };
  733. };
  734. struct NCBKernParamHash {
  735. std::size_t operator()(
  736. const fallback::ConvBiasImpl::NCBKernParam& param) const {
  737. std::size_t result = reinterpret_cast<std::size_t>(param.filter_ptr);
  738. result = result ^ (reinterpret_cast<std::size_t>(param.src_ptr) << 3);
  739. result = result ^ (reinterpret_cast<std::size_t>(param.dst_ptr) << 7);
  740. result = result ^ (static_cast<std::size_t>(param.n) << 11);
  741. return result;
  742. };
  743. };
  744. } // namespace
  745. void ConvBiasImpl::AlgoMkldnnConv::kern_mkldnn_fp32(const NCBKernParam& param,
  746. const NCBKernIndex&) {
  747. const NCBKernParam& key = param;
  748. static std::unordered_map<NCBKernParam, std::vector<dnnl::primitive>,
  749. NCBKernParamHash, NCBKernParamEqual>
  750. kern_net_map;
  751. static std::unordered_map<
  752. NCBKernParam, std::vector<std::unordered_map<int, dnnl::memory>>,
  753. NCBKernParamHash, NCBKernParamEqual>
  754. kern_net_arg_map;
  755. auto x86_handle = static_cast<HandleImpl*>(inplace_cpu_handle().get());
  756. megdnn_assert(x86_handle != nullptr, "x86 handle can not be null");
  757. auto eng_mkldnn = x86_handle->mkldnn_engine();
  758. auto stream_mkldnn = x86_handle->mkldnn_stream();
  759. auto&& fm = param.filter_meta;
  760. const uint32_t group = fm.group;
  761. const uint32_t in = param.n;
  762. const uint32_t ic = fm.icpg * group;
  763. const uint32_t oc = fm.ocpg * group;
  764. const uint32_t ih = param.isz[0];
  765. const uint32_t iw = param.isz[1];
  766. const uint32_t kh = fm.spatial[0];
  767. const uint32_t kw = fm.spatial[1];
  768. const uint32_t pad_h = fm.padding[0];
  769. const uint32_t pad_w = fm.padding[1];
  770. const uint32_t stride_h = fm.stride[0];
  771. const uint32_t stride_w = fm.stride[1];
  772. const uint32_t oh = param.osz[0];
  773. const uint32_t ow = param.osz[1];
  774. if (kern_net_map.find(key) == kern_net_map.end()) {
  775. std::vector<dnnl::primitive> net;
  776. std::vector<std::unordered_map<int, dnnl::memory>> net_args;
  777. mkldnn_fp32_conv_instance(param, fm.ocpg, fm.icpg, group, in, ic, oc,
  778. ih, iw, kh, kw, pad_h, pad_w, stride_h,
  779. stride_w, oh, ow, net, net_args, eng_mkldnn);
  780. kern_net_map[key] = net;
  781. kern_net_arg_map[key] = net_args;
  782. }
  783. const auto& net = kern_net_map[key];
  784. const auto& net_args = kern_net_arg_map[key];
  785. for (size_t i = 0; i < net.size(); ++i) {
  786. net.at(i).execute(stream_mkldnn, net_args.at(i));
  787. }
  788. stream_mkldnn.wait();
  789. if ((param.bias_mode == megdnn::BiasMode::NO_BIAS ||
  790. param.bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) &&
  791. (param.nonlineMode != NonlineMode::IDENTITY &&
  792. param.nonlineMode != NonlineMode::RELU &&
  793. param.nonlineMode != NonlineMode::SIGMOID)) {
  794. /**
  795. *NO_BIAS and BROADCAST_CHANNEL_BIAS has be done in mkldnn conv, but
  796. *it is necessary to do activition function not supported by mkldnn.
  797. *do not need any bias op
  798. **/
  799. PostProcess<float>::run(
  800. param.dst_ptr, const_cast<void*>(param.bias_ptr), param.dst_ptr,
  801. megdnn::BiasMode::NO_BIAS, param.nonlineMode, param.bias_type,
  802. param.dst_type, in, oc, oh, ow);
  803. } else if (param.bias_mode == megdnn::BiasMode::BIAS) {
  804. PostProcess<float>::run(
  805. param.dst_ptr, const_cast<void*>(param.bias_ptr), param.dst_ptr,
  806. param.bias_mode, param.nonlineMode, param.bias_type,
  807. param.dst_type, in, oc, oh, ow);
  808. }
  809. }
  810. #endif
  811. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台