You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

img2col_helper.h 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. /**
  2. * \file dnn/src/fallback/convolution/img2col_helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/utils.h"
  12. namespace {
  13. template <bool is_xcorr, typename dtype>
  14. void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,
  15. const int OC, const int OH, const int OW, const int IC,
  16. const int IH, const int IW, const int FH, const int FW,
  17. const int SH, const int SW) {
  18. megdnn_ignore(OC);
  19. size_t i = 0;
  20. rep(ic, IC) {
  21. rep(fh, FH) {
  22. rep(fw, FW) {
  23. rep(oh, OH) {
  24. rep(ow, OW) {
  25. int fh2, fw2;
  26. if (is_xcorr) {
  27. fh2 = fh;
  28. fw2 = fw;
  29. } else {
  30. fh2 = FH - fh - 1;
  31. fw2 = FW - fw - 1;
  32. }
  33. dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW +
  34. (ow * SW + fw2)];
  35. }
  36. }
  37. }
  38. }
  39. }
  40. }
  41. //!add for im2col matmul multithread
  42. //
  43. template <bool is_xcorr, typename dtype>
  44. void img2col_stride_nchw4(const dtype* __restrict src, dtype* __restrict dst,
  45. const int OC, const int OH, const int OW, const int IC,
  46. const int IH, const int IW, const int FH, const int FW,
  47. const int SH, const int SW, const int cur_index,
  48. const int block_size) {
  49. MEGDNN_MARK_USED_VAR(OC);
  50. MEGDNN_MARK_USED_VAR(OH);
  51. int start_h = cur_index / OW;
  52. int cur_remain_w = cur_index % OW;
  53. int end_h = (cur_index + block_size) / OW;
  54. int end_remain_w = (cur_index + block_size) % OW;
  55. bool same_line = false;
  56. if (start_h == end_h) {
  57. same_line = true;
  58. }
  59. size_t newIC = IC / 4;
  60. size_t i = 0;
  61. if (sizeof(dtype) != 1) {
  62. if (same_line) {
  63. rep(ic, newIC) {
  64. rep(fh, FH) {
  65. rep(fw, FW) {
  66. int fh2, fw2;
  67. if (is_xcorr) {
  68. fh2 = fh;
  69. fw2 = fw;
  70. } else {
  71. fh2 = FH - fh - 1;
  72. fw2 = FW - fw - 1;
  73. }
  74. for (int w = cur_remain_w; w < end_remain_w; w++) {
  75. size_t index = 4 * (ic * IH * IW +
  76. (start_h * SH + fh2) * IW +
  77. (w * SW + fw2));
  78. dst[i++] = src[index];
  79. dst[i++] = src[index + 1];
  80. dst[i++] = src[index + 2];
  81. dst[i++] = src[index + 3];
  82. }
  83. }
  84. }
  85. }
  86. } else {
  87. rep(ic, newIC) {
  88. rep(fh, FH) {
  89. rep(fw, FW) {
  90. int fh2, fw2;
  91. if (is_xcorr) {
  92. fh2 = fh;
  93. fw2 = fw;
  94. } else {
  95. fh2 = FH - fh - 1;
  96. fw2 = FW - fw - 1;
  97. }
  98. for (int w = cur_remain_w; w < OW; w++) {
  99. size_t index =4 * (ic * IH * IW +
  100. (start_h * SH + fh2) * IW +
  101. (w * SW + fw2));
  102. dst[i++] = src[index + 0];
  103. dst[i++] = src[index + 1];
  104. dst[i++] = src[index + 2];
  105. dst[i++] = src[index + 3];
  106. }
  107. for (int h = start_h + 1; h < end_h; h++) {
  108. rep(ow, OW) {
  109. size_t index = 4 * (ic * IH * IW +
  110. (h * SH + fh2) * IW +
  111. (ow * SW + fw2));
  112. dst[i++] = src[index + 0];
  113. dst[i++] = src[index + 1];
  114. dst[i++] = src[index + 2];
  115. dst[i++] = src[index + 3];
  116. }
  117. }
  118. for (int w = 0; w < end_remain_w; w++) {
  119. size_t index = 4 * (ic * IH * IW +
  120. (end_h * SH + fh2) * IW +
  121. (w * SW + fw2));
  122. dst[i++] = src[index + 0];
  123. dst[i++] = src[index + 1];
  124. dst[i++] = src[index + 2];
  125. dst[i++] = src[index + 3];
  126. }
  127. }
  128. }
  129. }
  130. }
  131. } else {
  132. uint32_t* output = nullptr;
  133. const uint32_t* uint32_src =
  134. static_cast<const uint32_t*>(static_cast<const void*>(src));
  135. output = static_cast<uint32_t*>(static_cast<void*>(dst));
  136. if (same_line) {
  137. rep(ic, newIC) {
  138. rep(fh, FH) {
  139. rep(fw, FW) {
  140. int fh2, fw2;
  141. if (is_xcorr) {
  142. fh2 = fh;
  143. fw2 = fw;
  144. } else {
  145. fh2 = FH - fh - 1;
  146. fw2 = FW - fw - 1;
  147. }
  148. size_t index =
  149. (ic * IH * IW + (start_h * SH + fh2) * IW +
  150. (cur_remain_w * SW + fw2));
  151. for (int w = cur_remain_w; w < end_remain_w; w++) {
  152. output[i++] = uint32_src[index];
  153. index += SW;
  154. }
  155. }
  156. }
  157. }
  158. } else {
  159. rep(ic, newIC) {
  160. rep(fh, FH) {
  161. rep(fw, FW) {
  162. int fh2, fw2;
  163. if (is_xcorr) {
  164. fh2 = fh;
  165. fw2 = fw;
  166. } else {
  167. fh2 = FH - fh - 1;
  168. fw2 = FW - fw - 1;
  169. }
  170. size_t index = ic * IH * IW +
  171. (start_h * SH + fh2) * IW +
  172. cur_remain_w * SW + fw2;
  173. for (int w = cur_remain_w; w < OW; w++) {
  174. output[i++] = uint32_src[index];
  175. index += SW;
  176. }
  177. for (int h = start_h + 1; h < end_h; h++) {
  178. index = ic * IH * IW + (h * SH + fh2) * IW + fw2;
  179. rep(ow, OW) {
  180. output[i++] = uint32_src[index];
  181. index += SW;
  182. }
  183. }
  184. index = ic * IH * IW + (end_h * SH + fh2) * IW + fw2;
  185. for (int w = 0; w < end_remain_w; w++) {
  186. output[i++] = uint32_src[index];
  187. index += SW;
  188. }
  189. }
  190. }
  191. }
  192. }
  193. }
  194. }
  195. template <bool is_xcorr, typename dtype>
  196. void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst,
  197. const int OC, const int OH, const int OW, const int IC,
  198. const int IH, const int IW, const int FH, const int FW,
  199. const int SH, const int SW, const int cur_index,
  200. const int block_size) {
  201. MEGDNN_MARK_USED_VAR(OC);
  202. MEGDNN_MARK_USED_VAR(OH);
  203. MEGDNN_MARK_USED_VAR(SH);
  204. MEGDNN_MARK_USED_VAR(SW);
  205. int start_h = cur_index / OW;
  206. int cur_remain_w = cur_index % OW;
  207. int end_h = (cur_index + block_size) / OW;
  208. int end_remain_w = (cur_index + block_size) % OW;
  209. bool same_line = false;
  210. if (start_h == end_h) {
  211. same_line = true;
  212. }
  213. size_t newIC = IC / 4;
  214. size_t i = 0;
  215. if (sizeof(dtype) != 1) {
  216. if (same_line) {
  217. rep(ic, newIC) {
  218. rep(fh, FH) {
  219. rep(fw, FW) {
  220. int fh2, fw2;
  221. if (is_xcorr) {
  222. fh2 = fh;
  223. fw2 = fw;
  224. } else {
  225. fh2 = FH - fh - 1;
  226. fw2 = FW - fw - 1;
  227. }
  228. for (int w = cur_remain_w; w < end_remain_w; w++) {
  229. size_t index =
  230. 4 * (ic * IH * IW + (start_h + fh2) * IW +
  231. (w + fw2));
  232. dst[i++] = src[index];
  233. dst[i++] = src[index + 1];
  234. dst[i++] = src[index + 2];
  235. dst[i++] = src[index + 3];
  236. }
  237. }
  238. }
  239. }
  240. } else {
  241. rep(ic, newIC) {
  242. rep(fh, FH) {
  243. rep(fw, FW) {
  244. int fh2, fw2;
  245. if (is_xcorr) {
  246. fh2 = fh;
  247. fw2 = fw;
  248. } else {
  249. fh2 = FH - fh - 1;
  250. fw2 = FW - fw - 1;
  251. }
  252. for (int w = cur_remain_w; w < OW; w++) {
  253. size_t index = ic * IH * IW + (start_h + fh2) * IW +
  254. (w + fw2);
  255. dst[i++] = src[4 * index];
  256. dst[i++] = src[4 * index + 1];
  257. dst[i++] = src[4 * index + 2];
  258. dst[i++] = src[4 * index + 3];
  259. }
  260. for (int h = start_h + 1; h < end_h; h++) {
  261. rep(ow, OW) {
  262. size_t index =
  263. 4 * (ic * IH * IW + (h + fh2) * IW +
  264. (ow + fw2));
  265. dst[i++] = src[index + 0];
  266. dst[i++] = src[index + 1];
  267. dst[i++] = src[index + 2];
  268. dst[i++] = src[index + 3];
  269. }
  270. }
  271. for (int w = 0; w < end_remain_w; w++) {
  272. size_t index = 4 * (ic * IH * IW +
  273. (end_h + fh2) * IW + (w + fw2));
  274. dst[i++] = src[index + 0];
  275. dst[i++] = src[index + 1];
  276. dst[i++] = src[index + 2];
  277. dst[i++] = src[index + 3];
  278. }
  279. }
  280. }
  281. }
  282. }
  283. } else {
  284. uint32_t* output = nullptr;
  285. const uint32_t* uint32_src =
  286. static_cast<const uint32_t*>(static_cast<const void*>(src));
  287. output = static_cast<uint32_t*>(static_cast<void*>(dst));
  288. if (same_line) {
  289. rep(ic, newIC) {
  290. rep(fh, FH) {
  291. rep(fw, FW) {
  292. int fh2, fw2;
  293. if (is_xcorr) {
  294. fh2 = fh;
  295. fw2 = fw;
  296. } else {
  297. fh2 = FH - fh - 1;
  298. fw2 = FW - fw - 1;
  299. }
  300. for (int w = cur_remain_w; w < end_remain_w; w++) {
  301. size_t index = (ic * IH * IW +
  302. (start_h + fh2) * IW + (w + fw2));
  303. output[i++] = uint32_src[index];
  304. }
  305. }
  306. }
  307. }
  308. } else {
  309. rep(ic, newIC) {
  310. rep(fh, FH) {
  311. rep(fw, FW) {
  312. int fh2, fw2;
  313. if (is_xcorr) {
  314. fh2 = fh;
  315. fw2 = fw;
  316. } else {
  317. fh2 = FH - fh - 1;
  318. fw2 = FW - fw - 1;
  319. }
  320. for (int w = cur_remain_w; w < OW; w++) {
  321. size_t index = ic * IH * IW + (start_h + fh2) * IW +
  322. (w + fw2);
  323. output[i++] = uint32_src[index];
  324. }
  325. for (int h = start_h + 1; h < end_h; h++) {
  326. rep(ow, OW) {
  327. size_t index = (ic * IH * IW + (h + fh2) * IW +
  328. (ow + fw2));
  329. output[i++] = uint32_src[index];
  330. }
  331. }
  332. for (int w = 0; w < end_remain_w; w++) {
  333. size_t index = (ic * IH * IW + (end_h + fh2) * IW +
  334. (w + fw2));
  335. output[i++] = uint32_src[index];
  336. }
  337. }
  338. }
  339. }
  340. }
  341. }
  342. }
  343. template <bool is_xcorr, typename dtype>
  344. void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,
  345. const int OC, const int OH, const int OW, const int IC,
  346. const int IH, const int IW, const int FH, const int FW,
  347. const int SH, const int SW, const int cur_index,
  348. const int block_size) {
  349. MEGDNN_MARK_USED_VAR(OC);
  350. MEGDNN_MARK_USED_VAR(OH);
  351. int start_h = cur_index / OW;
  352. int cur_remain_w = cur_index % OW;
  353. int end_h = (cur_index + block_size) / OW;
  354. int end_remain_w = (cur_index + block_size) % OW;
  355. bool same_line = false;
  356. if (start_h == end_h) {
  357. same_line = true;
  358. }
  359. size_t i = 0;
  360. if (same_line) {
  361. rep(ic, IC) {
  362. rep(fh, FH) {
  363. rep(fw, FW) {
  364. int fh2, fw2;
  365. if (is_xcorr) {
  366. fh2 = fh;
  367. fw2 = fw;
  368. } else {
  369. fh2 = FH - fh - 1;
  370. fw2 = FW - fw - 1;
  371. }
  372. for (int w = cur_remain_w; w < end_remain_w; w++) {
  373. dst[i++] =
  374. src[ic * IH * IW + (start_h * SH + fh2) * IW +
  375. (w * SW + fw2)];
  376. }
  377. }
  378. }
  379. }
  380. } else {
  381. rep(ic, IC) {
  382. rep(fh, FH) {
  383. rep(fw, FW) {
  384. int fh2, fw2;
  385. if (is_xcorr) {
  386. fh2 = fh;
  387. fw2 = fw;
  388. } else {
  389. fh2 = FH - fh - 1;
  390. fw2 = FW - fw - 1;
  391. }
  392. for (int w = cur_remain_w; w < OW; w++) {
  393. dst[i++] =
  394. src[ic * IH * IW + (start_h * SH + fh2) * IW +
  395. (w * SW + fw2)];
  396. }
  397. for (int h = start_h + 1; h < end_h; h++) {
  398. rep(ow, OW) {
  399. dst[i++] = src[ic * IH * IW + (h * SH + fh2) * IW +
  400. (ow * SW + fw2)];
  401. }
  402. }
  403. for (int w = 0; w < end_remain_w; w++) {
  404. dst[i++] = src[ic * IH * IW + (end_h * SH + fh2) * IW +
  405. (w * SW + fw2)];
  406. }
  407. }
  408. }
  409. }
  410. }
  411. }
  412. template <bool is_xcorr, typename dtype>
  413. void img2col(const dtype* __restrict src, dtype* __restrict dst, const int OC,
  414. const int OH, const int OW, const int IC, const int IH,
  415. const int IW, const int FH, const int FW, const int cur_index,
  416. const int block_size) {
  417. MEGDNN_MARK_USED_VAR(OC);
  418. MEGDNN_MARK_USED_VAR(OH);
  419. int start_h = cur_index / OW;
  420. int cur_remain_w = cur_index % OW;
  421. int end_h = (cur_index + block_size) / OW;
  422. int end_remain_w = (cur_index + block_size) % OW;
  423. bool same_line = false;
  424. if (start_h == end_h) {
  425. same_line = true;
  426. }
  427. int i = 0;
  428. if (same_line) {
  429. rep(ic, IC) {
  430. rep(fh, FH) {
  431. rep(fw, FW) {
  432. int fh2, fw2;
  433. if (is_xcorr) {
  434. fh2 = fh;
  435. fw2 = fw;
  436. } else {
  437. fh2 = FH - fh - 1;
  438. fw2 = FW - fw - 1;
  439. }
  440. for (int w = cur_remain_w; w < end_remain_w; w++) {
  441. dst[i++] = src[ic * IH * IW + (start_h + fh2) * IW +
  442. (w + fw2)];
  443. }
  444. }
  445. }
  446. }
  447. } else {
  448. rep(ic, IC) {
  449. rep(fh, FH) {
  450. rep(fw, FW) {
  451. int fh2, fw2;
  452. if (is_xcorr) {
  453. fh2 = fh;
  454. fw2 = fw;
  455. } else {
  456. fh2 = FH - fh - 1;
  457. fw2 = FW - fw - 1;
  458. }
  459. for (int w = cur_remain_w; w < OW; w++) {
  460. dst[i++] = src[ic * IH * IW + (start_h + fh2) * IW +
  461. (w + fw2)];
  462. }
  463. for (int h = start_h + 1; h < end_h; h++) {
  464. rep(ow, OW) {
  465. dst[i++] = src[ic * IH * IW + (h + fh2) * IW +
  466. (ow + fw2)];
  467. }
  468. }
  469. for (int w = 0; w < end_remain_w; w++) {
  470. dst[i++] = src[ic * IH * IW + (end_h + fh2) * IW +
  471. (w + fw2)];
  472. }
  473. }
  474. }
  475. }
  476. }
  477. }
  478. template <bool is_xcorr, typename dtype>
  479. void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH,
  480. size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) {
  481. size_t offset = (4 - OW % 4) % 4;
  482. size_t i = 0;
  483. rep(ic, IC) {
  484. rep(fh, FH) {
  485. rep(fw, FW) {
  486. rep(oh, OH) {
  487. size_t ow = 0;
  488. for (; ow < OW; ow += 4) {
  489. size_t fh2, fw2;
  490. if (is_xcorr) {
  491. fh2 = fh;
  492. fw2 = fw;
  493. } else {
  494. fh2 = FH - fh - 1;
  495. fw2 = FW - fw - 1;
  496. }
  497. dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
  498. (ow + fw2) + 0];
  499. dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
  500. (ow + fw2) + 1];
  501. dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
  502. (ow + fw2) + 2];
  503. dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
  504. (ow + fw2) + 3];
  505. }
  506. i -= offset;
  507. }
  508. }
  509. }
  510. }
  511. }
  512. } // anonymous namespace
  513. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台