You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860
  1. /**
  2. * \file dnn/src/naive/convolution/helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/oprs/nn.h"
  14. #include "src/common/utils.h"
  15. #include <cstring>
  16. namespace megdnn {
  17. namespace naive {
  18. namespace convolution {
  19. struct GroupCounter {
  20. const size_t grp_size;
  21. size_t cur_grp = 0, cur_off = 0;
  22. explicit GroupCounter(size_t grp_size) : grp_size{grp_size} {}
  23. void next() {
  24. if ((++cur_off) == grp_size) {
  25. cur_off = 0;
  26. ++cur_grp;
  27. }
  28. }
  29. };
  30. struct StrategyFwd {
  31. template <typename st, typename ft, typename ct>
  32. static void on(st& s, ft& f, ct& d, DType, DType, DType) {
  33. d += static_cast<ct>(s) * static_cast<ct>(f);
  34. }
  35. template <typename ct, typename dt>
  36. static void write(ct& d, dt& dst) {
  37. dst = static_cast<dt>(d);
  38. }
  39. template <typename dt>
  40. static void init_dval(dt& d) {
  41. d = static_cast<dt>(0);
  42. }
  43. };
  44. // Explicit specialization of member function template is not allowed to happen
  45. // in class scope, this is a defect of C++ specification which will be fixed in
  46. // C++17. We workaround this by marking the implmentation as inline and move
  47. // out of class definition.
  48. template <>
  49. inline void StrategyFwd::on(dt_quint8& s, dt_quint8& f, dt_qint32& d,
  50. DType src_dt, DType filt_dt, DType) {
  51. auto cast = [](const dt_quint8& val, DType dt) {
  52. return dt_qint32(static_cast<int32_t>(val.as_uint8()) -
  53. dt.param<dtype::Quantized8Asymm>().zero_point);
  54. };
  55. d += cast(s, src_dt) * cast(f, filt_dt);
  56. }
  57. template <>
  58. inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_float32& d,
  59. DType src_dt, DType filt_dt, DType) {
  60. auto cast = [](const dt_qint8& val, DType dt) {
  61. return dt.param<dtype::QuantizedS8>().dequantize(val);
  62. };
  63. d += cast(s, src_dt) * cast(f, filt_dt);
  64. }
  65. template <>
  66. inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_qint32& d, DType,
  67. DType, DType) {
  68. auto cast = [](const dt_qint8& val) {
  69. return dt_qint32(static_cast<int32_t>(val.as_int8()));
  70. };
  71. d += cast(s) * cast(f);
  72. }
  73. struct StrategyBwdData {
  74. template <typename st, typename ft, typename dt>
  75. static void on(st& s, ft& f, dt& d, DType, DType, DType) {
  76. s += static_cast<st>(f) * static_cast<st>(d);
  77. }
  78. template <typename ct, typename dt>
  79. static void write(ct&, dt&) {}
  80. template <typename dt>
  81. static void init_dval(dt&) {}
  82. };
  83. template <>
  84. inline void StrategyBwdData::on(int& s, signed char& f, signed char& d, DType,
  85. DType, DType) {
  86. auto cast = [](signed char& val) {
  87. return static_cast<int32_t>(((megdnn::dt_qint8)val).as_int8());
  88. };
  89. s += cast(f) * cast(d);
  90. }
  91. template <>
  92. inline void StrategyBwdData::on(dt_qint32& s, dt_quint8& f, dt_quint8& d, DType,
  93. DType filt_dt, DType dst_dt) {
  94. auto cast = [](const dt_quint8& val, DType dt) {
  95. return dt_qint32(static_cast<int32_t>(val.as_uint8()) -
  96. dt.param<dtype::Quantized8Asymm>().zero_point);
  97. };
  98. s += cast(f, filt_dt) * cast(d, dst_dt);
  99. }
  100. template <>
  101. inline void StrategyBwdData::on(dt_qint32& s, dt_qint8& f, dt_qint8& d, DType,
  102. DType, DType) {
  103. auto cast = [](const dt_qint8& val) {
  104. return dt_qint32(static_cast<int32_t>(val.as_int8()));
  105. };
  106. s += cast(f) * cast(d);
  107. }
  108. struct StrategyBwdFlt {
  109. template <typename st, typename ft, typename dt>
  110. static void on(st& s, ft& f, dt& d, DType, DType, DType) {
  111. f += static_cast<ft>(s) * static_cast<ft>(d);
  112. }
  113. template <typename ct, typename dt>
  114. static void write(ct&, dt&) {}
  115. template <typename dt>
  116. static void init_dval(dt&) {}
  117. };
  118. struct ConvFilterVisitor {
  119. template <typename ftype>
  120. static ftype* get_current_ptr(ftype* fptr, size_t /* batch */,
  121. size_t /* oc */, size_t /* oh */,
  122. size_t /* ow */, size_t /* filter_sizes*/) {
  123. return fptr;
  124. }
  125. };
  126. template <typename stype, typename ftype, typename dtype, typename comp_type,
  127. class Strategy, typename FilterMeta,
  128. typename FilterVisitor = ConvFilterVisitor>
  129. void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
  130. _megdnn_tensor_out dst, const FilterMeta& filter_meta) {
  131. size_t spatial_start, channel_pos, batch_pos;
  132. using Format = param::Convolution::Format;
  133. if (filter_meta.format == Format::NCHW ||
  134. filter_meta.format == Format::NCHW88 ||
  135. filter_meta.format == Format::NCHW44 ||
  136. filter_meta.format == Format::NCHW44_DOT ||
  137. filter_meta.format == Format::NCHW4 ||
  138. filter_meta.format == Format::NCHW4_NCHW ||
  139. filter_meta.format == Format::NCHW4_NCHW32 ||
  140. filter_meta.format == Format::NCHW8 ||
  141. filter_meta.format == Format::NCHW32 ||
  142. filter_meta.format == Format::NCHW32_NCHW4) {
  143. spatial_start = 2;
  144. channel_pos = 1;
  145. batch_pos = 0;
  146. } else if (filter_meta.format == Format::CHWN4) {
  147. spatial_start = 1;
  148. channel_pos = 0;
  149. batch_pos = 3;
  150. } else {
  151. megdnn_assert(filter_meta.format == Format::NHWC,
  152. "invalid conv format");
  153. spatial_start = 1;
  154. channel_pos = 3;
  155. batch_pos = 0;
  156. }
  157. auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start],
  158. IW = src.layout.shape[spatial_start + 1];
  159. auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1];
  160. auto OC = dst.layout.shape[channel_pos],
  161. OH = dst.layout.shape[spatial_start],
  162. OW = dst.layout.shape[spatial_start + 1];
  163. if (filter_meta.format == Format::NCHW4 ||
  164. filter_meta.format == Format::CHWN4 ||
  165. filter_meta.format == Format::NCHW44_DOT ||
  166. filter_meta.format == Format::NCHW44 ||
  167. filter_meta.format == Format::NCHW32_NCHW4) {
  168. OC *= 4;
  169. } else if (filter_meta.format == Format::NCHW8 ||
  170. filter_meta.format == Format::NCHW88) {
  171. OC *= 8;
  172. } else if (filter_meta.format == Format::NCHW32 ||
  173. filter_meta.format == Format::NCHW4_NCHW32) {
  174. OC *= 32;
  175. }
  176. size_t FS_G, FS_OC, FS_IC, FS_SPATIAL;
  177. if (filter_meta.format == Format::NCHW ||
  178. filter_meta.format == Format::NCHW4 ||
  179. filter_meta.format == Format::NCHW4_NCHW ||
  180. filter_meta.format == Format::NCHW4_NCHW32 ||
  181. filter_meta.format == Format::NCHW8 ||
  182. filter_meta.format == Format::NCHW32 ||
  183. filter_meta.format == Format::NCHW32_NCHW4) {
  184. // g, oc, ic, fh, fw
  185. FS_SPATIAL = 1;
  186. FS_IC = FH * FW;
  187. FS_OC = FS_IC * filter_meta.icpg;
  188. FS_G = FS_OC * filter_meta.ocpg;
  189. } else if (filter_meta.format == Format::CHWN4) {
  190. // g, ic, fh, fw, oc, pack_size
  191. FS_SPATIAL = filter_meta.ocpg * 4;
  192. FS_IC = FH * FW * FS_SPATIAL;
  193. FS_OC = 4;
  194. FS_G = FS_IC * filter_meta.icpg;
  195. } else if (filter_meta.format == Format::NCHW88) {
  196. if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
  197. src.layout.ndim == 5 && filter_meta.ocpg == 1) {
  198. FS_SPATIAL = 8;
  199. FS_IC = FH * FW * FS_SPATIAL;
  200. FS_OC = FS_IC * filter_meta.icpg;
  201. FS_G = FS_OC * filter_meta.ocpg;
  202. } else {
  203. if (src.layout.ndim == 4 && dst.layout.ndim == 5) {
  204. FS_IC = 8;
  205. FS_SPATIAL = filter_meta.icpg * FS_IC;
  206. FS_OC = FH * FW * FS_SPATIAL;
  207. FS_G = FS_OC * filter_meta.ocpg / 8;
  208. } else {
  209. FS_SPATIAL = 8 * 8;
  210. FS_IC = FH * FW * FS_SPATIAL;
  211. FS_OC = FS_IC * filter_meta.icpg / 8;
  212. FS_G = FS_OC * filter_meta.ocpg / 8;
  213. }
  214. }
  215. } else if (filter_meta.format == Format::NCHW44 ||
  216. filter_meta.format == Format::NCHW44_DOT) {
  217. if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
  218. src.layout.ndim == 5 && filter_meta.ocpg == 1) {
  219. FS_SPATIAL = 4;
  220. FS_IC = FH * FW * FS_SPATIAL;
  221. FS_OC = FS_IC * filter_meta.icpg;
  222. FS_G = FS_OC * filter_meta.ocpg;
  223. } else {
  224. if (src.layout.ndim == 4 && dst.layout.ndim == 5) {
  225. FS_IC = 4;
  226. FS_SPATIAL = filter_meta.icpg * FS_IC;
  227. FS_OC = FH * FW * FS_SPATIAL;
  228. FS_G = FS_OC * filter_meta.ocpg / 4;
  229. } else {
  230. FS_SPATIAL = 4 * 4;
  231. FS_IC = FH * FW * FS_SPATIAL;
  232. FS_OC = FS_IC * filter_meta.icpg / 4;
  233. FS_G = FS_OC * filter_meta.ocpg / 4;
  234. }
  235. }
  236. } else {
  237. // g, oc, fh, fw, ic
  238. megdnn_assert(filter_meta.format == Format::NHWC);
  239. FS_IC = 1;
  240. FS_SPATIAL = filter_meta.icpg;
  241. FS_OC = FS_SPATIAL * FH * FW;
  242. FS_G = FS_OC * filter_meta.ocpg;
  243. }
  244. int ph = filter_meta.padding[0], pw = filter_meta.padding[1];
  245. size_t sh = filter_meta.stride[0], sw = filter_meta.stride[1];
  246. int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1];
  247. stype* __restrict sptr = src.compatible_ptr<stype>();
  248. dtype* __restrict dptr = dst.compatible_ptr<dtype>();
  249. int h_offset = -ph, w_offset = -pw;
  250. if (filter_meta.should_flip) {
  251. h_offset += filter_meta.dilated_spatial[0] - 1;
  252. w_offset += filter_meta.dilated_spatial[1] - 1;
  253. dh = -dh;
  254. dw = -dw;
  255. }
  256. auto get_linear_addr = [&filter_meta, &src](ptrdiff_t n, ptrdiff_t c,
  257. ptrdiff_t h, ptrdiff_t w,
  258. const TensorLayout& layout,
  259. bool is_output) -> ptrdiff_t {
  260. if (filter_meta.format == Format::NCHW) {
  261. return n * layout.stride[0] + c * layout.stride[1] +
  262. h * layout.stride[2] + w * layout.stride[3];
  263. } else if (filter_meta.format == Format::NHWC) {
  264. return n * layout.stride[0] + h * layout.stride[1] +
  265. w * layout.stride[2] + c * layout.stride[3];
  266. } else if (filter_meta.format == Format::NCHW8 ||
  267. filter_meta.format == Format::NCHW88) {
  268. if (filter_meta.format == Format::NCHW88 && !is_output &&
  269. src.layout.ndim == 4) {
  270. return n * layout.stride[0] + c * layout.stride[1] +
  271. h * layout.stride[2] + w * layout.stride[3];
  272. } else {
  273. return n * layout.stride[0] + (c / 8) * layout.stride[1] +
  274. h * layout.stride[2] + w * layout.stride[3] +
  275. (c & 0b111) * layout.stride[4];
  276. }
  277. } else if (filter_meta.format == Format::NCHW44 ||
  278. filter_meta.format == Format::NCHW44_DOT) {
  279. if (!is_output && src.layout.ndim == 4) {
  280. return n * layout.stride[0] + c * layout.stride[1] +
  281. h * layout.stride[2] + w * layout.stride[3];
  282. } else {
  283. return n * layout.stride[0] + (c / 4) * layout.stride[1] +
  284. h * layout.stride[2] + w * layout.stride[3] +
  285. (c % 4) * layout.stride[4];
  286. }
  287. } else if (filter_meta.format == Format::NCHW32) {
  288. return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
  289. h * layout.stride[2] + w * layout.stride[3] +
  290. (c & 0x1F) * layout.stride[4];
  291. } else if (filter_meta.format == Format::NCHW32_NCHW4) {
  292. if (is_output) {
  293. return n * layout.stride[0] + (c / 4) * layout.stride[1] +
  294. h * layout.stride[2] + w * layout.stride[3] +
  295. (c & 0b11) * layout.stride[4];
  296. } else {
  297. return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
  298. h * layout.stride[2] + w * layout.stride[3] +
  299. (c & 0x1F) * layout.stride[4];
  300. }
  301. } else if (filter_meta.format == Format::CHWN4) {
  302. return (c / 4) * layout.stride[0] + h * layout.stride[1] +
  303. w * layout.stride[2] + n * layout.stride[3] +
  304. (c % 4) * layout.stride[4];
  305. } else if (filter_meta.format == Format::NCHW4_NCHW) {
  306. if (is_output) {
  307. return n * layout.stride[0] + c * layout.stride[1] +
  308. h * layout.stride[2] + w * layout.stride[3];
  309. } else {
  310. return n * layout.stride[0] + (c / 4) * layout.stride[1] +
  311. h * layout.stride[2] + w * layout.stride[3] +
  312. (c & 0b11) * layout.stride[4];
  313. }
  314. } else if (filter_meta.format == Format::NCHW4_NCHW32) {
  315. if (is_output) {
  316. return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
  317. h * layout.stride[2] + w * layout.stride[3] +
  318. (c & 0x1F) * layout.stride[4];
  319. } else {
  320. return n * layout.stride[0] + (c / 4) * layout.stride[1] +
  321. h * layout.stride[2] + w * layout.stride[3] +
  322. (c & 0b11) * layout.stride[4];
  323. }
  324. } else {
  325. megdnn_assert(filter_meta.format == Format::NCHW4,
  326. "invalid conv format");
  327. return n * layout.stride[0] + (c / 4) * layout.stride[1] +
  328. h * layout.stride[2] + w * layout.stride[3] +
  329. (c & 0b11) * layout.stride[4];
  330. }
  331. };
  332. auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0,
  333. size_t fh, size_t fw) {
  334. if (filter_meta.format == Format::NCHW4 ||
  335. filter_meta.format == Format::NCHW4_NCHW ||
  336. filter_meta.format == Format::NCHW4_NCHW32) {
  337. return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
  338. (ic - ic0) / 4 * FS_IC * 4 +
  339. (fh * FW + fw) * FS_SPATIAL * 4 + ((ic - ic0) & 0b11);
  340. } else if (filter_meta.format == Format::NCHW8) {
  341. return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
  342. (ic - ic0) / 8 * FS_IC * 8 +
  343. (fh * FW + fw) * FS_SPATIAL * 8 + ((ic - ic0) & 0b111);
  344. } else if (filter_meta.format == Format::NCHW32 ||
  345. filter_meta.format == Format::NCHW32_NCHW4) {
  346. return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
  347. (ic - ic0) / 32 * FS_IC * 32 +
  348. (fh * FW + fw) * FS_SPATIAL * 32 + ((ic - ic0) & 0x1F);
  349. } else if (filter_meta.format == Format::CHWN4) {
  350. return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
  351. (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
  352. ((ic - ic0) % 4);
  353. } else if (filter_meta.format == Format::NCHW88 ||
  354. filter_meta.format == Format::NCHW44) {
  355. size_t pack_c_size = 4_z;
  356. if(filter_meta.format == Format::NCHW88){
  357. pack_c_size = 8_z;
  358. }
  359. if (src.layout.ndim == 4) {
  360. // ic < 8, input is nchw
  361. return gc_out.cur_grp * FS_G +
  362. gc_out.cur_off / pack_c_size * FS_OC +
  363. (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
  364. gc_out.cur_off % pack_c_size;
  365. } else if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
  366. filter_meta.ocpg == 1 && src.layout.ndim == 5) {
  367. // dw case
  368. return gc_out.cur_grp / pack_c_size * FS_G +
  369. gc_out.cur_off * FS_OC + (ic - ic0) * FS_IC +
  370. (fh * FW + fw) * FS_SPATIAL +
  371. gc_out.cur_grp % pack_c_size;
  372. } else if (src.layout.ndim == 5) {
  373. // normal case
  374. return gc_out.cur_grp * FS_G +
  375. gc_out.cur_off / pack_c_size * FS_OC +
  376. (ic - ic0) / pack_c_size * FS_IC +
  377. (fh * FW + fw) * FS_SPATIAL +
  378. ((ic - ic0) % pack_c_size) * pack_c_size +
  379. gc_out.cur_off % pack_c_size;
  380. } else {
  381. megdnn_throw(
  382. "nchw88/nchw44 naive not support this input and "
  383. "output\n");
  384. }
  385. } else if (filter_meta.format == Format::NCHW44_DOT) {
  386. if (src.layout.ndim == 4) {
  387. // ic < 4, input is nchw
  388. return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
  389. (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
  390. gc_out.cur_off % 4;
  391. } else if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
  392. filter_meta.ocpg == 1 && src.layout.ndim == 5) {
  393. // dw case
  394. return gc_out.cur_grp / 4 * FS_G + gc_out.cur_off * FS_OC +
  395. (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL +
  396. gc_out.cur_grp % 4;
  397. } else if (src.layout.ndim == 5) {
  398. // normal case
  399. return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
  400. (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
  401. (gc_out.cur_off % 4) * 4 + ((ic - ic0) % 4);
  402. } else {
  403. megdnn_throw(
  404. "nchw44_dot naive not support this input and output\n");
  405. }
  406. } else {
  407. return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
  408. (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL;
  409. }
  410. };
  411. size_t filter_sizes = filter_meta.ocpg * filter_meta.icpg * FH * FW;
  412. for (size_t n = 0; n < N; ++n) {
  413. GroupCounter gc_out{filter_meta.ocpg};
  414. for (size_t oc = 0; oc < OC; ++oc, gc_out.next())
  415. for (size_t oh = 0; oh < OH; ++oh)
  416. for (size_t ow = 0; ow < OW; ++ow) {
  417. comp_type dval = dptr[get_linear_addr(n, oc, oh, ow,
  418. dst.layout, true)];
  419. ftype* fptr_cur = FilterVisitor::template get_current_ptr(
  420. fptr, n, oc, oh, ow, filter_sizes);
  421. Strategy::init_dval(dval);
  422. for (size_t fh = 0; fh < FH; ++fh)
  423. for (size_t fw = 0; fw < FW; ++fw) {
  424. size_t ih = sh * oh + fh * dh + h_offset,
  425. iw = sw * ow + fw * dw + w_offset;
  426. // here ih and iw are represented in unsigned int
  427. // they will become very large if underflow occurs
  428. if (ih < IH && iw < IW) {
  429. size_t ic0 = gc_out.cur_grp * filter_meta.icpg,
  430. ic1 = ic0 + filter_meta.icpg;
  431. for (size_t ic = ic0; ic < ic1; ++ic) {
  432. stype& sval = sptr[get_linear_addr(
  433. n, ic, ih, iw, src.layout, false)];
  434. ftype& fval = fptr_cur[get_filter_addr(
  435. gc_out, ic, ic0, fh, fw)];
  436. Strategy::on(sval, fval, dval,
  437. src.layout.dtype,
  438. filter_meta.dtype,
  439. dst.layout.dtype);
  440. }
  441. }
  442. }
  443. Strategy::write(dval,
  444. dptr[get_linear_addr(n, oc, oh, ow,
  445. dst.layout, true)]);
  446. }
  447. }
  448. }
  449. template <typename stype, typename ftype, typename dtype, typename comp_type,
  450. class Strategy, typename FilterMeta,
  451. typename FilterVisitor = ConvFilterVisitor>
  452. void compute2d_hwcd4(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  453. _megdnn_tensor_out dst, const FilterMeta& filter_meta) {
  454. // The filter's layout is (G, OC/4, FH, FW, IC, 4) when using mad
  455. // and (G, OC/4, FH, FW, IC/4, 4, 4) when using dot.
  456. bool use_dot = false;
  457. if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  458. src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  459. (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 &&
  460. (filter.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  461. filter.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm)))
  462. use_dot = true;
  463. using Format = param::Convolution::Format;
  464. megdnn_assert(filter_meta.format == Format::NHWCD4);
  465. auto N = src.layout.shape[0], IH = src.layout.shape[1],
  466. IW = src.layout.shape[3];
  467. auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1];
  468. auto OC = dst.layout.shape[2] * 4, OH = dst.layout.shape[1],
  469. OW = dst.layout.shape[3];
  470. int ph = filter_meta.padding[0], pw = filter_meta.padding[1];
  471. size_t sh = filter_meta.stride[0], sw = filter_meta.stride[1];
  472. int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1];
  473. stype* __restrict sptr = src.compatible_ptr<stype>();
  474. ftype* __restrict fptr = filter.compatible_ptr<ftype>();
  475. dtype* __restrict dptr = dst.compatible_ptr<dtype>();
  476. megdnn_assert(!filter_meta.should_flip);
  477. int h_offset = -ph, w_offset = -pw;
  478. auto get_linear_addr = [](size_t n, size_t c, size_t h, size_t w,
  479. const TensorLayout& layout) -> size_t {
  480. return n * layout.stride[0] + h * layout.stride[1] +
  481. (c / 4) * layout.stride[2] + w * layout.stride[3] +
  482. c % 4 * layout.stride[4];
  483. };
  484. size_t FS_G, FS_OCB, FS_SPATIAL;
  485. if (!use_dot && filter.layout.ndim == 5) {
  486. if (filter_meta.ocpg == 1 && filter_meta.icpg == 1) {
  487. // chanwise conv, (G/4, 1, FH, FW, 4)
  488. FS_G = filter.layout.stride[0];
  489. FS_OCB = 0;
  490. FS_SPATIAL = 4;
  491. } else {
  492. // dense conv, (OC/4, FH, FW, IC, 4)
  493. FS_G = 0;
  494. FS_OCB = filter.layout.stride[0];
  495. FS_SPATIAL = filter.layout.stride[2];
  496. }
  497. } else if (!use_dot && filter.layout.ndim == 6) {
  498. // group conv, (G, OC/4, FH, FW, IC, 4)
  499. FS_G = filter.layout.stride[0];
  500. FS_OCB = filter.layout.stride[1];
  501. FS_SPATIAL = filter.layout.stride[3];
  502. } else if (use_dot && filter.layout.ndim == 6) {
  503. // dense conv used dot, (OC/4, FH, FW, IC/4, 4, 4)
  504. FS_G = 0;
  505. FS_OCB = filter.layout.stride[0];
  506. FS_SPATIAL = filter.layout.stride[2];
  507. } else if (use_dot && filter.layout.ndim == 7) {
  508. // group conv used dot, (G, OC/4, FH, FW, IC/4, 4, 4)
  509. FS_G = filter.layout.stride[0];
  510. FS_OCB = filter.layout.stride[1];
  511. FS_SPATIAL = filter.layout.stride[3];
  512. } else if (use_dot && filter.layout.ndim == 5 && filter_meta.ocpg == 1 &&
  513. filter_meta.icpg == 1) {
  514. // chanwise conv, (G/4, 1, FH, FW, 4)
  515. FS_G = filter.layout.stride[0];
  516. FS_OCB = 0;
  517. FS_SPATIAL = 4;
  518. } else {
  519. megdnn_assert(0, "invalid filter layout");
  520. }
  521. auto get_filter_addr = [&use_dot, &FS_G, &FS_OCB, &FS_SPATIAL, &FW,
  522. &filter_meta](size_t group, size_t offset,
  523. size_t fh, size_t fw,
  524. size_t c) -> size_t {
  525. if (filter_meta.ocpg == 1 && filter_meta.icpg == 1) {
  526. return (group / 4) * FS_G + (fh * FW + fw) * FS_SPATIAL +
  527. (group % 4);
  528. } else if (!use_dot) {
  529. return group * FS_G + (offset / 4) * FS_OCB +
  530. (fh * FW + fw) * FS_SPATIAL + c * 4 + (offset % 4);
  531. } else {
  532. megdnn_assert(use_dot);
  533. return group * FS_G + (offset / 4) * FS_OCB +
  534. (fh * FW + fw) * FS_SPATIAL + (c / 4) * 16 +
  535. (offset % 4) * 4 + (c % 4);
  536. }
  537. };
  538. size_t filter_sizes = filter_meta.ocpg * filter_meta.icpg * FH * FW;
  539. for (size_t n = 0; n < N; ++n) {
  540. GroupCounter gc_out{filter_meta.ocpg};
  541. for (size_t oc = 0; oc < OC; ++oc, gc_out.next())
  542. for (size_t oh = 0; oh < OH; ++oh)
  543. for (size_t ow = 0; ow < OW; ++ow) {
  544. comp_type dval =
  545. dptr[get_linear_addr(n, oc, oh, ow, dst.layout)];
  546. Strategy::init_dval(dval);
  547. ftype* fptr_cur = FilterVisitor::template get_current_ptr(
  548. fptr, n, oc, oh, ow, filter_sizes);
  549. for (size_t fh = 0; fh < FH; ++fh)
  550. for (size_t fw = 0; fw < FW; ++fw) {
  551. size_t ih = sh * oh + fh * dh + h_offset,
  552. iw = sw * ow + fw * dw + w_offset;
  553. // here ih and iw are represented in unsigned int
  554. // they will become very large if underflow occurs
  555. if (ih < IH && iw < IW) {
  556. size_t ic0 = gc_out.cur_grp * filter_meta.icpg,
  557. ic1 = ic0 + filter_meta.icpg;
  558. for (size_t ic = ic0; ic < ic1; ++ic) {
  559. stype& sval = sptr[get_linear_addr(
  560. n, ic, ih, iw, src.layout)];
  561. ftype& fval = fptr_cur[get_filter_addr(
  562. gc_out.cur_grp, gc_out.cur_off, fh,
  563. fw, ic - ic0)];
  564. Strategy::on(sval, fval, dval,
  565. src.layout.dtype,
  566. filter_meta.dtype,
  567. dst.layout.dtype);
  568. }
  569. }
  570. }
  571. Strategy::write(
  572. dval,
  573. dptr[get_linear_addr(n, oc, oh, ow, dst.layout)]);
  574. }
  575. }
  576. }
  577. //! forward with only filter ptr
  578. template <typename stype, typename ftype, typename dtype, typename comp_type>
  579. void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
  580. const Convolution::CanonizedFilterMeta& filter_meta) {
  581. megdnn_assert(filter_meta.spatial_ndim == 2);
  582. megdnn_assert(
  583. filter_meta.format == param::Convolution::Format::NCHW ||
  584. filter_meta.format == param::Convolution::Format::NHWC ||
  585. filter_meta.format == param::Convolution::Format::NCHW88 ||
  586. filter_meta.format == param::Convolution::Format::NCHW44 ||
  587. filter_meta.format == param::Convolution::Format::NCHW44_DOT ||
  588. filter_meta.format == param::Convolution::Format::NCHW4 ||
  589. filter_meta.format == param::Convolution::Format::NCHW4_NCHW ||
  590. filter_meta.format == param::Convolution::Format::NCHW4_NCHW32 ||
  591. filter_meta.format == param::Convolution::Format::NCHW32_NCHW4);
  592. compute2d<stype, ftype, dtype, comp_type, StrategyFwd>(
  593. src, const_cast<ftype*>(fptr), dst, filter_meta);
  594. }
  595. //! forward with full filter (for API compatibility)
  596. template <typename stype, typename ftype, typename dtype, typename comp_type>
  597. void forward(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  598. _megdnn_tensor_out dst,
  599. const Convolution::CanonizedFilterMeta& filter_meta) {
  600. if (filter_meta.format == param::Convolution::Format::NHWCD4) {
  601. return compute2d_hwcd4<stype, ftype, dtype, comp_type, StrategyFwd>(
  602. src, filter, dst, filter_meta);
  603. }
  604. return forward<stype, ftype, dtype, comp_type>(
  605. src, filter.compatible_ptr<ftype>(), dst, filter_meta);
  606. }
  607. template <typename ftype, typename dtype, typename gtype>
  608. void backward_data(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  609. _megdnn_tensor_out grad,
  610. const Convolution::CanonizedFilterMeta& filter_meta) {
  611. megdnn_assert(grad.layout.is_contiguous());
  612. memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
  613. megdnn_assert(filter_meta.spatial_ndim == 2);
  614. if (filter_meta.format == param::Convolution::Format::NHWCD4) {
  615. return compute2d_hwcd4<gtype, ftype, dtype, dtype, StrategyBwdData>(
  616. grad, filter, diff, filter_meta);
  617. }
  618. compute2d<gtype, ftype, dtype, dtype, StrategyBwdData>(
  619. grad, filter.compatible_ptr<ftype>(), diff, filter_meta);
  620. }
  621. template <typename stype, typename dtype, typename gtype>
  622. void backward_filter(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  623. _megdnn_tensor_out grad,
  624. const Convolution::CanonizedFilterMeta& filter_meta) {
  625. megdnn_assert(grad.layout.is_contiguous());
  626. memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
  627. megdnn_assert(filter_meta.spatial_ndim == 2);
  628. compute2d<stype, gtype, dtype, dtype, StrategyBwdFlt>(
  629. src, grad.compatible_ptr<gtype>(), diff, filter_meta);
  630. }
  631. template <typename stype, typename ftype, typename dtype, typename comp_type,
  632. typename FilterMeta, typename FilterVisitor = ConvFilterVisitor>
  633. void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  634. _megdnn_tensor_in bias, _megdnn_tensor_out dst,
  635. dt_byte* /* workspace_ptr */, const FilterMeta& filter_meta) {
  636. megdnn_assert(filter_meta.spatial_ndim == 2);
  637. switch (filter_meta.format) {
  638. case param::Convolution::Format::NCHW:
  639. case param::Convolution::Format::NCHW88:
  640. case param::Convolution::Format::NCHW44:
  641. case param::Convolution::Format::NCHW44_DOT:
  642. case param::Convolution::Format::NHWC:
  643. case param::Convolution::Format::NCHW4:
  644. case param::Convolution::Format::NCHW4_NCHW:
  645. case param::Convolution::Format::NCHW4_NCHW32:
  646. case param::Convolution::Format::NCHW8:
  647. case param::Convolution::Format::NCHW32:
  648. case param::Convolution::Format::NCHW32_NCHW4:
  649. case param::Convolution::Format::CHWN4:
  650. compute2d<stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta,
  651. FilterVisitor>(src, filter.compatible_ptr<ftype>(), dst,
  652. filter_meta);
  653. break;
  654. case param::Convolution::Format::NHWCD4:
  655. compute2d_hwcd4<stype, ftype, dtype, comp_type, StrategyFwd,
  656. FilterMeta, FilterVisitor>(src, filter, dst,
  657. filter_meta);
  658. break;
  659. default:
  660. megdnn_assert_internal(0);
  661. }
  662. //! we can not decide with bias.raw_ptr, as non bias the raw_ptr is not
  663. //! nullptr
  664. if (bias.layout.ndim != 0) {
  665. if (dst.layout.eq_shape(bias.layout) &&
  666. dst.layout.dtype.enumv() == bias.layout.dtype.enumv()) {
  667. dtype* dst_ptr = dst.compatible_ptr<dtype>();
  668. dtype* bias_ptr = bias.compatible_ptr<dtype>();
  669. for (size_t i = 0; i < dst.layout.span().dist_elem(); i++) {
  670. comp_type val = static_cast<comp_type>(dst_ptr[0]) +
  671. static_cast<comp_type>(bias_ptr[0]);
  672. dst_ptr[0] = val;
  673. dst_ptr++;
  674. bias_ptr++;
  675. }
  676. return;
  677. }
  678. using Format = param::ConvBias::Format;
  679. switch (filter_meta.format) {
  680. case Format::NCHW:
  681. case Format::NCHW4_NCHW: {
  682. int dst_batch = dst.layout.shape[0];
  683. int dst_channel = dst.layout.shape[1];
  684. int chann_stride = dst.layout.shape[2] * dst.layout.shape[3];
  685. dtype* dst_ptr = dst.compatible_ptr<dtype>();
  686. for (int batch = 0; batch < dst_batch; ++batch) {
  687. for (int chan = 0; chan < dst_channel; ++chan) {
  688. dtype bias_val = bias.compatible_ptr<dtype>()[chan];
  689. for (int i = 0; i < chann_stride; ++i, ++dst_ptr) {
  690. comp_type val = static_cast<comp_type>(dst_ptr[0]) +
  691. static_cast<comp_type>(bias_val);
  692. dst_ptr[0] = val;
  693. }
  694. }
  695. }
  696. break;
  697. };
  698. #define BIAS_ADD_NCHWx(_pack_size) \
  699. do { \
  700. megdnn_assert(dst.layout.is_contiguous()); \
  701. int dst_batch = dst.layout.shape[0]; \
  702. int dst_channel = dst.layout.shape[1] * (_pack_size); \
  703. int chann_stride = dst.layout.shape[2] * dst.layout.shape[3]; \
  704. dtype* dst_ptr = dst.compatible_ptr<dtype>(); \
  705. for (int batch = 0; batch < dst_batch; ++batch) { \
  706. for (int chan = 0; chan < dst_channel; ++chan) { \
  707. dtype bias_val = bias.compatible_ptr<dtype>()[chan]; \
  708. for (int i = 0; i < chann_stride; ++i) { \
  709. int idx = batch * dst_channel * chann_stride + \
  710. (chan / (_pack_size)) * \
  711. (chann_stride * (_pack_size)) + \
  712. i * (_pack_size) + chan % (_pack_size); \
  713. dst_ptr[idx] = static_cast<comp_type>(dst_ptr[idx]) + \
  714. static_cast<comp_type>(bias_val); \
  715. } \
  716. } \
  717. } \
  718. } while (0)
  719. case Format::NCHW44:
  720. case Format::NCHW44_DOT:
  721. case Format::NCHW32_NCHW4:
  722. case Format::NCHW4: {
  723. BIAS_ADD_NCHWx(4);
  724. break;
  725. };
  726. case Format::NCHW8: {
  727. BIAS_ADD_NCHWx(8);
  728. break;
  729. };
  730. case Format::NCHW4_NCHW32:
  731. case Format::NCHW32: {
  732. BIAS_ADD_NCHWx(32);
  733. break;
  734. };
  735. case Format::NCHW88: {
  736. BIAS_ADD_NCHWx(8);
  737. break;
  738. };
  739. #define BIAS_ADD_CHWNx(_pack_size) \
  740. do { \
  741. megdnn_assert(dst.layout.is_contiguous()); \
  742. int dst_batch = dst.layout.shape[3]; \
  743. int dst_channel = dst.layout.shape[0] * (_pack_size); \
  744. int chann_stride = \
  745. dst.layout.shape[1] * dst.layout.shape[2] * dst_batch; \
  746. dtype* dst_ptr = dst.compatible_ptr<dtype>(); \
  747. for (int chan = 0; chan < dst_channel; ++chan) { \
  748. dtype bias_val = bias.compatible_ptr<dtype>()[chan]; \
  749. for (int i = 0; i < chann_stride; ++i) { \
  750. int idx = \
  751. (chan / (_pack_size)) * chann_stride * (_pack_size) + \
  752. i * (_pack_size) + chan % (_pack_size); \
  753. dst_ptr[idx] = static_cast<comp_type>(dst_ptr[idx]) + \
  754. static_cast<comp_type>(bias_val); \
  755. } \
  756. } \
  757. } while (0)
  758. case Format::CHWN4: {
  759. BIAS_ADD_CHWNx(4);
  760. break;
  761. }
  762. case Format::NHWC: {
  763. int dst_nhw = dst.layout.shape[0] * dst.layout.shape[1] *
  764. dst.layout.shape[2];
  765. int dst_channel = dst.layout.shape[3];
  766. dtype* dst_ptr = dst.compatible_ptr<dtype>();
  767. for (int nhw = 0; nhw < dst_nhw; ++nhw) {
  768. for (int chan = 0; chan < dst_channel; ++chan, ++dst_ptr) {
  769. dtype bias_val = bias.compatible_ptr<dtype>()[chan];
  770. comp_type val = static_cast<comp_type>(dst_ptr[0]) +
  771. static_cast<comp_type>(bias_val);
  772. dst_ptr[0] = val;
  773. }
  774. }
  775. break;
  776. };
  777. case Format::NHWCD4: {
  778. dtype* bias_ptr = bias.compatible_ptr<dtype>();
  779. dtype* dst_ptr = dst.compatible_ptr<dtype>();
  780. for (size_t n = 0; n < dst.layout[0]; n++) {
  781. for (size_t h = 0; h < dst.layout[1]; h++) {
  782. for (size_t cb = 0; cb < dst.layout[2]; cb++) {
  783. for (size_t w = 0; w < dst.layout[3]; w++) {
  784. for (size_t i = 0; i < 4; i++) {
  785. auto ptr = dst_ptr +
  786. n * dst.layout.stride[0] +
  787. h * dst.layout.stride[1] +
  788. cb * dst.layout.stride[2] +
  789. w * dst.layout.stride[3] +
  790. i * dst.layout.stride[4];
  791. comp_type val =
  792. static_cast<comp_type>(ptr[0]) +
  793. static_cast<comp_type>(
  794. bias_ptr[cb * 4 + i]);
  795. ptr[0] = val;
  796. }
  797. }
  798. }
  799. }
  800. }
  801. break;
  802. };
  803. default:
  804. megdnn_assert_internal(0);
  805. }
  806. }
  807. }
  808. } // namespace convolution
  809. } // namespace naive
  810. } // namespace megdnn
  811. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台