You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_format.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. /**
  2. * \file dnn/include/megdnn/tensor_format.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/handle.h"
  14. #include "megdnn/internal/visibility_prologue.h"
  15. namespace megdnn {
  16. enum class TensorFormat::Type {
  17. DEFAULT = 0, //!< see DefaultTensorFormat
  18. IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
  19. LOWBITS_ALIGNED_TO_BYTE = 2, //!<
  20. };
  21. class TensorFormat::ImplBase {
  22. public:
  23. using Type = TensorFormat::Type;
  24. virtual void assert_valid(const TensorLayout& layout) const = 0;
  25. virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0;
  26. virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
  27. virtual TensorLayout collapse_contiguous_spec(
  28. const TensorLayout& layout) const = 0;
  29. virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;
  30. //! a human-readable string description of this TensorFormat
  31. virtual std::string to_string() const = 0;
  32. virtual void serialize_append(std::string& result) const = 0;
  33. Type type() const { return m_type; }
  34. protected:
  35. ImplBase(Type type) : m_type{type} {}
  36. virtual ~ImplBase() = default;
  37. static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; }
  38. private:
  39. Type m_type;
  40. };
  41. TensorFormat::Type TensorFormat::type() const {
  42. return m_impl->type();
  43. }
  44. //! default tensor format that imposes no stride constraints
  45. class DefaultTensorFormat final : public TensorFormat::ImplBase {
  46. public:
  47. static constexpr Type TYPE = Type::DEFAULT;
  48. DefaultTensorFormat() : ImplBase(TYPE) {}
  49. void assert_valid(const TensorLayout& layout) const override;
  50. size_t init_contiguous_stride(TensorLayout& layout) const override;
  51. /*!
  52. * \brief A tensor is contiguous if logical offset in row-major of any
  53. * element always equals to its physical offset (i.e. offset considering
  54. * strides).
  55. *
  56. * Empty tensors are not considered to be contiguous.
  57. */
  58. bool is_contiguous_spec(const TensorLayout& layout) const override;
  59. TensorLayout collapse_contiguous_spec(
  60. const TensorLayout& layout) const override;
  61. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  62. std::string to_string() const override;
  63. void serialize_append(std::string& result) const override;
  64. static TensorFormat make();
  65. static TensorFormat deserialize(const Handle* handle, const void* buf,
  66. size_t size);
  67. };
  68. namespace detail {
  69. /*!
  70. * \brief 2D image with requirement on row stride
  71. *
  72. * \p align_axis is the axis to be aligned, also the first axis of image width.
  73. * More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
  74. * align_size_in_elements. Axes from 0 to align_axis-1 would be considered as
  75. * the height of the image, and other axes are the width.
  76. *
  77. * Empty tensors and negative strides are not allowed. Only contiguous or
  78. * broadcasted cases are allowed.
  79. *
  80. * Note: if `stride[align_axis - 1]` is larger than minimal value, it is still
  81. * considered as contiguous.
  82. */
  83. class Image2DTensorFormatBase : public TensorFormat::ImplBase {
  84. size_t m_align_axis, m_align_size_in_elements_log2;
  85. protected:
  86. Image2DTensorFormatBase(Type type, size_t align_axis,
  87. size_t align_size_in_elements);
  88. virtual ~Image2DTensorFormatBase() = default;
  89. public:
  90. /*!
  91. * \brief get alignment requirement in elements
  92. * \param div_log2 the result would be divided by `(1 << div_log2)`
  93. */
  94. size_t align_size_in_elements(size_t div_log2 = 0) const {
  95. return 1 << (m_align_size_in_elements_log2 > div_log2
  96. ? m_align_size_in_elements_log2 - div_log2
  97. : 0);
  98. }
  99. size_t align_axis() const { return m_align_axis; }
  100. size_t align_size_in_elements_log2() const {
  101. return m_align_size_in_elements_log2;
  102. }
  103. std::string to_string() const override;
  104. //! modify the align axis and return a new TensorFormat
  105. virtual TensorFormat change_axis(size_t axis) const = 0;
  106. //! number of dtype elems in each row, considering strides
  107. size_t image_width_elems(const TensorLayout& layout) const;
  108. //! number of rows
  109. size_t image_height(const TensorLayout& layout) const;
  110. void serialize_append(std::string& result) const override;
  111. protected:
  112. struct SerializePack {
  113. uint8_t align_axis;
  114. };
  115. };
  116. template <size_t PIXEL_SIZE>
  117. class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
  118. Handle::HandleVendorType m_vendor_type = Handle::HandleVendorType::NOT_SPEC;
  119. /*!
  120. * \brief get fix alignment requirement in bytes, consider m_vendor_type,
  121. * for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width
  122. * align COUNT, but mdl needs align size in byte, which equal to
  123. * (image_width algin count) * sizeof(data_type) * pixel_size
  124. */
  125. size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements,
  126. const TensorLayout& layout) const;
  127. protected:
  128. Image2DPackedTensorFormatBase(Type type, size_t align_axis,
  129. size_t align_size_in_elements,
  130. Handle::HandleVendorType vendor_type)
  131. : detail::Image2DTensorFormatBase(type, align_axis,
  132. align_size_in_elements),
  133. m_vendor_type(vendor_type) {}
  134. virtual ~Image2DPackedTensorFormatBase() = default;
  135. Handle::HandleVendorType vendor() const { return m_vendor_type; }
  136. public:
  137. /*!
  138. * \brief image width in logical pixels exclude padding
  139. *
  140. * It is the number of accessible elems (in dtype) divided by PIXEL_SIZE.
  141. *
  142. * \see image_row_pitch()
  143. */
  144. size_t image_width(const TensorLayout& layout) const;
  145. size_t image_row_pitch(const TensorLayout& layout) const;
  146. //! raise exception if preconditions violated
  147. void assert_valid(const TensorLayout& layout) const override;
  148. //! span for image must include the padding at the last row
  149. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  150. size_t init_contiguous_stride(TensorLayout& layout) const override;
  151. bool is_contiguous_spec(const TensorLayout& layout) const override;
  152. TensorLayout collapse_contiguous_spec(
  153. const TensorLayout& layout) const override;
  154. };
  155. using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
  156. /*!
  157. * \brief used for tensors storing lowbit data
  158. *
  159. * \param m_size_nbits size in bits of elements in the tensor
  160. * \param m_align_size_in_bits aligned size in bits
  161. * \param m_align_size_in_elements aligned size in elements
  162. */
  163. class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase {
  164. size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements;
  165. protected: //?
  166. LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits,
  167. size_t align_size_in_bits);
  168. virtual ~LowbitsAlignedTensorFormatBase() = default;
  169. public:
  170. size_t align_size_in_bits() const { return m_align_size_in_bits; }
  171. size_t size_nbits() const { return m_size_nbits; }
  172. std::string to_string() const override;
  173. //! raise exception if given layout is illegal
  174. void assert_valid(const TensorLayout& layout) const override;
  175. void serialize_append(std::string& result) const override;
  176. //! span for lowbit tensor must include the padding at the innermost
  177. //! dimemsion that make lowbit tensor be aligned to bytes
  178. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  179. size_t init_contiguous_stride(TensorLayout& layout) const override;
  180. bool is_contiguous_spec(const TensorLayout& layout) const override;
  181. TensorLayout collapse_contiguous_spec(
  182. const TensorLayout& layout) const override;
  183. protected:
  184. struct SerializePack {
  185. uint8_t size_nbits;
  186. uint8_t align_size_in_bits;
  187. };
  188. };
  189. } // namespace detail
  190. /*!
  191. * \brief 2D image that requires stride of width to be aligned, and pack 4 elems
  192. * into a pixel
  193. *
  194. * This is used for OpenCL.
  195. */
  196. class Image2DPack4TensorFormat final
  197. : public detail::Image2DPack4TensorFormatBase {
  198. public:
  199. static constexpr Type TYPE = Type::IMAGE2D_PACK4;
  200. //! for internal usage or test purposes
  201. static TensorFormat make_raw(size_t align_axis,
  202. size_t align_size_in_elements,
  203. Handle::HandleVendorType vendor_type =
  204. Handle::HandleVendorType::NOT_SPEC);
  205. static TensorFormat make(size_t align_axis, const Handle* handle);
  206. /*!
  207. * \brief deserialize on a handle
  208. *
  209. * Note that the alignment may be different if deserialized on another
  210. * handle
  211. */
  212. static TensorFormat deserialize(const Handle* handle, const void* buf,
  213. size_t size);
  214. static bool is_valid_image(const TensorLayout& layout) {
  215. if (layout.format.type() == TYPE) {
  216. layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(
  217. layout);
  218. return true;
  219. }
  220. return false;
  221. }
  222. TensorFormat change_axis(size_t axis) const override;
  223. private:
  224. Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements,
  225. Handle::HandleVendorType vendor_type)
  226. : detail::Image2DPack4TensorFormatBase(
  227. TYPE, align_axis, align_size_in_elements, vendor_type) {}
  228. };
  229. /*!
  230. * \brief Tensor for storing 4bit data that requires stride corresponding to
  231. * non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte
  232. */
  233. class LowbitsAlignedToBytesTensorFormat final
  234. : public detail::LowbitsAlignedTensorFormatBase {
  235. public:
  236. static constexpr Type TYPE = Type::LOWBITS_ALIGNED_TO_BYTE;
  237. static constexpr size_t BYTE_IN_BITS = 8;
  238. static TensorFormat make(size_t size_nbits);
  239. static TensorFormat deserialize(const Handle* handle, const void* buf,
  240. size_t size);
  241. static bool is_valid_layout(const TensorLayout& layout) {
  242. if (layout.format.type() == TYPE) {
  243. layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>()
  244. .assert_valid(layout);
  245. return true;
  246. }
  247. return false;
  248. }
  249. private:
  250. LowbitsAlignedToBytesTensorFormat(size_t size_nbits)
  251. : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits,
  252. BYTE_IN_BITS) {}
  253. };
  254. } // namespace megdnn
  255. #include "megdnn/internal/visibility_epilogue.h"
  256. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台