You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

relayout_format.cpp 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. /**
  2. * \file dnn/src/common/relayout_format.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "megdnn/tensor_format.h"
  14. #include "src/common/utils.h"
  15. using namespace megdnn;
  16. void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
  17. TensorLayout& dst) {
  18. using Param = param::RelayoutFormat;
  19. switch (param().mode) {
  20. case Param::Mode::NCHW_NHWCD4:
  21. case Param::Mode::NCHW_NHWCD4I:
  22. dst.ndim = 5;
  23. dst[0] = src[0];
  24. dst[1] = src[2];
  25. dst[2] = (src[1] + 3) / 4;
  26. dst[3] = src[3];
  27. dst[4] = 4;
  28. break;
  29. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  30. dst.ndim = 5;
  31. megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
  32. dst[0] = src[0];
  33. dst[1] = div_ceil(src[1], 4_z);
  34. dst[2] = src[2];
  35. dst[3] = src[3];
  36. dst[4] = 4;
  37. break;
  38. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  39. megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
  40. megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
  41. dst.ndim = 5;
  42. dst[0] = src[0];
  43. dst[1] = div_ceil(src[1], 4_z);
  44. dst[2] = src[2];
  45. dst[3] = src[3];
  46. dst[4] = 4;
  47. break;
  48. case Param::Mode::NCHW_NCHW88:
  49. dst.ndim = 5;
  50. dst[0] = src[0];
  51. dst[1] = div_ceil(src[1], 8_z);
  52. dst[2] = src[2];
  53. dst[3] = src[3];
  54. dst[4] = 8;
  55. break;
  56. case Param::Mode::NCHW88_NCHW:
  57. dst.ndim = 4;
  58. dst[0] = src[0];
  59. dst[1] = src[1] * 8;
  60. dst[2] = src[2];
  61. dst[3] = src[3];
  62. break;
  63. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  64. megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
  65. dst.ndim = 6;
  66. megdnn_assert(src[0] % 8 == 0,
  67. "NCHW_NCHW88_CONV_DENSE_WEIGHT out channel must "
  68. "align to 8");
  69. dst[0] = src[0] / 8;
  70. dst[1] = div_ceil(src[1], 8_z);
  71. dst[2] = src[2];
  72. dst[3] = src[3];
  73. dst[4] = 8;
  74. dst[5] = 8;
  75. break;
  76. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  77. megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
  78. dst.ndim = 6;
  79. dst[0] = div_ceil(src[0], 8_z);
  80. dst[1] = src[1];
  81. dst[2] = src[2];
  82. dst[3] = src[3];
  83. dst[4] = src[4];
  84. dst[5] = 8;
  85. break;
  86. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  87. megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
  88. dst.ndim = 7;
  89. dst[0] = src[0];
  90. megdnn_assert(src[1] % 8 == 0,
  91. "NCHW_NCHW88_CONV_GROUP_WEIGHT out channel must "
  92. "align to 8");
  93. dst[1] = src[1] / 8;
  94. dst[2] = div_ceil(src[2], 8_z);
  95. dst[3] = src[3];
  96. dst[4] = src[4];
  97. dst[5] = 8;
  98. dst[6] = 8;
  99. break;
  100. case Param::Mode::NHWC_NHWCD4:
  101. case Param::Mode::NHWC_NHWCD4I:
  102. megdnn_assert(src.ndim == 4);
  103. //! channel mod 4 should == 4
  104. megdnn_assert(src[3] % 4 == 0);
  105. dst.ndim = 5;
  106. dst[0] = src[0];
  107. dst[1] = src[1];
  108. dst[2] = src[3] / 4;
  109. dst[3] = src[2];
  110. dst[4] = 4;
  111. break;
  112. case Param::Mode::NHWCD4_NHWC:
  113. megdnn_assert(src.ndim == 5);
  114. dst.ndim = 4;
  115. dst[0] = src[0];
  116. dst[1] = src[1];
  117. dst[2] = src[3];
  118. dst[3] = src[2] * 4;
  119. break;
  120. case Param::Mode::NHWCD4_NCHW:
  121. case Param::Mode::NHWCD4I_NCHW:
  122. megdnn_assert(src.ndim == 5);
  123. dst.ndim = 4;
  124. dst[0] = src[0];
  125. dst[1] = src[2] * 4;
  126. dst[2] = src[1];
  127. dst[3] = src[3];
  128. break;
  129. case Param::Mode::INTER_WEIGHT_DENSE:
  130. case Param::Mode::INTER_WEIGHT_DENSEI:
  131. megdnn_assert(src.ndim == 4);
  132. megdnn_assert(src[0] % 4 == 0);
  133. dst.ndim = 5;
  134. dst[0] = src[0] / 4;
  135. dst[1] = src[2];
  136. dst[2] = src[3];
  137. dst[3] = round_up<size_t>(src[1], 4);
  138. dst[4] = 4;
  139. break;
  140. case Param::Mode::INTER_WEIGHT_GROUP:
  141. case Param::Mode::INTER_WEIGHT_GROUPI:
  142. // group conv filter
  143. megdnn_assert(src.ndim == 5);
  144. megdnn_assert(src[1] % 4 == 0 && src[2] % 4 == 0);
  145. dst.ndim = 6;
  146. dst[0] = src[0];
  147. dst[1] = src[1] / 4;
  148. dst[2] = src[3];
  149. dst[3] = src[4];
  150. dst[4] = src[2];
  151. dst[5] = 4;
  152. break;
  153. case Param::Mode::INTER_WEIGHT_CHAN:
  154. case Param::Mode::INTER_WEIGHT_CHANI:
  155. megdnn_assert(src.ndim == 5 && src[1] == 1 && src[2] == 1);
  156. // chanwise conv filter
  157. dst.ndim = 5;
  158. dst[0] = src[0] / 4;
  159. dst[1] = 1;
  160. dst[2] = src[3];
  161. dst[3] = src[4];
  162. dst[4] = 4;
  163. break;
  164. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  165. megdnn_assert(src.ndim == 4);
  166. megdnn_assert(src[0] % 4 == 0);
  167. dst.ndim = 6;
  168. dst[0] = src[0] / 4;
  169. dst[1] = src[2];
  170. dst[2] = src[3];
  171. dst[3] = div_ceil<size_t>(src[1], 4);
  172. dst[4] = 4;
  173. dst[5] = 4;
  174. break;
  175. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  176. megdnn_assert(src.ndim == 5);
  177. megdnn_assert(src[1] % 4 == 0 && src[2] % 4 == 0);
  178. dst.ndim = 7;
  179. dst[0] = src[0];
  180. dst[1] = src[1] / 4;
  181. dst[2] = src[3];
  182. dst[3] = src[4];
  183. dst[4] = src[2] / 4;
  184. dst[5] = 4;
  185. dst[6] = 4;
  186. break;
  187. case Param::Mode::NCHW4_CHWN4:
  188. megdnn_assert(src.ndim == 5);
  189. megdnn_assert(src[4] == 4);
  190. dst.ndim = 5;
  191. dst[0] = src[1];
  192. dst[1] = src[2];
  193. dst[2] = src[3];
  194. dst[3] = src[0];
  195. dst[4] = src[4];
  196. break;
  197. case Param::Mode::CHWN4_NCHW4:
  198. megdnn_assert(src.ndim == 5);
  199. megdnn_assert(src[4] == 4);
  200. dst.ndim = 5;
  201. dst[0] = src[3];
  202. dst[1] = src[0];
  203. dst[2] = src[1];
  204. dst[3] = src[2];
  205. dst[4] = src[4];
  206. break;
  207. case Param::Mode::NCHW_NCHW4: {
  208. megdnn_assert(src.ndim == 4);
  209. const size_t group = param().group;
  210. megdnn_assert(src[1] % group == 0);
  211. const size_t icpg = src[1] / group;
  212. dst.ndim = 5;
  213. dst[0] = src[0];
  214. dst[1] = group * div_ceil<size_t>(icpg, 4);
  215. dst[2] = src[2];
  216. dst[3] = src[3];
  217. dst[4] = 4;
  218. }; break;
  219. case Param::Mode::NCHW_NCHW4_WEIGHT:;
  220. {
  221. if (src.ndim == 4) {
  222. //! dense case
  223. dst.ndim = 5;
  224. dst[0] = div_ceil<size_t>(src[0], 4) * 4;
  225. dst[1] = div_ceil<size_t>(src[1], 4);
  226. dst[2] = src[2];
  227. dst[3] = src[3];
  228. dst[4] = 4;
  229. } else if (src.ndim == 5) {
  230. //! group case
  231. dst.ndim = 6;
  232. dst[0] = src[0];
  233. dst[1] = div_ceil<size_t>(src[1], 4) * 4;
  234. dst[2] = div_ceil<size_t>(src[2], 4);
  235. dst[3] = src[3];
  236. dst[4] = src[4];
  237. dst[5] = 4;
  238. }
  239. };
  240. break;
  241. case Param::Mode::NCHW4_NCHW:
  242. megdnn_assert(src.ndim == 5);
  243. dst.ndim = 4;
  244. dst[0] = src[0];
  245. dst[1] = param().oc == 0 ? src[1] * 4 : param().oc;
  246. dst[2] = src[2];
  247. dst[3] = src[3];
  248. megdnn_assert(dst[1] % param().group == 0);
  249. break;
  250. default:
  251. megdnn_assert(0, "Invalid RelayoutFormat Mode");
  252. break;
  253. }
  254. TensorFormat dst_fmt;
  255. deduce_format(src.format, dst_fmt);
  256. dst.format = dst_fmt;
  257. if (!dst.dtype.valid()) {
  258. dst.dtype = src.dtype;
  259. }
  260. dst.init_contiguous_stride();
  261. }
  262. void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  263. deduce_layout_fwd(src, dst);
  264. }
  265. void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
  266. size_t align = handle()->image2d_pitch_alignment();
  267. auto vendor_type = handle()->vendor_type();
  268. using Param = param::RelayoutFormat;
  269. #define CHECK_SRC(_expect) \
  270. megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \
  271. _expect.to_string().c_str(), src.to_string().c_str())
  272. switch (param().mode) {
  273. case Param::Mode::NHWC_NHWCD4:
  274. CHECK_SRC(DefaultTensorFormat::make());
  275. dst = src;
  276. break;
  277. case Param::Mode::NHWCD4_NHWC:
  278. CHECK_SRC(DefaultTensorFormat::make());
  279. dst = src;
  280. break;
  281. case Param::Mode::NHWC_NHWCD4I:
  282. CHECK_SRC(DefaultTensorFormat::make());
  283. dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
  284. break;
  285. case Param::Mode::NCHW_NHWCD4:
  286. CHECK_SRC(DefaultTensorFormat::make());
  287. dst = src;
  288. break;
  289. case Param::Mode::NCHW_NHWCD4I:
  290. CHECK_SRC(DefaultTensorFormat::make());
  291. dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
  292. break;
  293. case Param::Mode::NHWCD4I_NCHW:
  294. CHECK_SRC(
  295. Image2DPack4TensorFormat::make_raw(2, align, vendor_type));
  296. dst = DefaultTensorFormat::make();
  297. break;
  298. case Param::Mode::NHWCD4_NCHW:
  299. CHECK_SRC(DefaultTensorFormat::make());
  300. dst = src;
  301. break;
  302. case Param::Mode::INTER_WEIGHT_DENSE:
  303. CHECK_SRC(DefaultTensorFormat::make());
  304. dst = src;
  305. break;
  306. case Param::Mode::INTER_WEIGHT_DENSEI:
  307. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  308. CHECK_SRC(DefaultTensorFormat::make());
  309. dst = Image2DPack4TensorFormat::make_raw(3, align, vendor_type);
  310. break;
  311. case Param::Mode::INTER_WEIGHT_GROUP:
  312. CHECK_SRC(DefaultTensorFormat::make());
  313. dst = src;
  314. break;
  315. case Param::Mode::INTER_WEIGHT_GROUPI:
  316. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  317. CHECK_SRC(DefaultTensorFormat::make());
  318. dst = Image2DPack4TensorFormat::make_raw(4, align, vendor_type);
  319. break;
  320. case Param::Mode::INTER_WEIGHT_CHAN:
  321. CHECK_SRC(DefaultTensorFormat::make());
  322. dst = src;
  323. break;
  324. case Param::Mode::INTER_WEIGHT_CHANI:
  325. CHECK_SRC(DefaultTensorFormat::make());
  326. dst = Image2DPack4TensorFormat::make_raw(1, align, vendor_type);
  327. break;
  328. case Param::Mode::NCHW4_CHWN4:
  329. CHECK_SRC(DefaultTensorFormat::make());
  330. dst = src;
  331. break;
  332. case Param::Mode::CHWN4_NCHW4:
  333. CHECK_SRC(DefaultTensorFormat::make());
  334. dst = src;
  335. break;
  336. case Param::Mode::NCHW4_NCHW:
  337. case Param::Mode::NCHW_NCHW4:
  338. case Param::Mode::NCHW_NCHW4_WEIGHT:
  339. case Param::Mode::NCHW_NCHW88:
  340. case Param::Mode::NCHW88_NCHW:
  341. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  342. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  343. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  344. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  345. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  346. CHECK_SRC(DefaultTensorFormat::make());
  347. dst = src;
  348. break;
  349. default:
  350. megdnn_throw("Invalid relayout format mode");
  351. break;
  352. }
  353. if (!dst.is_default() &&
  354. (
  355. handle()->type() != Handle::HandleType::NAIVE)) {
  356. megdnn_throw(
  357. "Only naive and opencl handle support "
  358. "Image2DPack4TensorFormat, try to export MGB_USE_MEGDNN_DBG=2 "
  359. "and also export CUDA_VISIBLE_DEVICES=\'\' at CUDA env"
  360. "to enable naive handle");
  361. }
  362. #undef CHECK_SRC
  363. }
  364. void RelayoutFormat::check_layout_fwd(const TensorLayout& src,
  365. const TensorLayout& dst) {
  366. TensorLayout dst_expected;
  367. dst_expected.dtype = dst.dtype;
  368. deduce_layout_fwd(src, dst_expected);
  369. megdnn_assert_eq_layout(dst_expected, dst);
  370. }
  371. void RelayoutFormat::check_exec(const TensorLayout& src,
  372. const TensorLayout& dst,
  373. size_t workspace_in_bytes) {
  374. check_layout_fwd(src, dst);
  375. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  376. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  377. }
  378. void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
  379. const TensorLayout& dst,
  380. TensorLayout& exec_workspace,
  381. TensorLayout& exec_src,
  382. TensorLayout& exec_dst) {
  383. check_layout_fwd(src, dst);
  384. using Param = param::RelayoutFormat;
  385. switch (param().mode) {
  386. case Param::Mode::NCHW_NCHW88:
  387. // nchw to nchw8c
  388. {
  389. exec_workspace = TensorLayout(
  390. {src[0], round_up(src[1], 8_z), src[2], src[3]},
  391. src.dtype, src.format);
  392. exec_src = exec_workspace
  393. .reshape({src[0], div_ceil(src[1], 8_z), 8,
  394. src[2], src[3]})
  395. .dimshuffle({0, 1, 3, 4, 2});
  396. exec_dst = dst;
  397. }
  398. break;
  399. case Param::Mode::NCHW_NCHW4:
  400. // nchw to nchw4
  401. {
  402. const size_t group = param().group;
  403. const size_t icpg = src[1] / group;
  404. exec_workspace = TensorLayout(
  405. {src[0], group * round_up(icpg, 4_z), src[2], src[3]},
  406. src.dtype, src.format);
  407. exec_src =
  408. exec_workspace
  409. .reshape({src[0], group * div_ceil(icpg, 4_z),
  410. 4, src[2], src[3]})
  411. .dimshuffle({0, 1, 3, 4, 2});
  412. exec_dst = dst;
  413. }
  414. break;
  415. case Param::Mode::NCHW_NCHW4_WEIGHT:
  416. // nchw to nchw4_weight
  417. {
  418. if (src.ndim == 4) {
  419. exec_workspace = TensorLayout(
  420. {round_up(src[0], 4_z), round_up(src[1], 4_z),
  421. src[2], src[3]},
  422. src.dtype, src.format);
  423. exec_src = exec_workspace
  424. .reshape({round_up(src[0], 4_z),
  425. div_ceil(src[1], 4_z), 4,
  426. src[2], src[3]})
  427. .dimshuffle({0, 1, 3, 4, 2});
  428. exec_dst = dst;
  429. } else if (src.ndim == 5) {
  430. exec_workspace = TensorLayout(
  431. {src[0], round_up(src[1], 4_z),
  432. round_up(src[2], 4_z), src[3], src[4]},
  433. src.dtype, src.format);
  434. exec_src = exec_workspace
  435. .reshape({src[0], round_up(src[1], 4_z),
  436. div_ceil(src[2], 4_z), 4,
  437. src[3], src[4]})
  438. .dimshuffle({0, 1, 2, 4, 5, 3});
  439. exec_dst = dst;
  440. }
  441. }
  442. break;
  443. case Param::Mode::NCHW4_NCHW:
  444. // nchw to nchw4
  445. {
  446. exec_workspace =
  447. TensorLayout({src[0], src[1] * 4, src[2], src[3]},
  448. src.dtype, src.format)
  449. .reshape({src[0], src[1], 4, src[2], src[3]})
  450. .dimshuffle({0, 1, 3, 4, 2});
  451. exec_src = src;
  452. exec_dst = dst;
  453. }
  454. break;
  455. case Param::Mode::NCHW88_NCHW:
  456. // nchw8c to nchw
  457. exec_src = src;
  458. exec_dst = dst.reshape({dst[0], dst[1] / 8, 8, dst[2], dst[3]})
  459. .dimshuffle({0, 1, 3, 4, 2});
  460. break;
  461. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  462. // oihw to oihw8i8o
  463. {
  464. megdnn_assert(src.ndim == 4);
  465. megdnn_assert(src[0] % 8 == 0);
  466. exec_workspace = TensorLayout(
  467. {src[0], round_up(src[1], 8_z), src[2], src[3]},
  468. src.dtype, src.format);
  469. exec_src =
  470. exec_workspace
  471. .reshape({src[0] / 8, 8, div_ceil(src[1], 8_z),
  472. 8, src[2], src[3]})
  473. .dimshuffle({0, 2, 4, 5, 3, 1});
  474. exec_dst = dst;
  475. }
  476. break;
  477. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  478. // goihw to goihw8g
  479. {
  480. megdnn_assert(src.ndim == 5);
  481. exec_workspace = TensorLayout(
  482. {round_up(src[0], 8_z), src[1], src[2], src[3], src[4]},
  483. src.dtype, src.format);
  484. exec_src = exec_workspace
  485. .reshape({div_ceil(src[0], 8_z), 8, src[1],
  486. src[2], src[3], src[4]})
  487. .dimshuffle({0, 2, 3, 4, 5, 1});
  488. exec_dst = dst;
  489. }
  490. break;
  491. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  492. // goihw to goihw8i8o
  493. {
  494. megdnn_assert(src.ndim == 5);
  495. megdnn_assert(src[1] % 8 == 0);
  496. exec_workspace = TensorLayout(
  497. {src[0], src[1], round_up(src[2], 8_z), src[3], src[4]},
  498. src.dtype, src.format);
  499. exec_src = exec_workspace
  500. .reshape({src[0], src[1] / 8, 8,
  501. div_ceil(src[2], 8_z), 8, src[3],
  502. src[4]})
  503. .dimshuffle({0, 1, 3, 5, 6, 4, 2});
  504. exec_dst = dst;
  505. }
  506. break;
  507. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  508. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  509. // nchw to nchw4c or oihw to oihw4i
  510. {
  511. exec_workspace = TensorLayout(
  512. {src[0], round_up(src[1], 4_z), src[2], src[3]},
  513. src.dtype, src.format);
  514. exec_src = exec_workspace
  515. .reshape({src[0], div_ceil(src[1], 4_z), 4,
  516. src[2], src[3]})
  517. .dimshuffle({0, 1, 3, 4, 2});
  518. exec_dst = dst;
  519. }
  520. break;
  521. case Param::Mode::NCHW_NHWCD4:
  522. case Param::Mode::NCHW_NHWCD4I:
  523. // src is {N, C, H, W}
  524. // dst is {N, H, CB, W, 4}
  525. exec_src = src;
  526. exec_src[1] = (exec_src[1] + 3) / 4 * 4;
  527. exec_src.stride[0] = exec_src[1] * exec_src.stride[1];
  528. exec_src = exec_src.dimshuffle({0, 2, 3, 1});
  529. exec_src = exec_src.reshape({exec_src[0], exec_src[1], exec_src[2],
  530. exec_src[3] / 4, 4})
  531. .dimshuffle({0, 1, 3, 2, 4});
  532. exec_dst = dst;
  533. break;
  534. case Param::Mode::NHWC_NHWCD4:
  535. case Param::Mode::NHWC_NHWCD4I:
  536. // src is {N, H, W, C},
  537. // dst is {N, H, CB, W, 4}
  538. exec_src = src.reshape({src[0], src[1], src[2], src[3] / 4, 4})
  539. .dimshuffle({0, 1, 3, 2, 4});
  540. exec_dst = dst;
  541. break;
  542. case Param::Mode::NHWCD4_NHWC:
  543. // src is {N, H, CB, W, 4}
  544. // dst is {N, H, W, C},
  545. exec_src = src;
  546. exec_dst = dst.reshape({dst[0], dst[1], dst[2], dst[3] / 4, 4})
  547. .dimshuffle({0, 1, 3, 2, 4});
  548. break;
  549. case Param::Mode::NHWCD4_NCHW:
  550. case Param::Mode::NHWCD4I_NCHW:
  551. exec_src = src;
  552. exec_dst = dst.reshape({dst[0], dst[1] / 4, 4, dst[2], dst[3]})
  553. .dimshuffle({0, 3, 1, 4, 2});
  554. break;
  555. case Param::Mode::INTER_WEIGHT_DENSE:
  556. case Param::Mode::INTER_WEIGHT_DENSEI:
  557. // src is {OC, IC, FH, FW}
  558. // dst is {OCB, FH, FW, IC, 4}
  559. exec_src = src.reshape({src[0] / 4, 4, src[1], src[2], src[3]})
  560. .dimshuffle({0, 3, 4, 2, 1});
  561. exec_dst = dst;
  562. // dst[3] may be round_uped, set to the real ic
  563. exec_dst.shape[3] = src[1];
  564. break;
  565. case Param::Mode::INTER_WEIGHT_GROUP:
  566. case Param::Mode::INTER_WEIGHT_GROUPI:
  567. // group conv filter
  568. // src is {G, ocpg, icpg, fh, fw}
  569. // dst is {G, ocpgb, fh, fw, icpg, 4}
  570. exec_src =
  571. src.reshape({src[0], src[1] / 4, 4, src[2], src[3], src[4]})
  572. .dimshuffle({0, 1, 4, 5, 3, 2});
  573. exec_dst = dst;
  574. break;
  575. case Param::Mode::INTER_WEIGHT_CHAN:
  576. case Param::Mode::INTER_WEIGHT_CHANI:
  577. megdnn_assert(src.ndim == 5);
  578. megdnn_assert(src[1] == 1 && src[2] == 1);
  579. // chanwise conv filter
  580. megdnn_assert(src[0] % 4 == 0);
  581. exec_src = src.reshape({src[0] / 4, 4, 1, src[3], src[4]})
  582. .dimshuffle({0, 2, 3, 4, 1});
  583. exec_dst = dst;
  584. break;
  585. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  586. // src is {oc, ic, fh , fw}
  587. // dst is {oc/4, fh, fw, ic/4, 4, 4}
  588. exec_src = src;
  589. exec_src[1] = round_up<size_t>(src[1], 4);
  590. exec_src.stride[0] = exec_src.stride[1] * exec_src[1];
  591. exec_src = exec_src.reshape({exec_src[0] / 4, 4, exec_src[1] / 4, 4,
  592. exec_src[2], exec_src[3]})
  593. .dimshuffle({0, 4, 5, 2, 1, 3});
  594. exec_dst = dst;
  595. break;
  596. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  597. // src is {G, ocpg, icpg, fh, fw}
  598. // dst is {G, ocpg/4, fh, fw, icpg/4, 4, 4}
  599. exec_src = src.reshape({src[0], src[1] / 4, 4, src[2] / 4, 4,
  600. src[3], src[4]})
  601. .dimshuffle({0, 1, 5, 6, 3, 2, 4});
  602. exec_dst = dst;
  603. break;
  604. case Param::Mode::NCHW4_CHWN4:
  605. // src is {N, C/4, H, W, 4}
  606. // dst is {C/4, H, W, N, 4}
  607. exec_src = src.dimshuffle({1, 2, 3, 0, 4});
  608. exec_dst = dst;
  609. break;
  610. case Param::Mode::CHWN4_NCHW4:
  611. // src is {C/4, H, W, N, 4}
  612. // dst is {N, C/4, H, W, 4}
  613. exec_src = src.dimshuffle({3, 0, 1, 2, 4});
  614. exec_dst = dst;
  615. break;
  616. default:
  617. megdnn_assert(0, "Invalid RelayoutFormat Mode");
  618. }
  619. }
  620. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台