You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

relayout_format.cpp 27 kB


  1. /**
  2. * \file dnn/src/common/relayout_format.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "megdnn/tensor_format.h"
  14. #include "src/common/utils.h"
  15. using namespace megdnn;
  16. void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
  17. TensorLayout& dst) {
  18. using Param = param::RelayoutFormat;
  19. switch (param().mode) {
  20. case Param::Mode::NCHW_NHWCD4:
  21. case Param::Mode::NCHW_NHWCD4I:
  22. dst.ndim = 5;
  23. dst[0] = src[0];
  24. dst[1] = src[2];
  25. dst[2] = (src[1] + 3) / 4;
  26. dst[3] = src[3];
  27. dst[4] = 4;
  28. break;
  29. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  30. dst.ndim = 5;
  31. megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
  32. dst[0] = src[0];
  33. dst[1] = div_ceil(src[1], 4_z);
  34. dst[2] = src[2];
  35. dst[3] = src[3];
  36. dst[4] = 4;
  37. break;
  38. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  39. megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
  40. megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
  41. dst.ndim = 5;
  42. dst[0] = src[0];
  43. dst[1] = div_ceil(src[1], 4_z);
  44. dst[2] = src[2];
  45. dst[3] = src[3];
  46. dst[4] = 4;
  47. break;
  48. case Param::Mode::NCHW_NCHW88:
  49. dst.ndim = 5;
  50. dst[0] = src[0];
  51. dst[1] = div_ceil(src[1], 8_z);
  52. dst[2] = src[2];
  53. dst[3] = src[3];
  54. dst[4] = 8;
  55. break;
  56. case Param::Mode::NCHW88_NCHW:
  57. dst.ndim = 4;
  58. dst[0] = src[0];
  59. dst[1] = src[1] * 8;
  60. dst[2] = src[2];
  61. dst[3] = src[3];
  62. break;
  63. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  64. megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
  65. dst.ndim = 6;
  66. megdnn_assert(src[0] % 8 == 0,
  67. "NCHW_NCHW88_CONV_DENSE_WEIGHT out channel must "
  68. "align to 8");
  69. dst[0] = src[0] / 8;
  70. dst[1] = div_ceil(src[1], 8_z);
  71. dst[2] = src[2];
  72. dst[3] = src[3];
  73. dst[4] = 8;
  74. dst[5] = 8;
  75. break;
  76. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  77. megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
  78. dst.ndim = 6;
  79. dst[0] = div_ceil(src[0], 8_z);
  80. dst[1] = src[1];
  81. dst[2] = src[2];
  82. dst[3] = src[3];
  83. dst[4] = src[4];
  84. dst[5] = 8;
  85. break;
  86. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  87. megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
  88. dst.ndim = 7;
  89. dst[0] = src[0];
  90. megdnn_assert(src[1] % 8 == 0,
  91. "NCHW_NCHW88_CONV_GROUP_WEIGHT out channel must "
  92. "align to 8");
  93. dst[1] = src[1] / 8;
  94. dst[2] = div_ceil(src[2], 8_z);
  95. dst[3] = src[3];
  96. dst[4] = src[4];
  97. dst[5] = 8;
  98. dst[6] = 8;
  99. break;
  100. case Param::Mode::NHWC_NHWCD4:
  101. case Param::Mode::NHWC_NHWCD4I:
  102. megdnn_assert(src.ndim == 4);
  103. //! channel mod 4 should == 4
  104. megdnn_assert(src[3] % 4 == 0);
  105. dst.ndim = 5;
  106. dst[0] = src[0];
  107. dst[1] = src[1];
  108. dst[2] = src[3] / 4;
  109. dst[3] = src[2];
  110. dst[4] = 4;
  111. break;
  112. case Param::Mode::NHWCD4_NHWC:
  113. megdnn_assert(src.ndim == 5);
  114. dst.ndim = 4;
  115. dst[0] = src[0];
  116. dst[1] = src[1];
  117. dst[2] = src[3];
  118. dst[3] = src[2] * 4;
  119. break;
  120. case Param::Mode::NHWCD4_NCHW:
  121. case Param::Mode::NHWCD4I_NCHW:
  122. megdnn_assert(src.ndim == 5);
  123. dst.ndim = 4;
  124. dst[0] = src[0];
  125. dst[1] = src[2] * 4;
  126. dst[2] = src[1];
  127. dst[3] = src[3];
  128. break;
  129. case Param::Mode::INTER_WEIGHT_DENSE:
  130. case Param::Mode::INTER_WEIGHT_DENSEI:
  131. megdnn_assert(src.ndim == 4);
  132. megdnn_assert(src[0] % 4 == 0);
  133. dst.ndim = 5;
  134. dst[0] = src[0] / 4;
  135. dst[1] = src[2];
  136. dst[2] = src[3];
  137. dst[3] = round_up<size_t>(src[1], 4);
  138. dst[4] = 4;
  139. break;
  140. case Param::Mode::INTER_WEIGHT_GROUP:
  141. case Param::Mode::INTER_WEIGHT_GROUPI:
  142. // group conv filter
  143. megdnn_assert(src.ndim == 5);
  144. megdnn_assert(src[1] % 4 == 0 && src[2] % 4 == 0);
  145. dst.ndim = 6;
  146. dst[0] = src[0];
  147. dst[1] = src[1] / 4;
  148. dst[2] = src[3];
  149. dst[3] = src[4];
  150. dst[4] = src[2];
  151. dst[5] = 4;
  152. break;
  153. case Param::Mode::INTER_WEIGHT_CHAN:
  154. case Param::Mode::INTER_WEIGHT_CHANI:
  155. megdnn_assert(src.ndim == 5 && src[1] == 1 && src[2] == 1);
  156. // chanwise conv filter
  157. dst.ndim = 5;
  158. dst[0] = src[0] / 4;
  159. dst[1] = 1;
  160. dst[2] = src[3];
  161. dst[3] = src[4];
  162. dst[4] = 4;
  163. break;
  164. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  165. megdnn_assert(src.ndim == 4);
  166. megdnn_assert(src[0] % 4 == 0);
  167. dst.ndim = 6;
  168. dst[0] = src[0] / 4;
  169. dst[1] = src[2];
  170. dst[2] = src[3];
  171. dst[3] = div_ceil<size_t>(src[1], 4);
  172. dst[4] = 4;
  173. dst[5] = 4;
  174. break;
  175. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  176. megdnn_assert(src.ndim == 5);
  177. megdnn_assert(src[1] % 4 == 0 && src[2] % 4 == 0);
  178. dst.ndim = 7;
  179. dst[0] = src[0];
  180. dst[1] = src[1] / 4;
  181. dst[2] = src[3];
  182. dst[3] = src[4];
  183. dst[4] = src[2] / 4;
  184. dst[5] = 4;
  185. dst[6] = 4;
  186. break;
  187. case Param::Mode::NCHW4_CHWN4:
  188. megdnn_assert(src.ndim == 5);
  189. megdnn_assert(src[4] == 4);
  190. dst.ndim = 5;
  191. dst[0] = src[1];
  192. dst[1] = src[2];
  193. dst[2] = src[3];
  194. dst[3] = src[0];
  195. dst[4] = src[4];
  196. break;
  197. case Param::Mode::CHWN4_NCHW4:
  198. megdnn_assert(src.ndim == 5);
  199. megdnn_assert(src[4] == 4);
  200. dst.ndim = 5;
  201. dst[0] = src[3];
  202. dst[1] = src[0];
  203. dst[2] = src[1];
  204. dst[3] = src[2];
  205. dst[4] = src[4];
  206. break;
  207. case Param::Mode::NCHW_NCHW4: {
  208. megdnn_assert(src.ndim == 4);
  209. const size_t group = param().group;
  210. megdnn_assert(src[1] % group == 0);
  211. const size_t icpg = src[1] / group;
  212. dst.ndim = 5;
  213. dst[0] = src[0];
  214. dst[1] = group * div_ceil<size_t>(icpg, 4);
  215. dst[2] = src[2];
  216. dst[3] = src[3];
  217. dst[4] = 4;
  218. }; break;
  219. case Param::Mode::NCHW_NCHW4_WEIGHT:;
  220. {
  221. if (src.ndim == 4) {
  222. //! dense case
  223. dst.ndim = 5;
  224. dst[0] = div_ceil<size_t>(src[0], 4) * 4;
  225. dst[1] = div_ceil<size_t>(src[1], 4);
  226. dst[2] = src[2];
  227. dst[3] = src[3];
  228. dst[4] = 4;
  229. } else if (src.ndim == 5) {
  230. //! group case
  231. dst.ndim = 6;
  232. dst[0] = src[0];
  233. dst[1] = div_ceil<size_t>(src[1], 4) * 4;
  234. dst[2] = div_ceil<size_t>(src[2], 4);
  235. dst[3] = src[3];
  236. dst[4] = src[4];
  237. dst[5] = 4;
  238. }
  239. };
  240. break;
  241. case Param::Mode::NCHW4_NCHW:
  242. megdnn_assert(src.ndim == 5);
  243. dst.ndim = 4;
  244. dst[0] = src[0];
  245. dst[1] = param().oc == 0 ? src[1] * 4 : param().oc;
  246. dst[2] = src[2];
  247. dst[3] = src[3];
  248. megdnn_assert(dst[1] % param().group == 0);
  249. break;
  250. case Param::Mode::NCHW_NCHW64:
  251. megdnn_assert(src.ndim == 4);
  252. dst.ndim = 5;
  253. dst[0] = src[0];
  254. dst[1] = div_ceil(src[1], 64_z);
  255. dst[2] = src[2];
  256. dst[3] = src[3];
  257. dst[4] = 64;
  258. break;
  259. case Param::Mode::NCHW64_NCHW:
  260. megdnn_assert(src.ndim == 5);
  261. dst.ndim = 4;
  262. dst[0] = src[0];
  263. dst[1] = param().oc == 0 ? src[1] * 64 : param().oc;
  264. dst[2] = src[2];
  265. dst[3] = src[3];
  266. break;
  267. default:
  268. megdnn_assert(0, "Invalid RelayoutFormat Mode");
  269. break;
  270. }
  271. TensorFormat dst_fmt;
  272. deduce_format(src.format, dst_fmt);
  273. dst.format = dst_fmt;
  274. if (!dst.dtype.valid()) {
  275. dst.dtype = src.dtype;
  276. }
  277. dst.init_contiguous_stride();
  278. }
  279. void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  280. deduce_layout_fwd(src, dst);
  281. }
  282. void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
  283. size_t align = handle()->image2d_pitch_alignment();
  284. auto vendor_type = handle()->vendor_type();
  285. using Param = param::RelayoutFormat;
  286. #define CHECK_SRC(_expect) \
  287. megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \
  288. _expect.to_string().c_str(), src.to_string().c_str())
  289. switch (param().mode) {
  290. case Param::Mode::NHWC_NHWCD4:
  291. CHECK_SRC(DefaultTensorFormat::make());
  292. dst = src;
  293. break;
  294. case Param::Mode::NHWCD4_NHWC:
  295. CHECK_SRC(DefaultTensorFormat::make());
  296. dst = src;
  297. break;
  298. case Param::Mode::NHWC_NHWCD4I:
  299. CHECK_SRC(DefaultTensorFormat::make());
  300. dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
  301. break;
  302. case Param::Mode::NCHW_NHWCD4:
  303. CHECK_SRC(DefaultTensorFormat::make());
  304. dst = src;
  305. break;
  306. case Param::Mode::NCHW_NHWCD4I:
  307. CHECK_SRC(DefaultTensorFormat::make());
  308. dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
  309. break;
  310. case Param::Mode::NHWCD4I_NCHW:
  311. CHECK_SRC(
  312. Image2DPack4TensorFormat::make_raw(2, align, vendor_type));
  313. dst = DefaultTensorFormat::make();
  314. break;
  315. case Param::Mode::NHWCD4_NCHW:
  316. CHECK_SRC(DefaultTensorFormat::make());
  317. dst = src;
  318. break;
  319. case Param::Mode::INTER_WEIGHT_DENSE:
  320. CHECK_SRC(DefaultTensorFormat::make());
  321. dst = src;
  322. break;
  323. case Param::Mode::INTER_WEIGHT_DENSEI:
  324. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  325. CHECK_SRC(DefaultTensorFormat::make());
  326. dst = Image2DPack4TensorFormat::make_raw(3, align, vendor_type);
  327. break;
  328. case Param::Mode::INTER_WEIGHT_GROUP:
  329. CHECK_SRC(DefaultTensorFormat::make());
  330. dst = src;
  331. break;
  332. case Param::Mode::INTER_WEIGHT_GROUPI:
  333. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  334. CHECK_SRC(DefaultTensorFormat::make());
  335. dst = Image2DPack4TensorFormat::make_raw(4, align, vendor_type);
  336. break;
  337. case Param::Mode::INTER_WEIGHT_CHAN:
  338. CHECK_SRC(DefaultTensorFormat::make());
  339. dst = src;
  340. break;
  341. case Param::Mode::INTER_WEIGHT_CHANI:
  342. CHECK_SRC(DefaultTensorFormat::make());
  343. dst = Image2DPack4TensorFormat::make_raw(1, align, vendor_type);
  344. break;
  345. case Param::Mode::NCHW4_CHWN4:
  346. CHECK_SRC(DefaultTensorFormat::make());
  347. dst = src;
  348. break;
  349. case Param::Mode::CHWN4_NCHW4:
  350. CHECK_SRC(DefaultTensorFormat::make());
  351. dst = src;
  352. break;
  353. case Param::Mode::NCHW4_NCHW:
  354. case Param::Mode::NCHW_NCHW4:
  355. case Param::Mode::NCHW_NCHW4_WEIGHT:
  356. case Param::Mode::NCHW_NCHW88:
  357. case Param::Mode::NCHW88_NCHW:
  358. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  359. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  360. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  361. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  362. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  363. CHECK_SRC(DefaultTensorFormat::make());
  364. dst = src;
  365. break;
  366. case Param::Mode::NCHW_NCHW64:
  367. dst = src;
  368. break;
  369. case Param::Mode::NCHW64_NCHW:
  370. dst = src;
  371. break;
  372. default:
  373. megdnn_throw("Invalid relayout format mode");
  374. break;
  375. }
  376. if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 &&
  377. (
  378. handle()->type() != Handle::HandleType::NAIVE)) {
  379. #if MEGDNN_ENABLE_MANGLING
  380. megdnn_throw(
  381. "Only naive and opencl handle support "
  382. "Image2DPack4TensorFormat, try build with debug for get more "
  383. "info");
  384. #else
  385. megdnn_throw(
  386. "Only naive and opencl handle support "
  387. "Image2DPack4TensorFormat, try to export MGB_USE_MEGDNN_DBG=2 "
  388. "and also export CUDA_VISIBLE_DEVICES=\'\' at CUDA env"
  389. "to enable naive handle");
  390. #endif
  391. }
  392. #undef CHECK_SRC
  393. }
  394. void RelayoutFormat::check_layout_fwd(const TensorLayout& src,
  395. const TensorLayout& dst) {
  396. TensorLayout dst_expected;
  397. dst_expected.dtype = dst.dtype;
  398. deduce_layout_fwd(src, dst_expected);
  399. megdnn_assert_eq_layout(dst_expected, dst);
  400. }
  401. void RelayoutFormat::check_exec(const TensorLayout& src,
  402. const TensorLayout& dst,
  403. size_t workspace_in_bytes) {
  404. check_layout_fwd(src, dst);
  405. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  406. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  407. }
  408. void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
  409. const TensorLayout& dst,
  410. TensorLayout& exec_workspace,
  411. TensorLayout& exec_src,
  412. TensorLayout& exec_dst) {
  413. check_layout_fwd(src, dst);
  414. using Param = param::RelayoutFormat;
  415. switch (param().mode) {
  416. case Param::Mode::NCHW_NCHW88:
  417. // nchw to nchw8c
  418. {
  419. exec_workspace = TensorLayout(
  420. {src[0], round_up(src[1], 8_z), src[2], src[3]},
  421. src.dtype, src.format);
  422. exec_src = exec_workspace
  423. .reshape({src[0], div_ceil(src[1], 8_z), 8,
  424. src[2], src[3]})
  425. .dimshuffle({0, 1, 3, 4, 2});
  426. exec_dst = dst;
  427. }
  428. break;
  429. case Param::Mode::NCHW_NCHW4:
  430. // nchw to nchw4
  431. {
  432. const size_t group = param().group;
  433. const size_t icpg = src[1] / group;
  434. exec_workspace = TensorLayout(
  435. {src[0], group * round_up(icpg, 4_z), src[2], src[3]},
  436. src.dtype, src.format);
  437. exec_src =
  438. exec_workspace
  439. .reshape({src[0], group * div_ceil(icpg, 4_z),
  440. 4, src[2], src[3]})
  441. .dimshuffle({0, 1, 3, 4, 2});
  442. exec_dst = dst;
  443. }
  444. break;
  445. case Param::Mode::NCHW_NCHW4_WEIGHT:
  446. // nchw to nchw4_weight
  447. {
  448. if (src.ndim == 4) {
  449. exec_workspace = TensorLayout(
  450. {round_up(src[0], 4_z), round_up(src[1], 4_z),
  451. src[2], src[3]},
  452. src.dtype, src.format);
  453. exec_src = exec_workspace
  454. .reshape({round_up(src[0], 4_z),
  455. div_ceil(src[1], 4_z), 4,
  456. src[2], src[3]})
  457. .dimshuffle({0, 1, 3, 4, 2});
  458. exec_dst = dst;
  459. } else if (src.ndim == 5) {
  460. exec_workspace = TensorLayout(
  461. {src[0], round_up(src[1], 4_z),
  462. round_up(src[2], 4_z), src[3], src[4]},
  463. src.dtype, src.format);
  464. exec_src = exec_workspace
  465. .reshape({src[0], round_up(src[1], 4_z),
  466. div_ceil(src[2], 4_z), 4,
  467. src[3], src[4]})
  468. .dimshuffle({0, 1, 2, 4, 5, 3});
  469. exec_dst = dst;
  470. }
  471. }
  472. break;
  473. case Param::Mode::NCHW4_NCHW:
  474. // nchw to nchw4
  475. {
  476. megdnn_assert(src.format == dst.format);
  477. exec_workspace =
  478. TensorLayout({src[0], src[1] * 4, src[2], src[3]},
  479. dst.dtype, dst.format);
  480. exec_src = src.dimshuffle({0, 1, 4, 2, 3});
  481. exec_dst = dst;
  482. }
  483. break;
  484. case Param::Mode::NCHW88_NCHW:
  485. // nchw8c to nchw
  486. exec_src = src;
  487. exec_dst = dst.reshape({dst[0], dst[1] / 8, 8, dst[2], dst[3]})
  488. .dimshuffle({0, 1, 3, 4, 2});
  489. break;
  490. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  491. // oihw to oihw8i8o
  492. {
  493. megdnn_assert(src.ndim == 4);
  494. megdnn_assert(src[0] % 8 == 0);
  495. exec_workspace = TensorLayout(
  496. {src[0], round_up(src[1], 8_z), src[2], src[3]},
  497. src.dtype, src.format);
  498. exec_src =
  499. exec_workspace
  500. .reshape({src[0] / 8, 8, div_ceil(src[1], 8_z),
  501. 8, src[2], src[3]})
  502. .dimshuffle({0, 2, 4, 5, 3, 1});
  503. exec_dst = dst;
  504. }
  505. break;
  506. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  507. // goihw to goihw8g
  508. {
  509. megdnn_assert(src.ndim == 5);
  510. exec_workspace = TensorLayout(
  511. {round_up(src[0], 8_z), src[1], src[2], src[3], src[4]},
  512. src.dtype, src.format);
  513. exec_src = exec_workspace
  514. .reshape({div_ceil(src[0], 8_z), 8, src[1],
  515. src[2], src[3], src[4]})
  516. .dimshuffle({0, 2, 3, 4, 5, 1});
  517. exec_dst = dst;
  518. }
  519. break;
  520. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  521. // goihw to goihw8i8o
  522. {
  523. megdnn_assert(src.ndim == 5);
  524. megdnn_assert(src[1] % 8 == 0);
  525. exec_workspace = TensorLayout(
  526. {src[0], src[1], round_up(src[2], 8_z), src[3], src[4]},
  527. src.dtype, src.format);
  528. exec_src = exec_workspace
  529. .reshape({src[0], src[1] / 8, 8,
  530. div_ceil(src[2], 8_z), 8, src[3],
  531. src[4]})
  532. .dimshuffle({0, 1, 3, 5, 6, 4, 2});
  533. exec_dst = dst;
  534. }
  535. break;
  536. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  537. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  538. // nchw to nchw4c or oihw to oihw4i
  539. {
  540. exec_workspace = TensorLayout(
  541. {src[0], round_up(src[1], 4_z), src[2], src[3]},
  542. src.dtype, src.format);
  543. exec_src = exec_workspace
  544. .reshape({src[0], div_ceil(src[1], 4_z), 4,
  545. src[2], src[3]})
  546. .dimshuffle({0, 1, 3, 4, 2});
  547. exec_dst = dst;
  548. }
  549. break;
  550. case Param::Mode::NCHW_NHWCD4:
  551. case Param::Mode::NCHW_NHWCD4I:
  552. // src is {N, C, H, W}
  553. // dst is {N, H, CB, W, 4}
  554. exec_src = src;
  555. exec_src[1] = (exec_src[1] + 3) / 4 * 4;
  556. exec_src.stride[0] = exec_src[1] * exec_src.stride[1];
  557. exec_src = exec_src.dimshuffle({0, 2, 3, 1});
  558. exec_src = exec_src.reshape({exec_src[0], exec_src[1], exec_src[2],
  559. exec_src[3] / 4, 4})
  560. .dimshuffle({0, 1, 3, 2, 4});
  561. exec_dst = dst;
  562. break;
  563. case Param::Mode::NHWC_NHWCD4:
  564. case Param::Mode::NHWC_NHWCD4I:
  565. // src is {N, H, W, C},
  566. // dst is {N, H, CB, W, 4}
  567. exec_src = src.reshape({src[0], src[1], src[2], src[3] / 4, 4})
  568. .dimshuffle({0, 1, 3, 2, 4});
  569. exec_dst = dst;
  570. break;
  571. case Param::Mode::NHWCD4_NHWC:
  572. // src is {N, H, CB, W, 4}
  573. // dst is {N, H, W, C},
  574. exec_src = src;
  575. exec_dst = dst.reshape({dst[0], dst[1], dst[2], dst[3] / 4, 4})
  576. .dimshuffle({0, 1, 3, 2, 4});
  577. break;
  578. case Param::Mode::NHWCD4_NCHW:
  579. case Param::Mode::NHWCD4I_NCHW:
  580. exec_src = src;
  581. exec_dst = dst.reshape({dst[0], dst[1] / 4, 4, dst[2], dst[3]})
  582. .dimshuffle({0, 3, 1, 4, 2});
  583. break;
  584. case Param::Mode::INTER_WEIGHT_DENSE:
  585. case Param::Mode::INTER_WEIGHT_DENSEI:
  586. // src is {OC, IC, FH, FW}
  587. // dst is {OCB, FH, FW, IC, 4}
  588. exec_src = src.reshape({src[0] / 4, 4, src[1], src[2], src[3]})
  589. .dimshuffle({0, 3, 4, 2, 1});
  590. exec_dst = dst;
  591. // dst[3] may be round_uped, set to the real ic
  592. exec_dst.shape[3] = src[1];
  593. break;
  594. case Param::Mode::INTER_WEIGHT_GROUP:
  595. case Param::Mode::INTER_WEIGHT_GROUPI:
  596. // group conv filter
  597. // src is {G, ocpg, icpg, fh, fw}
  598. // dst is {G, ocpgb, fh, fw, icpg, 4}
  599. exec_src =
  600. src.reshape({src[0], src[1] / 4, 4, src[2], src[3], src[4]})
  601. .dimshuffle({0, 1, 4, 5, 3, 2});
  602. exec_dst = dst;
  603. break;
  604. case Param::Mode::INTER_WEIGHT_CHAN:
  605. case Param::Mode::INTER_WEIGHT_CHANI:
  606. megdnn_assert(src.ndim == 5);
  607. megdnn_assert(src[1] == 1 && src[2] == 1);
  608. // chanwise conv filter
  609. megdnn_assert(src[0] % 4 == 0);
  610. exec_src = src.reshape({src[0] / 4, 4, 1, src[3], src[4]})
  611. .dimshuffle({0, 2, 3, 4, 1});
  612. exec_dst = dst;
  613. break;
  614. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  615. // src is {oc, ic, fh , fw}
  616. // dst is {oc/4, fh, fw, ic/4, 4, 4}
  617. exec_src = src;
  618. exec_src[1] = round_up<size_t>(src[1], 4);
  619. exec_src.stride[0] = exec_src.stride[1] * exec_src[1];
  620. exec_src = exec_src.reshape({exec_src[0] / 4, 4, exec_src[1] / 4, 4,
  621. exec_src[2], exec_src[3]})
  622. .dimshuffle({0, 4, 5, 2, 1, 3});
  623. exec_dst = dst;
  624. break;
  625. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  626. // src is {G, ocpg, icpg, fh, fw}
  627. // dst is {G, ocpg/4, fh, fw, icpg/4, 4, 4}
  628. exec_src = src.reshape({src[0], src[1] / 4, 4, src[2] / 4, 4,
  629. src[3], src[4]})
  630. .dimshuffle({0, 1, 5, 6, 3, 2, 4});
  631. exec_dst = dst;
  632. break;
  633. case Param::Mode::NCHW4_CHWN4:
  634. // src is {N, C/4, H, W, 4}
  635. // dst is {C/4, H, W, N, 4}
  636. exec_src = src.dimshuffle({1, 2, 3, 0, 4});
  637. exec_dst = dst;
  638. break;
  639. case Param::Mode::CHWN4_NCHW4:
  640. // src is {C/4, H, W, N, 4}
  641. // dst is {N, C/4, H, W, 4}
  642. exec_src = src.dimshuffle({3, 0, 1, 2, 4});
  643. exec_dst = dst;
  644. break;
  645. case Param::Mode::NCHW_NCHW64:
  646. // src is {N, C, H, W}
  647. // dst is {N, C/64, H, W, 64}
  648. exec_workspace = TensorLayout(
  649. {src[0], round_up(src[1], 64_z), src[2], src[3]},
  650. src.dtype);
  651. exec_src = exec_workspace
  652. .reshape({src[0], div_ceil(src[1], 64_z), 64,
  653. src[2], src[3]})
  654. .dimshuffle({0, 1, 3, 4, 2});
  655. exec_dst = dst;
  656. break;
  657. case Param::Mode::NCHW64_NCHW:
  658. // src is {N, C/64, H, W, 64}
  659. // dst is {N, C, H, W}
  660. exec_workspace = TensorLayout({src[0], src[1] * 64, src[2], src[3]},
  661. dst.dtype);
  662. exec_src = src.dimshuffle({0, 1, 4, 2, 3});
  663. exec_dst = dst;
  664. break;
  665. default:
  666. megdnn_assert(0, "Invalid RelayoutFormat Mode");
  667. }
  668. }
  669. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台