You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_format.cpp 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. /**
  2. * \file dnn/src/common/tensor_format.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/tensor_format.h"
  12. #include "megdnn/basic_types.h"
  13. #include "src/common/utils.h"
  14. #include <unordered_map>
  15. using namespace megdnn;
  16. using namespace megdnn::detail;
  17. namespace {
  18. DefaultTensorFormat* default_tensor_format_obj;
  19. }
  20. /* ===================== TensorFormat ===================== */
  21. TensorFormat TensorFormat::deserialize(const std::string& bin,
  22. const Handle* handle) {
  23. using Type = TensorFormat::Type;
  24. auto type = reinterpret_cast<const Type*>(bin.data());
  25. switch (*type) {
  26. case Type::DEFAULT:
  27. return DefaultTensorFormat::deserialize(handle, type + 1,
  28. bin.size() - sizeof(Type));
  29. case Type::IMAGE2D_PACK4:
  30. return Image2DPack4TensorFormat::deserialize(
  31. handle, type + 1, bin.size() - sizeof(Type));
  32. case Type::FOURBITS_ALIGNED_TO_BYTE:
  33. return FourBitsAlignedToBytesTensorFormat::deserialize(
  34. handle, type + 1, bin.size() - sizeof(Type));
  35. default:
  36. megdnn_throw("invalid tensor format type in deserialize");
  37. }
  38. }
  39. TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {}
  40. std::string TensorFormat::to_string() const {
  41. return m_impl->to_string();
  42. }
  43. std::string TensorFormat::serialize() const {
  44. std::string ret;
  45. ret.reserve(32);
  46. ret.assign(sizeof(Type), '\0');
  47. *reinterpret_cast<Type*>(&ret[0]) = type();
  48. m_impl->serialize_append(ret);
  49. return ret;
  50. }
  51. void TensorFormat::on_bad_cvt(Type dst_type) const {
  52. MEGDNN_MARK_USED_VAR(dst_type);
  53. megdnn_throw(ssprintf("can not convert tensor format %s to %d",
  54. impl()->to_string().c_str(),
  55. static_cast<int>(dst_type)));
  56. }
  57. bool TensorFormat::is_default() const {
  58. return m_impl == default_tensor_format_obj;
  59. }
  60. /* ===================== DefaultFormat ===================== */
  61. void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const {
  62. megdnn_assert(
  63. !layout.dtype.valid() || !layout.dtype.is_low_bit(),
  64. "DefaultTensorFormat does not support low-bits tensor(dtype:%s)",
  65. layout.dtype.name());
  66. }
  67. size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const {
  68. assert_valid(layout);
  69. if (!layout.ndim)
  70. return 0;
  71. megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
  72. size_t accum = 1;
  73. SafeMultiplies<size_t> mul;
  74. for (size_t i = layout.ndim; i; --i) {
  75. layout.stride[i - 1] = accum;
  76. accum = mul(accum, layout.shape[i - 1]);
  77. }
  78. return accum;
  79. }
  80. bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const {
  81. assert_valid(layout);
  82. return layout.is_physical_contiguous();
  83. }
  84. TensorLayout DefaultTensorFormat::collapse_contiguous_spec(
  85. const TensorLayout& layout) const {
  86. assert_valid(layout);
  87. megdnn_assert(layout.ndim);
  88. TensorLayout res{layout};
  89. // remove all dims with shape 1
  90. for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 2; --i) {
  91. if (!res.shape[i]) {
  92. // empty tensor
  93. res.ndim = 1;
  94. res.shape[0] = 0;
  95. res.stride[0] = 1;
  96. return res;
  97. }
  98. if (res.shape[i] == 1)
  99. res.remove_axis_inplace(i);
  100. }
  101. if (res.ndim == 1) {
  102. if (res.shape[0] <= 1) {
  103. // make it the "most canonical" contiguous layout for scalars or
  104. // empty tensors
  105. res.stride[0] = 1;
  106. }
  107. return res;
  108. }
  109. megdnn_assert(res.ndim && res.shape[res.ndim - 1]);
  110. for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) {
  111. megdnn_assert(res.shape[i]);
  112. if (res.stride[i] ==
  113. res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) {
  114. res.shape[i] *= res.shape[i + 1];
  115. res.stride[i] = res.stride[i + 1];
  116. res.remove_axis_inplace(i + 1);
  117. }
  118. }
  119. return res;
  120. }
  121. TensorLayout::Span DefaultTensorFormat::span_spec(
  122. const TensorLayout& layout) const {
  123. assert_valid(layout);
  124. if (layout.ndim == 0)
  125. return {0, 0, 0, 0};
  126. ptrdiff_t low_elem = 0;
  127. size_t high_elem = 0;
  128. for (size_t i = 0; i < layout.ndim; ++i) {
  129. auto shape_val = layout.shape[i];
  130. if (!shape_val) {
  131. return {0, 0, 0, 0};
  132. }
  133. auto stride_val = layout.stride[i];
  134. if (stride_val > 0) {
  135. high_elem += (shape_val - 1) * stride_val;
  136. } else {
  137. low_elem += (shape_val - 1) * stride_val;
  138. }
  139. }
  140. ++high_elem;
  141. ptrdiff_t low_byte;
  142. if (low_elem < 0) {
  143. low_byte = low_elem * layout.dtype.size();
  144. } else {
  145. low_byte = 0;
  146. }
  147. size_t high_byte = layout.dtype.size(high_elem);
  148. return TensorLayout::Span(low_elem, low_byte, high_elem, high_byte);
  149. }
  150. std::string DefaultTensorFormat::to_string() const {
  151. return "default{}";
  152. }
  153. void DefaultTensorFormat::serialize_append(std::string&) const {}
  154. TensorFormat DefaultTensorFormat::deserialize(const Handle* handle,
  155. const void* buf, size_t size) {
  156. MEGDNN_MARK_USED_VAR(handle);
  157. MEGDNN_MARK_USED_VAR(buf);
  158. megdnn_assert(!size);
  159. return make();
  160. }
  161. TensorFormat DefaultTensorFormat::make() {
  162. // use static storage so the object is accessible in global destructing
  163. // phase
  164. static std::aligned_storage_t<sizeof(DefaultTensorFormat),
  165. alignof(DefaultTensorFormat)>
  166. storage;
  167. static DefaultTensorFormat* obj = default_tensor_format_obj =
  168. new (&storage) DefaultTensorFormat{};
  169. return impl_to_tensor_format(obj);
  170. }
  171. /* ===================== Image2DTensorFormatBase ===================== */
  172. Image2DTensorFormatBase::Image2DTensorFormatBase(Type type, size_t align_axis,
  173. size_t align_size_in_elements)
  174. : ImplBase(type), m_align_axis(align_axis) {
  175. megdnn_assert(align_size_in_elements && align_axis);
  176. m_align_size_in_elements_log2 = __builtin_ctz(align_size_in_elements);
  177. megdnn_assert(
  178. (1u << m_align_size_in_elements_log2) == align_size_in_elements,
  179. "align size not power of 2: %zu", align_size_in_elements);
  180. }
  181. void Image2DTensorFormatBase::serialize_append(std::string& result) const {
  182. SerializePack pack;
  183. pack.align_axis = m_align_axis;
  184. megdnn_assert(pack.align_axis == m_align_axis); // detect overflow
  185. result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
  186. }
  187. size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const {
  188. size_t accum = 1;
  189. for (int i = m_align_axis - 1; i >= 0; --i) {
  190. if (layout.stride[i] == 0) {
  191. // this dimension is broadcasted
  192. } else {
  193. accum *= layout.shape[i];
  194. }
  195. }
  196. return accum;
  197. }
  198. size_t Image2DTensorFormatBase::image_width_elems(
  199. const TensorLayout& layout) const {
  200. size_t high_elem = 0;
  201. for (size_t i = m_align_axis; i < layout.ndim; ++i) {
  202. high_elem += (layout.shape[i] - 1) * layout.stride[i];
  203. }
  204. return high_elem + 1;
  205. }
  206. std::string Image2DTensorFormatBase::to_string() const {
  207. return ssprintf("I2D{%zu,%d}", m_align_axis,
  208. 1 << m_align_size_in_elements_log2);
  209. }
  210. /* ===================== Image2DPackedTensorFormatBase ===================== */
  211. template <size_t PIXEL_SIZE>
  212. size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width(
  213. const TensorLayout& layout) const {
  214. auto ret = image_width_elems(layout);
  215. megdnn_assert(ret % PIXEL_SIZE == 0);
  216. return ret / PIXEL_SIZE;
  217. }
  218. template <size_t PIXEL_SIZE>
  219. void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid(
  220. const TensorLayout& layout) const {
  221. auto m_align_axis = align_axis();
  222. megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE),
  223. "bad shape: %zu", layout.shape[layout.ndim - 1]);
  224. megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis);
  225. ptrdiff_t first_non_zero_stride = 0;
  226. for (int i = layout.ndim - 1; i >= 0; --i) {
  227. megdnn_assert(layout.shape[i] && layout.stride[i] >= 0);
  228. if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) {
  229. first_non_zero_stride = layout.stride[i];
  230. }
  231. }
  232. size_t mask =
  233. image_pitch_alignment_in_bytes(
  234. align_size_in_elements(layout.dtype.size_log()), layout) -
  235. 1;
  236. megdnn_assert(!(first_non_zero_stride & mask),
  237. "first stride is %d, but alignment is %zu",
  238. static_cast<int>(first_non_zero_stride), mask + 1);
  239. }
  240. template <size_t PIXEL_SIZE>
  241. size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_row_pitch(
  242. const TensorLayout& layout) const {
  243. for (int i = align_axis() - 1; i >= 0; --i) {
  244. // find a non-broadcast axis
  245. if (auto s = layout.stride[i]) {
  246. return layout.dtype.size(s);
  247. }
  248. }
  249. // use width for all broadcasted case
  250. size_t alignment_in_bytes_log2 = align_size_in_elements_log2();
  251. if (m_vendor_type == Handle::HandleVendorType::MALI) {
  252. alignment_in_bytes_log2 +=
  253. __builtin_ctz(layout.dtype.size() * PIXEL_SIZE);
  254. }
  255. return get_aligned_power2<size_t>(
  256. layout.dtype.size(image_width_elems(layout)),
  257. 1 << alignment_in_bytes_log2);
  258. }
  259. template <size_t PIXEL_SIZE>
  260. size_t
  261. Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_pitch_alignment_in_bytes(
  262. size_t align_size_in_elements, const TensorLayout& layout) const {
  263. return m_vendor_type == Handle::HandleVendorType::MALI
  264. ? (align_size_in_elements * layout.dtype.size() * PIXEL_SIZE)
  265. : align_size_in_elements;
  266. }
  267. template <size_t PIXEL_SIZE>
  268. TensorLayout::Span Image2DPackedTensorFormatBase<PIXEL_SIZE>::span_spec(
  269. const TensorLayout& layout) const {
  270. assert_valid(layout);
  271. size_t size = image_height(layout) * image_row_pitch(layout);
  272. auto mask = (1 << layout.dtype.size_log()) - 1;
  273. megdnn_assert(!(size & mask), "unaligned size: %zu", size);
  274. return {0, 0, size >> layout.dtype.size_log(), size};
  275. }
  276. template <size_t PIXEL_SIZE>
  277. size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::init_contiguous_stride(
  278. TensorLayout& layout) const {
  279. auto m_align_axis = align_axis();
  280. if (!layout.ndim)
  281. return 0;
  282. megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis,
  283. "dtype=%s ndim=%zu align=%zu", layout.dtype.name(),
  284. layout.ndim, m_align_axis);
  285. size_t align_size = image_pitch_alignment_in_bytes(
  286. align_size_in_elements(layout.dtype.size_log()), layout);
  287. size_t accum = 1;
  288. SafeMultiplies<size_t> mul;
  289. for (size_t i = layout.ndim; i; --i) {
  290. if (i == m_align_axis) {
  291. accum = get_aligned_power2<size_t>(accum, align_size);
  292. }
  293. layout.stride[i - 1] = accum;
  294. accum = mul(accum, layout.shape[i - 1]);
  295. }
  296. assert_valid(layout);
  297. return accum;
  298. };
  299. template <size_t PIXEL_SIZE>
  300. bool Image2DPackedTensorFormatBase<PIXEL_SIZE>::is_contiguous_spec(
  301. const TensorLayout& layout) const {
  302. megdnn_assert(layout.dtype.valid());
  303. size_t align_size = image_pitch_alignment_in_bytes(
  304. align_size_in_elements(layout.dtype.size_log()), layout);
  305. ptrdiff_t expected = 1;
  306. int height_axis = static_cast<int>(align_axis() - 1);
  307. for (int i = layout.ndim - 1; i >= 0; --i) {
  308. if (i == height_axis) {
  309. expected = megdnn::get_aligned_power2<size_t>(expected, align_size);
  310. }
  311. if (layout.shape[i] != 1 && layout.stride[i] != expected) {
  312. if (i == height_axis) {
  313. // allow row pitch to be larger than minimal required
  314. auto s = layout.stride[i];
  315. if (!s) {
  316. // broadcast is not contiguous
  317. return false;
  318. }
  319. size_t mask =
  320. image_pitch_alignment_in_bytes(
  321. align_size_in_elements(layout.dtype.size_log()),
  322. layout) -
  323. 1;
  324. megdnn_assert(s > expected && !(s & mask),
  325. "invalid row pitch: %d; layout: %s",
  326. static_cast<int>(s), layout.to_string().c_str());
  327. expected = s;
  328. } else {
  329. return false;
  330. }
  331. }
  332. expected *= layout.shape[i];
  333. }
  334. // empty tensors are not contiguous
  335. return expected != 0;
  336. }
  337. template <size_t PIXEL_SIZE>
  338. TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec(
  339. const TensorLayout& layout) const {
  340. assert_valid(layout);
  341. TensorLayout res{layout};
  342. int new_axis = align_axis();
  343. // remove all dims with shape 1
  344. for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 3; --i) {
  345. if (i == new_axis && static_cast<int>(res.ndim) == new_axis + 1) {
  346. // i is the only width dim
  347. continue;
  348. }
  349. if (i == new_axis - 1 && !i) {
  350. // new_xis == 1 && i == 0, i is the only height dim
  351. continue;
  352. }
  353. if (res.shape[i] == 1) {
  354. res.remove_axis_inplace(i);
  355. if (i < new_axis)
  356. new_axis -= 1;
  357. }
  358. }
  359. megdnn_assert(res.ndim >= 2);
  360. auto contig_with_next = [&](size_t i) {
  361. return res.stride[i] ==
  362. res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1]);
  363. };
  364. for (int i = static_cast<int>(res.ndim) - 2; i >= new_axis; --i) {
  365. megdnn_assert(res.shape[i]);
  366. if (contig_with_next(i)) {
  367. // remove next axis
  368. res.shape[i] *= res.shape[i + 1];
  369. res.stride[i] = res.stride[i + 1];
  370. res.remove_axis_inplace(i + 1);
  371. }
  372. }
  373. for (int i = new_axis - 2; i >= 0; --i) {
  374. megdnn_assert(res.shape[i]);
  375. if (contig_with_next(i)) {
  376. res.shape[i] *= res.shape[i + 1];
  377. res.stride[i] = res.stride[i + 1];
  378. res.remove_axis_inplace(i + 1);
  379. if (i <= new_axis - 2)
  380. new_axis -= 1;
  381. }
  382. }
  383. res.format = change_axis(new_axis);
  384. return res;
  385. }
  386. namespace megdnn {
  387. namespace detail {
  388. template class Image2DPackedTensorFormatBase<4>;
  389. } // namespace detail
  390. } // namespace megdnn
  391. /* =============== FourBitsAlignedToBytesTensorFormatBase ============== */
  392. template <size_t SIZE_NBITS>
  393. LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase(
  394. Type type, size_t align_size_in_bits)
  395. : ImplBase(type), m_align_size_in_bits(align_size_in_bits) {
  396. megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS),
  397. "align size(%zu) must be a multiple of element size(%zu)",
  398. m_align_size_in_bits, SIZE_NBITS);
  399. m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS;
  400. }
  401. template <size_t SIZE_NBITS>
  402. std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const {
  403. return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits);
  404. }
  405. template <size_t SIZE_NBITS>
  406. void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid(
  407. const TensorLayout& layout) const {
  408. megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() &&
  409. layout.dtype.low_bit() == SIZE_NBITS);
  410. bool has_dim_unity_stride = false;
  411. for (int i = layout.ndim - 1; i >= 0; --i) {
  412. if (!has_dim_unity_stride && layout.stride[i] == 1)
  413. has_dim_unity_stride = true;
  414. megdnn_assert(
  415. layout.stride[i] >= 0 &&
  416. (layout.stride[i] % m_align_size_in_elements == 0 ||
  417. layout.stride[i] == 1),
  418. "bad stride: %zu", layout.stride[i]);
  419. }
  420. megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous");
  421. }
  422. template <size_t SIZE_NBITS>
  423. void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append(
  424. std::string& result) const {
  425. SerializePack pack;
  426. pack.align_size_in_bits = m_align_size_in_bits;
  427. megdnn_assert(pack.align_size_in_bits ==
  428. m_align_size_in_bits); // detect overflow;
  429. result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
  430. }
  431. template <size_t SIZE_NBITS>
  432. TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec(
  433. const TensorLayout& layout) const {
  434. assert_valid(layout);
  435. if (layout.ndim == 0)
  436. return {0, 0, 0, 0};
  437. size_t high_elem = 0;
  438. for (size_t i = 0; i < layout.ndim; ++i) {
  439. auto shape_val = layout.shape[i];
  440. if (!shape_val) {
  441. return {0, 0, 0, 0};
  442. }
  443. auto stride_val = layout.stride[i];
  444. megdnn_assert(stride_val >= 0,
  445. "lowbit tensors shouldn't have negative strides");
  446. high_elem += (shape_val - 1) * stride_val;
  447. }
  448. ++high_elem;
  449. size_t high_byte = layout.dtype.size(high_elem);
  450. return TensorLayout::Span(0, 0, high_elem, high_byte);
  451. }
  452. template <size_t SIZE_NBITS>
  453. size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride(
  454. TensorLayout& layout) const {
  455. if (!layout.ndim)
  456. return 0;
  457. megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
  458. size_t accum = 1;
  459. SafeMultiplies<size_t> mul;
  460. for (size_t i = layout.ndim; i; --i) {
  461. layout.stride[i - 1] = accum;
  462. auto multiplier = layout.shape[i - 1];
  463. if (i == layout.ndim)
  464. multiplier = round_up(multiplier, m_align_size_in_elements);
  465. accum = mul(accum, multiplier);
  466. }
  467. return accum;
  468. }
  469. template <size_t SIZE_NBITS>
  470. bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec(
  471. const TensorLayout& layout) const {
  472. assert_valid(layout);
  473. ptrdiff_t expected = 1;
  474. for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) {
  475. if (layout.shape[i] != 1 && layout.stride[i] != expected)
  476. return false;
  477. auto multiplier = layout.shape[i];
  478. if (i == layout.ndim - 1)
  479. multiplier = round_up(multiplier, m_align_size_in_elements);
  480. expected *= multiplier;
  481. }
  482. return expected != 0;
  483. }
  484. template <size_t SIZE_NBITS>
  485. TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec(
  486. const TensorLayout& layout) const {
  487. assert_valid(layout);
  488. TensorLayout res{layout};
  489. for (int i = static_cast<int>(res.ndim) - 1; i >= 0; --i) {
  490. if (!res.shape[i]) {
  491. // empty tensor
  492. res.ndim = 1;
  493. res.shape[0] = 0;
  494. res.stride[0] = 1;
  495. return res;
  496. }
  497. if (res.shape[i] == 1) {
  498. res.remove_axis_inplace(i);
  499. }
  500. }
  501. megdnn_assert(res.ndim && res.shape[res.ndim - 1]);
  502. for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) {
  503. megdnn_assert(res.shape[i]);
  504. if (res.stride[i] ==
  505. res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) {
  506. res.shape[i] *= res.shape[i + 1];
  507. res.stride[i] = res.stride[i + 1];
  508. res.remove_axis_inplace(i + 1);
  509. }
  510. }
  511. return res;
  512. }
  513. namespace megdnn {
  514. namespace detail {
  515. template class LowbitsTensorFormatBase<4>;
  516. } // namespace detail
  517. } // namespace megdnn
  518. /* ===================== Image2DPack4TensorFormat ===================== */
  519. TensorFormat Image2DPack4TensorFormat::make_raw(
  520. size_t align_axis, size_t align_size_in_elements,
  521. Handle::HandleVendorType vendor_type) {
  522. static std::mutex mtx;
  523. static std::unordered_map<uint64_t,
  524. std::unique_ptr<Image2DPack4TensorFormat>>
  525. cache;
  526. megdnn_assert(std::max(align_axis, align_size_in_elements) <=
  527. std::numeric_limits<uint32_t>::max());
  528. MEGDNN_LOCK_GUARD(mtx);
  529. auto&& ptr = cache[(static_cast<uint64_t>(align_axis) << 32) |
  530. align_size_in_elements];
  531. if (!ptr) {
  532. ptr.reset(new Image2DPack4TensorFormat{
  533. align_axis, align_size_in_elements, vendor_type});
  534. }
  535. return impl_to_tensor_format(ptr.get());
  536. }
  537. TensorFormat Image2DPack4TensorFormat::make(size_t align_axis,
  538. const Handle* handle) {
  539. return make_raw(align_axis, handle->image2d_pitch_alignment(),
  540. handle->vendor_type());
  541. }
  542. TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle,
  543. const void* buf,
  544. size_t size) {
  545. megdnn_assert(size == sizeof(SerializePack));
  546. auto pack = *static_cast<const SerializePack*>(buf);
  547. return make(pack.align_axis, handle);
  548. }
  549. TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const {
  550. return make_raw(axis, align_size_in_elements(), vendor());
  551. }
  552. /* ===================== FourBitsAlignedToBytesTensorFormat
  553. * ===================== */
  554. TensorFormat FourBitsAlignedToBytesTensorFormat::make(
  555. size_t align_size_in_bits) {
  556. static std::mutex mtx;
  557. static std::unordered_map<
  558. uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>>
  559. cache;
  560. megdnn_assert(!(align_size_in_bits % 4));
  561. MEGDNN_LOCK_GUARD(mtx);
  562. auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)];
  563. if (!ptr) {
  564. ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits});
  565. }
  566. return impl_to_tensor_format(ptr.get());
  567. }
  568. TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*,
  569. const void* buf,
  570. size_t size) {
  571. megdnn_assert(size == sizeof(SerializePack));
  572. auto pack = *static_cast<const SerializePack*>(buf);
  573. return make(pack.align_size_in_bits);
  574. }
  575. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台