You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_format.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. /**
  2. * \file dnn/include/megdnn/tensor_format.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/handle.h"
  14. #include "megdnn/internal/visibility_prologue.h"
  15. namespace megdnn {
  16. enum class TensorFormat::Type {
  17. DEFAULT = 0, //!< see DefaultTensorFormat
  18. IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
  19. LOWBITS_ALIGNED_TO_BYTE = 2, //!<
  20. };
  21. class TensorFormat::ImplBase {
  22. public:
  23. using Type = TensorFormat::Type;
  24. virtual void assert_valid(const TensorLayout& layout) const = 0;
  25. virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0;
  26. virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
  27. virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0;
  28. virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;
  29. //! a human-readable string description of this TensorFormat
  30. virtual std::string to_string() const = 0;
  31. virtual void serialize_append(std::string& result) const = 0;
  32. Type type() const { return m_type; }
  33. protected:
  34. ImplBase(Type type) : m_type{type} {}
  35. virtual ~ImplBase() = default;
  36. static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; }
  37. private:
  38. Type m_type;
  39. };
  40. TensorFormat::Type TensorFormat::type() const {
  41. return m_impl->type();
  42. }
  43. //! default tensor format that imposes no stride constraints
  44. class DefaultTensorFormat final : public TensorFormat::ImplBase {
  45. public:
  46. static constexpr Type TYPE = Type::DEFAULT;
  47. DefaultTensorFormat() : ImplBase(TYPE) {}
  48. void assert_valid(const TensorLayout& layout) const override;
  49. size_t init_contiguous_stride(TensorLayout& layout) const override;
  50. /*!
  51. * \brief A tensor is contiguous if logical offset in row-major of any
  52. * element always equals to its physical offset (i.e. offset considering
  53. * strides).
  54. *
  55. * Empty tensors are not considered to be contiguous.
  56. */
  57. bool is_contiguous_spec(const TensorLayout& layout) const override;
  58. TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
  59. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  60. std::string to_string() const override;
  61. void serialize_append(std::string& result) const override;
  62. static TensorFormat make();
  63. static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
  64. };
  65. namespace detail {
  66. /*!
  67. * \brief 2D image with requirement on row stride
  68. *
  69. * \p align_axis is the axis to be aligned, also the first axis of image width.
  70. * More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
  71. * align_size_in_elements. Axes from 0 to align_axis-1 would be considered as
  72. * the height of the image, and other axes are the width.
  73. *
  74. * Empty tensors and negative strides are not allowed. Only contiguous or
  75. * broadcasted cases are allowed.
  76. *
  77. * Note: if `stride[align_axis - 1]` is larger than minimal value, it is still
  78. * considered as contiguous.
  79. */
  80. class Image2DTensorFormatBase : public TensorFormat::ImplBase {
  81. size_t m_align_axis, m_align_size_in_elements_log2;
  82. protected:
  83. Image2DTensorFormatBase(
  84. Type type, size_t align_axis, size_t align_size_in_elements);
  85. virtual ~Image2DTensorFormatBase() = default;
  86. public:
  87. /*!
  88. * \brief get alignment requirement in elements
  89. * \param div_log2 the result would be divided by `(1 << div_log2)`
  90. */
  91. size_t align_size_in_elements(size_t div_log2 = 0) const {
  92. return 1 << (m_align_size_in_elements_log2 > div_log2
  93. ? m_align_size_in_elements_log2 - div_log2
  94. : 0);
  95. }
  96. size_t align_axis() const { return m_align_axis; }
  97. size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; }
  98. std::string to_string() const override;
  99. //! modify the align axis and return a new TensorFormat
  100. virtual TensorFormat change_axis(size_t axis) const = 0;
  101. //! number of dtype elems in each row, considering strides
  102. size_t image_width_elems(const TensorLayout& layout) const;
  103. //! number of rows
  104. size_t image_height(const TensorLayout& layout) const;
  105. void serialize_append(std::string& result) const override;
  106. protected:
  107. struct SerializePack {
  108. uint8_t align_axis;
  109. };
  110. };
  111. template <size_t PIXEL_SIZE>
  112. class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
  113. Handle::HandleVendorType m_vendor_type = Handle::HandleVendorType::NOT_SPEC;
  114. /*!
  115. * \brief get fix alignment requirement in bytes, consider m_vendor_type,
  116. * for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width
  117. * align COUNT, but mdl needs align size in byte, which equal to
  118. * (image_width algin count) * sizeof(data_type) * pixel_size
  119. */
  120. size_t image_pitch_alignment_in_bytes(
  121. size_t align_size_in_elements, const TensorLayout& layout) const;
  122. protected:
  123. Image2DPackedTensorFormatBase(
  124. Type type, size_t align_axis, size_t align_size_in_elements,
  125. Handle::HandleVendorType vendor_type)
  126. : detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements),
  127. m_vendor_type(vendor_type) {}
  128. virtual ~Image2DPackedTensorFormatBase() = default;
  129. Handle::HandleVendorType vendor() const { return m_vendor_type; }
  130. public:
  131. /*!
  132. * \brief image width in logical pixels exclude padding
  133. *
  134. * It is the number of accessible elems (in dtype) divided by PIXEL_SIZE.
  135. *
  136. * \see image_row_pitch()
  137. */
  138. size_t image_width(const TensorLayout& layout) const;
  139. size_t image_row_pitch(const TensorLayout& layout) const;
  140. //! raise exception if preconditions violated
  141. void assert_valid(const TensorLayout& layout) const override;
  142. //! span for image must include the padding at the last row
  143. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  144. size_t init_contiguous_stride(TensorLayout& layout) const override;
  145. bool is_contiguous_spec(const TensorLayout& layout) const override;
  146. TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
  147. };
  148. using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
  149. /*!
  150. * \brief used for tensors storing lowbit data
  151. *
  152. * \param m_size_nbits size in bits of elements in the tensor
  153. * \param m_align_size_in_bits aligned size in bits
  154. * \param m_align_size_in_elements aligned size in elements
  155. */
  156. class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase {
  157. size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements;
  158. protected: //?
  159. LowbitsAlignedTensorFormatBase(
  160. Type type, size_t size_nbits, size_t align_size_in_bits);
  161. virtual ~LowbitsAlignedTensorFormatBase() = default;
  162. public:
  163. size_t align_size_in_bits() const { return m_align_size_in_bits; }
  164. size_t size_nbits() const { return m_size_nbits; }
  165. std::string to_string() const override;
  166. //! raise exception if given layout is illegal
  167. void assert_valid(const TensorLayout& layout) const override;
  168. void serialize_append(std::string& result) const override;
  169. //! span for lowbit tensor must include the padding at the innermost
  170. //! dimemsion that make lowbit tensor be aligned to bytes
  171. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  172. size_t init_contiguous_stride(TensorLayout& layout) const override;
  173. bool is_contiguous_spec(const TensorLayout& layout) const override;
  174. TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
  175. protected:
  176. struct SerializePack {
  177. uint8_t size_nbits;
  178. uint8_t align_size_in_bits;
  179. };
  180. };
  181. } // namespace detail
  182. /*!
  183. * \brief 2D image that requires stride of width to be aligned, and pack 4 elems
  184. * into a pixel
  185. *
  186. * This is used for OpenCL.
  187. */
  188. class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase {
  189. public:
  190. static constexpr Type TYPE = Type::IMAGE2D_PACK4;
  191. //! for internal usage or test purposes
  192. static TensorFormat make_raw(
  193. size_t align_axis, size_t align_size_in_elements,
  194. Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC);
  195. static TensorFormat make(size_t align_axis, const Handle* handle);
  196. /*!
  197. * \brief deserialize on a handle
  198. *
  199. * Note that the alignment may be different if deserialized on another
  200. * handle
  201. */
  202. static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
  203. static bool is_valid_image(const TensorLayout& layout) {
  204. if (layout.format.type() == TYPE) {
  205. layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout);
  206. return true;
  207. }
  208. return false;
  209. }
  210. TensorFormat change_axis(size_t axis) const override;
  211. private:
  212. Image2DPack4TensorFormat(
  213. size_t align_axis, size_t align_size_in_elements,
  214. Handle::HandleVendorType vendor_type)
  215. : detail::Image2DPack4TensorFormatBase(
  216. TYPE, align_axis, align_size_in_elements, vendor_type) {}
  217. };
  218. /*!
  219. * \brief Tensor for storing 4bit data that requires stride corresponding to
  220. * non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte
  221. */
  222. class LowbitsAlignedToBytesTensorFormat final
  223. : public detail::LowbitsAlignedTensorFormatBase {
  224. public:
  225. static constexpr Type TYPE = Type::LOWBITS_ALIGNED_TO_BYTE;
  226. static constexpr size_t BYTE_IN_BITS = 8;
  227. static TensorFormat make(size_t size_nbits);
  228. static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
  229. static bool is_valid_layout(const TensorLayout& layout) {
  230. if (layout.format.type() == TYPE) {
  231. layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid(
  232. layout);
  233. return true;
  234. }
  235. return false;
  236. }
  237. private:
  238. LowbitsAlignedToBytesTensorFormat(size_t size_nbits)
  239. : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {}
  240. };
  241. } // namespace megdnn
  242. #include "megdnn/internal/visibility_epilogue.h"
  243. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台