You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 52 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236
  1. /**
  2. * \file dnn/src/common/convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs/nn.h"
  13. #include "src/common/utils.h"
  14. using namespace megdnn;
  15. namespace {
  16. template <typename Param>
  17. std::string get_errmsg(
  18. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  19. const Param& param) {
  20. MEGDNN_MARK_USED_VAR(src);
  21. MEGDNN_MARK_USED_VAR(filter);
  22. MEGDNN_MARK_USED_VAR(dst);
  23. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  24. megdnn_layout_msg(dst) + ", " + "is_nchw=" +
  25. std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
  26. "is_xcorr=" +
  27. std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
  28. "pad_h=" + std::to_string(param.pad_h) + ", " +
  29. "pad_w=" + std::to_string(param.pad_w) + ", " +
  30. "stride_h=" + std::to_string(param.stride_h) + ", " +
  31. "stride_w=" + std::to_string(param.stride_w) + ", " +
  32. "dilate_h=" + std::to_string(param.dilate_h) + ", " +
  33. "dilate_w=" + std::to_string(param.dilate_w);
  34. }
  35. template <typename Param, typename Param::Format>
  36. uint32_t spatial_getter(uint32_t filter, const Param&) {
  37. return filter;
  38. }
  39. template <typename Parameter, typename Param>
  40. void make_canonized_filter_meta_nchw_nhwc(
  41. size_t src_ndim, const TensorLayout& filter, const Param& param,
  42. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  43. megdnn_assert(
  44. param.format == Param::Format::NCHW || param.format == Param::Format::NHWC);
  45. auto img_ndim = src_ndim - 2;
  46. size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
  47. if (param.sparse == Param::Sparse::DENSE) {
  48. megdnn_assert(
  49. filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
  50. "bad filter ndim for dense convolution: "
  51. "spatial_ndim=%zu filter_ndim=%zu",
  52. img_ndim, filter.ndim);
  53. ret.group = 1;
  54. flt_start = 0;
  55. } else {
  56. megdnn_assert(
  57. param.sparse == Param::Sparse::GROUP,
  58. "invalid convolution sparse type");
  59. megdnn_assert(
  60. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  61. "bad filter ndim for group convolution: "
  62. "spatial_ndim=%zu filter_ndim=%zu",
  63. img_ndim, filter.ndim);
  64. // grp, oc, ic, dims[]
  65. ret.group = filter[0];
  66. flt_start = 1;
  67. }
  68. uint32_t ic_block_size = 1, oc_block_size = 1;
  69. if (param.format == Param::Format::NCHW) {
  70. // filter should be (oc, ic, fh, fw)
  71. flt_spatial_start = 2;
  72. ocpg_pos = 0;
  73. icpg_pos = 1;
  74. } else {
  75. megdnn_assert(
  76. param.format == Param::Format::NHWC, "invalid conv tensor format");
  77. // filter should be (oc, fh, fw, ic)
  78. flt_spatial_start = 1;
  79. ocpg_pos = 0;
  80. icpg_pos = 3;
  81. }
  82. ret.spatial_ndim = src_ndim - 2;
  83. megdnn_assert(
  84. ret.spatial_ndim == 2,
  85. "only 2D convolution is supported, and input should be 4-dim; "
  86. "got input dim = %zu",
  87. src_ndim);
  88. ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
  89. ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
  90. auto dilation = ret.dilation;
  91. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  92. megdnn_assert(
  93. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  94. dilation[i]);
  95. ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
  96. filter[i + flt_start + flt_spatial_start], param);
  97. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  98. }
  99. }
  100. template <typename Parameter, typename Param>
  101. void make_canonized_filter_meta_nhwcd4(
  102. size_t src_ndim, const TensorLayout& filter, const Param& param,
  103. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  104. /**
  105. * input: N H IC/4 W 4
  106. * Filter:
  107. * OC/4, FH, FW, IC, 4 [dense]
  108. * GROUP, OC/4, FH, FW, IC, 4 [group]
  109. * GROUP/4, 1, FH, FW, 4 [chanwise]
  110. */
  111. megdnn_assert(param.format == Param::Format::NHWCD4);
  112. auto img_ndim = src_ndim - 3;
  113. size_t flt_start = 0, flt_spatial_start = 1;
  114. bool is_chanwise = false;
  115. if (param.sparse == Param::Sparse::DENSE) {
  116. megdnn_assert(
  117. filter.ndim == img_ndim + 3,
  118. "bad filter ndim for dense convolution: "
  119. "spatial_ndim=%zu filter_ndim=%zu",
  120. img_ndim, filter.ndim);
  121. // oc, ic, dims[]
  122. ret.group = 1;
  123. flt_start = 0;
  124. } else {
  125. megdnn_assert(
  126. param.sparse == Param::Sparse::GROUP,
  127. "invalid convolution sparse type");
  128. megdnn_assert(
  129. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
  130. "bad filter ndim for group convolution: "
  131. "spatial_ndim=%zu filter_ndim=%zu",
  132. img_ndim, filter.ndim);
  133. if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
  134. is_chanwise = true;
  135. ret.group = filter[0] * 4;
  136. } else {
  137. ret.group = filter[0];
  138. }
  139. flt_start = 1;
  140. }
  141. ret.spatial_ndim = src_ndim - 3;
  142. megdnn_assert(
  143. ret.spatial_ndim == 2,
  144. "only 2D convolution is supported, and input should be 4-dim; "
  145. "got input dim = %zu",
  146. src_ndim);
  147. if (is_chanwise) {
  148. ret.ocpg = 1;
  149. ret.icpg = 1;
  150. } else {
  151. ret.ocpg = filter[flt_start] * 4;
  152. ret.icpg = filter[flt_start + 3];
  153. }
  154. auto dilation = ret.dilation;
  155. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  156. megdnn_assert(
  157. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  158. dilation[i]);
  159. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  160. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  161. }
  162. }
  163. template <typename Parameter, typename Param>
  164. void make_canonized_filter_meta_nhwcd4_dot(
  165. size_t src_ndim, const TensorLayout& filter, const Param& param,
  166. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  167. /**
  168. * input: N H IC/4 W 4
  169. * Filter:
  170. * GROUP/4, 1, FH, FW, 4 [chanwise]
  171. * OC/4, FH, FW, IC/4, 4, 4 [dense]
  172. * GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
  173. */
  174. megdnn_assert(param.format == Param::Format::NHWCD4);
  175. auto img_ndim = src_ndim - 3;
  176. size_t flt_start = 0, flt_spatial_start = 1;
  177. bool is_chanwise = false;
  178. if (param.sparse == Param::Sparse::DENSE) {
  179. megdnn_assert(
  180. filter.ndim == img_ndim + 4,
  181. "bad filter ndim for dense convolution: "
  182. "spatial_ndim=%zu filter_ndim=%zu",
  183. img_ndim, filter.ndim);
  184. // oc, ic, dims[]
  185. ret.group = 1;
  186. flt_start = 0;
  187. } else {
  188. megdnn_assert(
  189. param.sparse == Param::Sparse::GROUP,
  190. "invalid convolution sparse type");
  191. megdnn_assert(
  192. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  193. "bad filter ndim for group convolution: "
  194. "spatial_ndim=%zu filter_ndim=%zu",
  195. img_ndim, filter.ndim);
  196. if (filter.ndim == img_ndim + 3) {
  197. megdnn_assert(filter[1] == 1);
  198. is_chanwise = true;
  199. ret.group = filter[0] * 4;
  200. } else {
  201. ret.group = filter[0];
  202. }
  203. flt_start = 1;
  204. }
  205. ret.spatial_ndim = src_ndim - 3;
  206. megdnn_assert(
  207. ret.spatial_ndim == 2,
  208. "only 2D convolution is supported, and input should be 4-dim; "
  209. "got input dim = %zu",
  210. src_ndim);
  211. if (is_chanwise) {
  212. ret.ocpg = 1;
  213. ret.icpg = 1;
  214. } else {
  215. ret.ocpg = filter[flt_start] * 4;
  216. ret.icpg = filter[flt_start + 3] * 4;
  217. }
  218. auto dilation = ret.dilation;
  219. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  220. megdnn_assert(
  221. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  222. dilation[i]);
  223. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  224. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  225. }
  226. }
  227. template <size_t pack_size, typename Parameter, typename Param>
  228. void make_canonized_filter_meta_nchwxx(
  229. size_t src_ndim, const TensorLayout& filter, const Param& param,
  230. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  231. /**
  232. * input: N IC/pack_size, H, W, pack_size
  233. *
  234. ** NCHW44-DOT mode
  235. * filter:
  236. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
  237. * [dense]
  238. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  239. * FH, FW, pack_size(OC), pack_size(IC)} [group]
  240. *
  241. * NCHW88 and NCHW44 mode
  242. * filter:
  243. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
  244. * [dense]
  245. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  246. * FH, FW, pack_size(IC), pack_size(OC)} [group]
  247. * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
  248. *
  249. *
  250. */
  251. megdnn_assert(
  252. param.format == Param::Format::NCHW88 ||
  253. param.format == Param::Format::NCHW44 ||
  254. param.format == Param::Format::NCHW44_DOT);
  255. size_t img_ndim = 2;
  256. size_t flt_start = 0;
  257. size_t flt_spatial_start = 2;
  258. size_t pack_c_size = 0;
  259. if (param.sparse == Param::Sparse::DENSE) {
  260. if (filter.ndim == img_ndim + 4) {
  261. // oihw8i8o case
  262. megdnn_assert(
  263. (filter[filter.ndim - 2] == pack_size &&
  264. filter[filter.ndim - 1] == pack_size) ||
  265. (filter[filter.ndim - 2] == 2 * pack_size &&
  266. filter[filter.ndim - 1] == 2 * pack_size),
  267. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  268. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  269. ret.group = 1;
  270. flt_start = 0;
  271. if (filter[filter.ndim - 2] == 2 * pack_size &&
  272. filter[filter.ndim - 1] == 2 * pack_size) {
  273. pack_c_size = 2 * pack_size;
  274. } else {
  275. pack_c_size = pack_size;
  276. }
  277. ret.ocpg = filter[flt_start] * pack_c_size;
  278. ret.icpg = filter[flt_start + 1] * pack_c_size;
  279. } else if (filter.ndim == img_ndim + 3) {
  280. // ohwi8o
  281. flt_start = 0;
  282. flt_spatial_start = 1;
  283. ret.group = 1;
  284. ret.ocpg = filter[flt_start] * pack_size;
  285. ret.icpg = filter[flt_start + 3];
  286. } else {
  287. megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim);
  288. }
  289. } else {
  290. megdnn_assert(
  291. param.sparse == Param::Sparse::GROUP,
  292. "invalid convolution sparse type");
  293. flt_start = 1;
  294. auto filter_oc = filter[flt_start];
  295. auto filter_ic = filter[flt_start + 1];
  296. if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) {
  297. // Depthwise case goihw8g
  298. megdnn_assert(
  299. filter.ndim == img_ndim + 4,
  300. "bad filter ndim for group convolution: "
  301. "spatial_ndim=%zu filter_ndim=%zu",
  302. img_ndim, filter.ndim);
  303. megdnn_assert(
  304. filter[filter.ndim - 1] == pack_size,
  305. "last dim of filter must be %zu, but %zu", pack_size,
  306. filter[filter.ndim - 1]);
  307. ret.group = filter[0] * pack_size;
  308. ret.ocpg = filter_oc;
  309. ret.icpg = filter_ic;
  310. } else {
  311. // norm group case goihw8i8o
  312. megdnn_assert(
  313. filter.ndim == img_ndim + 5,
  314. "bad filter ndim for group convolution: "
  315. "spatial_ndim=%zu filter_ndim=%zu",
  316. img_ndim, filter.ndim);
  317. megdnn_assert(
  318. (filter[filter.ndim - 1] == pack_size &&
  319. filter[filter.ndim - 2] == pack_size) ||
  320. (filter[filter.ndim - 1] == 2 * pack_size &&
  321. filter[filter.ndim - 2] == 2 * pack_size),
  322. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  323. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  324. ret.group = filter[0];
  325. if (filter[filter.ndim - 2] == 2 * pack_size &&
  326. filter[filter.ndim - 1] == 2 * pack_size) {
  327. ret.ocpg = filter_oc * 2 * pack_size;
  328. ret.icpg = filter_ic * 2 * pack_size;
  329. } else {
  330. ret.ocpg = filter_oc * pack_size;
  331. ret.icpg = filter_ic * pack_size;
  332. }
  333. }
  334. }
  335. ret.spatial_ndim = 2;
  336. megdnn_assert(
  337. ret.spatial_ndim == 2,
  338. "only 2D convolution is supported, and input should be 5-dim "
  339. "for nchwxx; "
  340. "got input dim = %zu",
  341. src_ndim);
  342. auto dilation = ret.dilation;
  343. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  344. megdnn_assert(
  345. dilation[i] == 1,
  346. "NCHWXX has invalid dilation on spatial dim %zu: %u, "
  347. "require to be 1",
  348. i, dilation[i]);
  349. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  350. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  351. }
  352. }
  353. template <size_t pack_size, typename Parameter, typename Param>
  354. void make_canonized_filter_meta_nchwx(
  355. size_t src_ndim, const TensorLayout& filter, const Param& param,
  356. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  357. /**
  358. * input: N IC/pack_size, H, W, pack_size
  359. * filter:
  360. * OC, IC/pack_size, FH, FW, pack_size [dense]
  361. * GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
  362. */
  363. megdnn_assert(
  364. param.format == Param::Format::NCHW4 ||
  365. param.format == Param::Format::NCHW8 ||
  366. param.format == Param::Format::NCHW32 ||
  367. param.format == Param::Format::NCHW4_NCHW ||
  368. param.format == Param::Format::NCHW4_NHWC ||
  369. param.format == Param::Format::NCHW4_NCHW32 ||
  370. param.format == Param::Format::NCHW32_NCHW4 ||
  371. param.format == Param::Format::NCHW64);
  372. auto img_ndim = src_ndim - 3;
  373. size_t flt_start = 0, flt_spatial_start = 2;
  374. if (param.sparse == Param::Sparse::DENSE) {
  375. megdnn_assert(
  376. filter.ndim == img_ndim + 3,
  377. "bad filter ndim for dense convolution: "
  378. "spatial_ndim=%zu filter_ndim=%zu",
  379. img_ndim, filter.ndim);
  380. // oc, ic, dims[]
  381. ret.group = 1;
  382. flt_start = 0;
  383. } else {
  384. megdnn_assert(
  385. param.sparse == Param::Sparse::GROUP,
  386. "invalid convolution sparse type");
  387. megdnn_assert(
  388. filter.ndim == img_ndim + 4,
  389. "bad filter ndim for group convolution: "
  390. "spatial_ndim=%zu filter_ndim=%zu",
  391. img_ndim, filter.ndim);
  392. ret.group = filter[0];
  393. flt_start = 1;
  394. }
  395. ret.spatial_ndim = src_ndim - 3;
  396. megdnn_assert(
  397. ret.spatial_ndim == 2,
  398. "only 2D convolution is supported, and input should be 5-dim "
  399. "for nchw4; "
  400. "got input dim = %zu",
  401. src_ndim);
  402. ret.ocpg = filter[flt_start];
  403. ret.icpg = filter[flt_start + 1] * pack_size;
  404. auto dilation = ret.dilation;
  405. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  406. megdnn_assert(
  407. dilation[i] == 1,
  408. "NCHW4 has invalid dilation on spatial dim %zu: %u, "
  409. "require to be 1",
  410. i, dilation[i]);
  411. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  412. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  413. }
  414. }
  415. template <size_t pack_size, typename Parameter, typename Param>
  416. void make_canonized_filter_meta_chwnx(
  417. size_t src_ndim, const TensorLayout& filter, const Param& param,
  418. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  419. /**
  420. * input: IC / pack_size, H, W, N, pack_size
  421. * Filter:
  422. * IC / pack_size, FH, FW, OC, pack_size [dense]
  423. * GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
  424. * not implemented [chanwise]
  425. */
  426. megdnn_assert(param.format == Param::Format::CHWN4);
  427. auto img_ndim = src_ndim - 3;
  428. size_t flt_start = 0, flt_spatial_start = 1;
  429. if (param.sparse == Param::Sparse::DENSE) {
  430. megdnn_assert(
  431. filter.ndim == img_ndim + 3,
  432. "bad filter ndim for dense convolution: "
  433. "spatial_ndim=%zu filter_ndim=%zu",
  434. img_ndim, filter.ndim);
  435. // oc, ic, dims[]
  436. ret.group = 1;
  437. flt_start = 0;
  438. } else {
  439. megdnn_assert(
  440. param.sparse == Param::Sparse::GROUP,
  441. "invalid convolution sparse type");
  442. megdnn_assert(
  443. filter.ndim == img_ndim + 4,
  444. "bad filter ndim for group convolution: "
  445. "spatial_ndim=%zu filter_ndim=%zu",
  446. img_ndim, filter.ndim);
  447. ret.group = filter[0];
  448. flt_start = 1;
  449. }
  450. ret.spatial_ndim = src_ndim - 3;
  451. megdnn_assert(
  452. ret.spatial_ndim == 2,
  453. "only 2D convolution is supported, and input should be 4-dim; "
  454. "got input dim = %zu",
  455. src_ndim);
  456. ret.icpg = filter[flt_start] * pack_size;
  457. ret.ocpg = filter[flt_start + 3];
  458. auto dilation = ret.dilation;
  459. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  460. megdnn_assert(
  461. dilation[i] == 1,
  462. "CHWNx has invalid dilation on spatial dim %zu: %u, "
  463. "require to be 1",
  464. i, dilation[i]);
  465. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  466. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  467. }
  468. }
  469. } // namespace
  470. namespace megdnn {
  471. template <typename Parameter>
  472. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  473. make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const {
  474. megdnn_assert_contiguous(filter);
  475. CanonizedFilterMeta ret;
  476. ret.dtype = filter.dtype;
  477. ret.format = param().format;
  478. if (param().mode == Mode::CONVOLUTION) {
  479. ret.should_flip = true;
  480. } else {
  481. megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
  482. ret.should_flip = false;
  483. }
  484. ret.stride[0] = param().stride_h;
  485. ret.stride[1] = param().stride_w;
  486. ret.padding[0] = param().pad_h;
  487. ret.padding[1] = param().pad_w;
  488. ret.dilation[0] = param().dilate_h;
  489. ret.dilation[1] = param().dilate_w;
  490. if (param().format == Param::Format::NHWCD4) {
  491. if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  492. filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  493. make_canonized_filter_meta_nhwcd4_dot<Parameter>(
  494. src_ndim, filter, param(), ret);
  495. } else {
  496. make_canonized_filter_meta_nhwcd4<Parameter>(
  497. src_ndim, filter, param(), ret);
  498. }
  499. } else if (
  500. param().format == Param::Format::NCHW4 ||
  501. param().format == Param::Format::NCHW4_NCHW ||
  502. param().format == Param::Format::NCHW4_NHWC ||
  503. param().format == Param::Format::NCHW4_NCHW32) {
  504. make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret);
  505. } else if (param().format == Param::Format::NCHW8) {
  506. make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret);
  507. } else if (param().format == Param::Format::NCHW88) {
  508. make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret);
  509. } else if (
  510. param().format == Param::Format::NCHW44 ||
  511. param().format == Param::Format::NCHW44_DOT) {
  512. make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret);
  513. } else if (
  514. param().format == Param::Format::NCHW32 ||
  515. param().format == Param::Format::NCHW32_NCHW4) {
  516. make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret);
  517. } else if (param().format == Param::Format::CHWN4) {
  518. make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret);
  519. } else if (param().format == Param::Format::NCHW64) {
  520. make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret);
  521. } else {
  522. megdnn_assert(
  523. param().format == Param::Format::NHWC ||
  524. param().format == Param::Format::NCHW);
  525. make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter, param(), ret);
  526. }
  527. return ret;
  528. }
  529. template <typename Parameter>
  530. void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
  531. DType src, DType filter, DType& dst) const {
  532. // The first one will be the default choice.
  533. SmallVector<DType> supported_dst_dtype;
  534. // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
  535. if (src.category() == DTypeCategory::FLOAT) {
  536. supported_dst_dtype.push_back(src);
  537. } else if (src.enumv() == DTypeEnum::Int8) {
  538. supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
  539. } else if (
  540. src.enumv() == DTypeEnum::QuantizedS8 ||
  541. src.enumv() == DTypeEnum::Quantized8Asymm ||
  542. src.enumv() == DTypeEnum::QuantizedS4 ||
  543. src.enumv() == DTypeEnum::Quantized4Asymm) {
  544. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
  545. bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
  546. ((dst.enumv() == DTypeEnum::QuantizedS4 ||
  547. dst.enumv() == DTypeEnum::Quantized4Asymm) &&
  548. src.enumv() == DTypeEnum::QuantizedS8) ||
  549. ((src.enumv() == DTypeEnum::QuantizedS4 ||
  550. src.enumv() == DTypeEnum::Quantized4Asymm) &&
  551. dst.enumv() == DTypeEnum::QuantizedS8));
  552. if (cond_dst) {
  553. supported_dst_dtype.push_back(dst);
  554. }
  555. if (src.enumv() == DTypeEnum::QuantizedS8) {
  556. supported_dst_dtype.push_back(dtype::Float32());
  557. }
  558. } else if (src.enumv() == DTypeEnum::QuantizedS32) {
  559. //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
  560. megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
  561. supported_dst_dtype.push_back(dtype::QuantizedS8(
  562. src.param<dtype::QuantizedS32>().scale /
  563. filter.param<dtype::QuantizedS8>().scale));
  564. } else {
  565. megdnn_throw(ssprintf(
  566. "unsupported input / filter DType: %s x %s", src.name(),
  567. filter.name()));
  568. }
  569. if (!dst.valid()) {
  570. dst = supported_dst_dtype.at(0);
  571. } else {
  572. bool dst_supported = false;
  573. for (auto&& dt : supported_dst_dtype) {
  574. if (dtype_almost_equal(dt, dst)) {
  575. dst_supported = true;
  576. break;
  577. }
  578. }
  579. MEGDNN_MARK_USED_VAR(dst_supported);
  580. megdnn_assert(
  581. dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(),
  582. filter.name(), dst.name());
  583. }
  584. megdnn_assert(
  585. (param().compute_mode == Param::ComputeMode::FLOAT32 ||
  586. param().compute_mode == Param::ComputeMode::DEFAULT)
  587. #if !MEGDNN_DISABLE_FLOAT16
  588. || src.enumv() == DTypeEnum::Float16 ||
  589. src.enumv() == DTypeEnum::BFloat16
  590. #endif
  591. ,
  592. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  593. "input / output.");
  594. }
  595. template <typename Parameter>
  596. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  597. deduce_layout_fwd(
  598. const TensorLayout& src, const TensorLayout& filter,
  599. TensorLayout& dst) const {
  600. auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
  601. MEGDNN_MARK_USED_VAR(errmsg);
  602. megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
  603. megdnn_assert(
  604. ((src.dtype.enumv() == filter.dtype.enumv()) ||
  605. (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
  606. filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
  607. "%s", errmsg().c_str());
  608. check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
  609. size_t img_dim;
  610. if (param().format == Param::Format::NCHW ||
  611. param().format == Param::Format::NHWC) {
  612. img_dim = src.ndim - 2;
  613. megdnn_assert(
  614. filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s",
  615. errmsg().c_str());
  616. } else {
  617. megdnn_assert(
  618. param().format == Param::Format::NHWCD4 ||
  619. param().format == Param::Format::NCHW4 ||
  620. param().format == Param::Format::NCHW4_NCHW ||
  621. param().format == Param::Format::NCHW4_NHWC ||
  622. param().format == Param::Format::NCHW4_NCHW32 ||
  623. param().format == Param::Format::NCHW44 ||
  624. param().format == Param::Format::NCHW44_DOT ||
  625. param().format == Param::Format::NCHW8 ||
  626. param().format == Param::Format::NCHW32 ||
  627. param().format == Param::Format::NCHW32_NCHW4 ||
  628. param().format == Param::Format::NCHW88 ||
  629. param().format == Param::Format::CHWN4 ||
  630. param().format == Param::Format::NCHW64);
  631. img_dim = src.ndim - 3;
  632. if ((param().format == Param::Format::NCHW88 ||
  633. param().format == Param::Format::NCHW44_DOT ||
  634. param().format == Param::Format::NCHW44) &&
  635. filter.ndim == 5) {
  636. img_dim = src.ndim - 2;
  637. }
  638. megdnn_assert(
  639. filter.ndim == img_dim + 3 ||
  640. (filter.ndim == img_dim + 2 &&
  641. (param().format == Param::Format::NCHW88 ||
  642. param().format == Param::Format::NCHW44_DOT ||
  643. param().format == Param::Format::NCHW44)) ||
  644. filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5,
  645. "%s", errmsg().c_str());
  646. if (param().format == Param::Format::NCHW4 ||
  647. param().format == Param::Format::NCHW4_NCHW ||
  648. param().format == Param::Format::NCHW4_NCHW32) {
  649. megdnn_assert(
  650. src.ndim == 5 &&
  651. (filter.ndim == 5 || filter.ndim == 6 ||
  652. filter.ndim == 7) &&
  653. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  654. "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
  655. "filter's ndim is "
  656. "5 or 6, and "
  657. "last shape "
  658. "is 4 "
  659. "but got src %s, filter %s",
  660. src.to_string().c_str(), filter.to_string().c_str());
  661. }
  662. if (param().format == Param::Format::NCHW8) {
  663. megdnn_assert(
  664. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  665. src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8,
  666. "NCHW8 require src and filter's ndim is 5 or 6, and last "
  667. "shape is 8 "
  668. "but got src %s, filter %s",
  669. src.to_string().c_str(), filter.to_string().c_str());
  670. }
  671. if (param().format == Param::Format::NCHW32 ||
  672. param().format == Param::Format::NCHW32_NCHW4) {
  673. megdnn_assert(
  674. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  675. src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32,
  676. "NCHW32/NCHW32_NCHW4 require src and filter's ndim "
  677. "is 5 or 6, and last "
  678. "shape is 32 "
  679. "but got src %s, filter %s",
  680. src.to_string().c_str(), filter.to_string().c_str());
  681. }
  682. if (param().format == Param::Format::NCHW88) {
  683. megdnn_assert(
  684. (src.ndim == 4 && filter.ndim == 5 &&
  685. filter[filter.ndim - 1] == 8) ||
  686. (src.ndim == 5 &&
  687. ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) ||
  688. (filter.ndim == 7 && filter[filter.ndim - 1] == 8 &&
  689. filter[filter.ndim - 2] == 8)) &&
  690. src[src.ndim - 1] == 8),
  691. "NCHW88 require src ndim is 5 and filter's ndim is 6 "
  692. ", and last shape two is 8 but got src %s, filter %s",
  693. src.to_string().c_str(), filter.to_string().c_str());
  694. }
  695. if (param().format == Param::Format::NCHW44 ||
  696. param().format == Param::Format::NCHW44_DOT) {
  697. //! support nchw44 filter change to 88 for int8 winogradf23_88 using
  698. //! MK8 mamtul
  699. megdnn_assert(
  700. (src.ndim == 4 && filter.ndim == 5 &&
  701. filter[filter.ndim - 1] == 4) ||
  702. (src.ndim == 5 &&
  703. ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 ||
  704. filter[filter.ndim - 1] == 8)) ||
  705. (filter.ndim == 7 &&
  706. (filter[filter.ndim - 1] == 4 ||
  707. filter[filter.ndim - 1] == 8) &&
  708. (filter[filter.ndim - 2] == 4 ||
  709. filter[filter.ndim - 2] == 8))) &&
  710. src[src.ndim - 1] == 4),
  711. "NCHW44 require src ndim is 5 and filter's ndim is 6 "
  712. ", and last shape two is 4 but got src %s, filter %s",
  713. src.to_string().c_str(), filter.to_string().c_str());
  714. }
  715. if (param().format == Param::Format::CHWN4) {
  716. megdnn_assert(
  717. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  718. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  719. "CHWN4 require src and filter's ndim is 5 or 6, and last "
  720. "shape is 4 "
  721. "but got src %s, filter %s",
  722. src.to_string().c_str(), filter.to_string().c_str());
  723. }
  724. if (param().format == Param::Format::NCHW64) {
  725. megdnn_assert(
  726. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  727. src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64,
  728. "NCHW64 require src and filter's ndim is 5 or 6, and "
  729. "last shape is 64 but got src %s, filter %s",
  730. src.to_string().c_str(), filter.to_string().c_str());
  731. }
  732. }
  733. megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported");
  734. auto cflt = make_canonized_filter_meta(src.ndim, filter);
  735. if (param().format == Param::Format::NCHW ||
  736. param().format == Param::Format::NHWC) {
  737. size_t src_or_dst_c_pos = 0;
  738. size_t src_or_dst_spatial_start = 0;
  739. if (param().format == Param::Format::NCHW) {
  740. src_or_dst_c_pos = 1;
  741. src_or_dst_spatial_start = 2;
  742. } else {
  743. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  744. src_or_dst_c_pos = 3;
  745. src_or_dst_spatial_start = 1;
  746. }
  747. megdnn_assert(
  748. cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
  749. errmsg().c_str());
  750. dst.ndim = src.ndim;
  751. dst[0] = src[0];
  752. dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
  753. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  754. dst[i + src_or_dst_spatial_start] = infer_conv_shape(
  755. src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  756. cflt.stride[i], cflt.padding[i]);
  757. }
  758. } else if (param().format == Param::Format::NCHW4) {
  759. megdnn_assert(
  760. src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
  761. src.ndim);
  762. megdnn_assert(
  763. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  764. errmsg().c_str(), cflt.icpg, cflt.group);
  765. dst.ndim = src.ndim;
  766. dst[0] = src[0];
  767. auto oc = cflt.ocpg * cflt.group;
  768. megdnn_assert(oc % 4 == 0);
  769. dst[1] = oc / 4;
  770. dst[2] = infer_conv_shape(
  771. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  772. dst[3] = infer_conv_shape(
  773. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  774. dst[4] = 4;
  775. } else if (param().format == Param::Format::NCHW8) {
  776. megdnn_assert(
  777. src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
  778. src.ndim);
  779. megdnn_assert(
  780. cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u",
  781. errmsg().c_str(), cflt.icpg, cflt.group);
  782. dst.ndim = src.ndim;
  783. dst[0] = src[0];
  784. auto oc = cflt.ocpg * cflt.group;
  785. megdnn_assert(oc % 8 == 0);
  786. dst[1] = oc / 8;
  787. dst[2] = infer_conv_shape(
  788. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  789. dst[3] = infer_conv_shape(
  790. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  791. dst[4] = 8;
  792. } else if (param().format == Param::Format::NCHW32) {
  793. megdnn_assert(
  794. src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
  795. src.ndim);
  796. megdnn_assert(
  797. cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
  798. errmsg().c_str(), cflt.icpg, cflt.group);
  799. dst.ndim = src.ndim;
  800. dst[0] = src[0];
  801. auto oc = cflt.ocpg * cflt.group;
  802. megdnn_assert(oc % 32 == 0);
  803. dst[1] = oc / 32;
  804. dst[2] = infer_conv_shape(
  805. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  806. dst[3] = infer_conv_shape(
  807. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  808. dst[4] = 32;
  809. } else if (param().format == Param::Format::NCHW88) {
  810. megdnn_assert(
  811. src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
  812. "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
  813. dst.ndim = 5;
  814. dst[0] = src[0];
  815. auto oc = cflt.ocpg * cflt.group;
  816. megdnn_assert(oc % 8 == 0);
  817. dst[1] = oc / 8;
  818. dst[2] = infer_conv_shape(
  819. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  820. dst[3] = infer_conv_shape(
  821. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  822. dst[4] = 8;
  823. if (cflt.group == 1) {
  824. megdnn_assert(
  825. cflt.icpg * cflt.group == src[1] * 8 ||
  826. (cflt.icpg * cflt.group == src[1]),
  827. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
  828. }
  829. } else if (
  830. param().format == Param::Format::NCHW44 ||
  831. param().format == Param::Format::NCHW44_DOT) {
  832. megdnn_assert(
  833. src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
  834. "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
  835. dst.ndim = 5;
  836. dst[0] = src[0];
  837. auto oc = cflt.ocpg * cflt.group;
  838. megdnn_assert(oc % 4 == 0);
  839. dst[1] = oc / 4;
  840. dst[2] = infer_conv_shape(
  841. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  842. dst[3] = infer_conv_shape(
  843. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  844. dst[4] = 4;
  845. if (cflt.group == 1) {
  846. megdnn_assert(
  847. cflt.icpg * cflt.group == src[1] * 4 ||
  848. (cflt.icpg * cflt.group == src[1]),
  849. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
  850. }
  851. } else if (param().format == Param::Format::CHWN4) {
  852. megdnn_assert(
  853. src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
  854. src.ndim);
  855. megdnn_assert(
  856. cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u",
  857. errmsg().c_str(), cflt.icpg, cflt.group);
  858. dst.ndim = src.ndim;
  859. dst[3] = src[3];
  860. auto oc = cflt.ocpg * cflt.group;
  861. megdnn_assert(oc % 4 == 0);
  862. dst[0] = oc / 4;
  863. dst[1] = infer_conv_shape(
  864. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  865. dst[2] = infer_conv_shape(
  866. src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  867. dst[4] = 4;
  868. } else if (param().format == Param::Format::NCHW4_NCHW) {
  869. megdnn_assert(
  870. src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
  871. src.ndim);
  872. megdnn_assert(
  873. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  874. errmsg().c_str(), cflt.icpg, cflt.group);
  875. dst.ndim = 4;
  876. dst[0] = src[0];
  877. auto oc = cflt.ocpg * cflt.group;
  878. dst[1] = oc;
  879. dst[2] = infer_conv_shape(
  880. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  881. dst[3] = infer_conv_shape(
  882. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  883. } else if (param().format == Param::Format::NCHW4_NHWC) {
  884. megdnn_assert(
  885. src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
  886. src.ndim);
  887. megdnn_assert(
  888. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  889. errmsg().c_str(), cflt.icpg, cflt.group);
  890. dst.ndim = 4;
  891. dst[0] = src[0];
  892. dst[1] = infer_conv_shape(
  893. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  894. dst[2] = infer_conv_shape(
  895. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  896. auto oc = cflt.ocpg * cflt.group;
  897. dst[3] = oc;
  898. } else if (param().format == Param::Format::NCHW4_NCHW32) {
  899. megdnn_assert(
  900. src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
  901. src.ndim);
  902. megdnn_assert(
  903. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  904. errmsg().c_str(), cflt.icpg, cflt.group);
  905. dst.ndim = src.ndim;
  906. dst[0] = src[0];
  907. auto oc = cflt.ocpg * cflt.group;
  908. megdnn_assert(oc % 32 == 0);
  909. dst[1] = oc / 32;
  910. dst[2] = infer_conv_shape(
  911. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  912. dst[3] = infer_conv_shape(
  913. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  914. dst[4] = 32;
  915. } else if (param().format == Param::Format::NCHW32_NCHW4) {
  916. megdnn_assert(
  917. src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
  918. src.ndim);
  919. megdnn_assert(
  920. cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
  921. errmsg().c_str(), cflt.icpg, cflt.group);
  922. dst.ndim = src.ndim;
  923. dst[0] = src[0];
  924. auto oc = cflt.ocpg * cflt.group;
  925. megdnn_assert(oc % 4 == 0);
  926. dst[1] = oc / 4;
  927. dst[2] = infer_conv_shape(
  928. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  929. dst[3] = infer_conv_shape(
  930. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  931. dst[4] = 4;
  932. } else if (param().format == Param::Format::NCHW64) {
  933. megdnn_assert(
  934. src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
  935. src.ndim);
  936. megdnn_assert(
  937. cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u",
  938. errmsg().c_str(), cflt.icpg, cflt.group);
  939. dst.ndim = src.ndim;
  940. dst[0] = src[0];
  941. auto oc = cflt.ocpg * cflt.group;
  942. megdnn_assert(oc % 64 == 0);
  943. dst[1] = oc / 64;
  944. dst[2] = infer_conv_shape(
  945. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  946. dst[3] = infer_conv_shape(
  947. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  948. dst[4] = 64;
  949. } else {
  950. megdnn_assert(param().format == Param::Format::NHWCD4);
  951. megdnn_assert(
  952. src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
  953. src.ndim);
  954. megdnn_assert(
  955. cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u",
  956. errmsg().c_str(), cflt.icpg, cflt.group);
  957. dst.ndim = src.ndim;
  958. dst[0] = src[0];
  959. auto oc = cflt.ocpg * cflt.group;
  960. megdnn_assert(oc % 4 == 0);
  961. dst[2] = oc / 4;
  962. dst[1] = infer_conv_shape(
  963. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  964. dst[3] = infer_conv_shape(
  965. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  966. megdnn_assert(src[4] == 4);
  967. dst[4] = 4;
  968. }
  969. if (!src.format.is_default() && !src.format.is_lowbit_aligned()) { // propagate
  970. dst.format = src.format;
  971. } else { // determined by dtype
  972. dst.format = TensorFormat(dst.dtype);
  973. }
  974. dst.init_contiguous_stride();
  975. return cflt;
  976. }
  977. /**
  978. * \warning: An explicit specialization shall be declared in a namespace
  979. * enclosing the specialized template. An explicit specialization whose
  980. * declarator-id is not qualified shall be declared in the nearest enclosing
  981. * namespace of the template, or, if the namespace is inline (7.3.1), any
  982. * namespace from its enclosing namespace set.
  983. * refer to:
  984. * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
  985. */
  986. template <>
  987. ConvolutionBase<param::Convolution>::CanonizedFilterMeta ConvolutionBase<
  988. param::Convolution>::
  989. check_layout_fwd(
  990. const TensorLayout& src, const TensorLayout& filter,
  991. const TensorLayout& dst) const {
  992. megdnn_assert_contiguous(src);
  993. megdnn_assert_contiguous(filter);
  994. TensorLayout dst_expected;
  995. dst_expected.dtype = dst.dtype;
  996. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  997. megdnn_assert_eq_layout(dst_expected, dst);
  998. return ret;
  999. }
  1000. template <>
  1001. ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ConvolutionBase<param::ConvBias>::
  1002. check_layout_fwd(
  1003. const TensorLayout& src, const TensorLayout& filter,
  1004. const TensorLayout& dst) const {
  1005. megdnn_assert_contiguous(src);
  1006. megdnn_assert_contiguous(filter);
  1007. TensorLayout dst_expected;
  1008. dst_expected.dtype = dst.dtype;
  1009. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1010. megdnn_assert_eq_layout(dst_expected, dst);
  1011. return ret;
  1012. }
  1013. template <>
  1014. ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ConvolutionBase<
  1015. param::BatchConvBias>::
  1016. check_layout_fwd(
  1017. const TensorLayout& src, const TensorLayout& filter,
  1018. const TensorLayout& dst) const {
  1019. megdnn_assert_contiguous(src);
  1020. megdnn_assert_contiguous(filter);
  1021. TensorLayout dst_expected;
  1022. dst_expected.dtype = dst.dtype;
  1023. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1024. megdnn_assert_eq_layout(dst_expected, dst);
  1025. return ret;
  1026. }
  1027. void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
  1028. check_or_deduce_dtype_fwd(src, filter, dst);
  1029. }
  1030. void ConvolutionForward::deduce_layout(
  1031. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
  1032. deduce_layout_fwd(src, filter, dst);
  1033. }
  1034. ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
  1035. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  1036. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) {
  1037. auto ret = check_layout_fwd(src, filter, dst);
  1038. auto required_workspace_in_bytes =
  1039. get_workspace_in_bytes(src, filter, dst, preprocessed_filter);
  1040. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1041. return ret;
  1042. }
  1043. ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec(
  1044. const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
  1045. size_t workspace_in_bytes) {
  1046. auto grad_fwd = grad;
  1047. auto filter_fwd = filter;
  1048. auto diff_fwd = diff;
  1049. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  1050. grad_fwd.init_contiguous_stride();
  1051. diff_fwd.init_contiguous_stride();
  1052. auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
  1053. auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
  1054. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1055. return ret;
  1056. }
  1057. void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) {
  1058. SmallVector<DType> supported_dst_dtype;
  1059. if (filter.category() == diff.category() &&
  1060. filter.category() == DTypeCategory::FLOAT) {
  1061. supported_dst_dtype.push_back(filter);
  1062. } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
  1063. supported_dst_dtype.push_back(dtype::Int32());
  1064. } else if (
  1065. (filter.enumv() == DTypeEnum::QuantizedS8 &&
  1066. diff.enumv() == DTypeEnum::QuantizedS8) ||
  1067. (filter.enumv() == DTypeEnum::Quantized8Asymm &&
  1068. diff.enumv() == DTypeEnum::Quantized8Asymm)) {
  1069. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
  1070. if (grad.valid() && grad.enumv() == diff.enumv()) {
  1071. supported_dst_dtype.push_back(grad);
  1072. }
  1073. } else {
  1074. megdnn_throw(ssprintf(
  1075. "unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
  1076. }
  1077. if (!grad.valid()) {
  1078. grad = supported_dst_dtype.at(0);
  1079. } else {
  1080. megdnn_assert(
  1081. vec_contains(supported_dst_dtype, grad),
  1082. "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
  1083. grad.name());
  1084. }
  1085. megdnn_assert(
  1086. param().compute_mode != Param::ComputeMode::FLOAT32
  1087. #if !MEGDNN_DISABLE_FLOAT16
  1088. || filter.enumv() == DTypeEnum::Float16 ||
  1089. filter.enumv() == DTypeEnum::BFloat16
  1090. #endif
  1091. ,
  1092. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  1093. "input / output.");
  1094. }
  1095. void ConvolutionBackwardData::deduce_layout(
  1096. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
  1097. auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
  1098. MEGDNN_MARK_USED_VAR(errmsg);
  1099. megdnn_assert_contiguous(filter);
  1100. megdnn_assert_contiguous(diff);
  1101. megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
  1102. megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());
  1103. deduce_dtype(filter.dtype, diff.dtype, grad.dtype);
  1104. auto cflt = make_canonized_filter_meta(diff.ndim, filter);
  1105. auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
  1106. MEGDNN_MARK_USED_VAR(errmsg);
  1107. auto i = (out - 1) * stride + filter;
  1108. megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
  1109. return i - pad * 2;
  1110. };
  1111. if (param().format == Param::Format::NCHW ||
  1112. param().format == Param::Format::NHWC) {
  1113. size_t src_or_dst_c_pos = 0;
  1114. size_t src_or_dst_spatial_start = 0;
  1115. if (param().format == Param::Format::NCHW) {
  1116. src_or_dst_c_pos = 1;
  1117. src_or_dst_spatial_start = 2;
  1118. } else {
  1119. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  1120. src_or_dst_c_pos = 3;
  1121. src_or_dst_spatial_start = 1;
  1122. }
  1123. megdnn_assert(
  1124. cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
  1125. errmsg().c_str());
  1126. grad.ndim = diff.ndim;
  1127. grad[0] = diff[0];
  1128. grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
  1129. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  1130. grad[i + src_or_dst_spatial_start] =
  1131. deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  1132. cflt.stride[i], cflt.padding[i]);
  1133. }
  1134. } else if (param().format == Param::Format::NCHW4) {
  1135. megdnn_assert(
  1136. diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
  1137. diff.ndim);
  1138. megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
  1139. megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
  1140. grad.ndim = diff.ndim;
  1141. grad[0] = diff[0];
  1142. auto ic = cflt.icpg * cflt.group;
  1143. megdnn_assert(ic % 4 == 0);
  1144. grad[1] = ic / 4;
  1145. grad[2] = deduce(
  1146. diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1147. grad[3] = deduce(
  1148. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1149. megdnn_assert(diff[4] == 4);
  1150. grad[4] = 4;
  1151. } else {
  1152. megdnn_assert(param().format == Param::Format::NHWCD4);
  1153. megdnn_assert(
  1154. diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu",
  1155. diff.ndim);
  1156. megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str());
  1157. grad.ndim = diff.ndim;
  1158. grad[0] = diff[0];
  1159. auto ic = cflt.icpg * cflt.group;
  1160. megdnn_assert(ic % 4 == 0);
  1161. grad[2] = ic / 4;
  1162. grad[1] = deduce(
  1163. diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1164. grad[3] = deduce(
  1165. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1166. megdnn_assert(diff[4] == 4);
  1167. grad[4] = 4;
  1168. }
  1169. grad.format = diff.format;
  1170. grad.init_contiguous_stride();
  1171. }
  1172. ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec(
  1173. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1174. size_t workspace_in_bytes) {
  1175. megdnn_assert(
  1176. src.dtype.category() == DTypeCategory::FLOAT &&
  1177. diff.dtype.category() == DTypeCategory::FLOAT &&
  1178. grad.dtype.category() == DTypeCategory::FLOAT,
  1179. "only float type is supported for conv backward filter");
  1180. auto src_fwd = src;
  1181. auto diff_fwd = diff;
  1182. src_fwd.init_contiguous_stride();
  1183. diff_fwd.init_contiguous_stride();
  1184. auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
  1185. auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
  1186. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1187. return ret;
  1188. }
  1189. } // namespace megdnn
  1190. // vim: syntax=cpp.doxygen