You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_format.h 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/handle.h"
  4. #include "megdnn/internal/visibility_prologue.h"
  5. namespace megdnn {
  6. enum class TensorFormat::Type {
  7. DEFAULT = 0, //!< see DefaultTensorFormat
  8. IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
  9. LOWBITS_ALIGNED_TO_BYTE = 2, //!<
  10. };
  11. class TensorFormat::ImplBase {
  12. public:
  13. using Type = TensorFormat::Type;
  14. virtual void assert_valid(const TensorLayout& layout) const = 0;
  15. virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0;
  16. virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
  17. virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0;
  18. virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;
  19. //! a human-readable string description of this TensorFormat
  20. virtual std::string to_string() const = 0;
  21. virtual void serialize_append(std::string& result) const = 0;
  22. Type type() const { return m_type; }
  23. protected:
  24. ImplBase(Type type) : m_type{type} {}
  25. virtual ~ImplBase() = default;
  26. static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; }
  27. private:
  28. Type m_type;
  29. };
  30. TensorFormat::Type TensorFormat::type() const {
  31. return m_impl->type();
  32. }
  33. //! default tensor format that imposes no stride constraints
  34. class DefaultTensorFormat final : public TensorFormat::ImplBase {
  35. public:
  36. static constexpr Type TYPE = Type::DEFAULT;
  37. DefaultTensorFormat() : ImplBase(TYPE) {}
  38. void assert_valid(const TensorLayout& layout) const override;
  39. size_t init_contiguous_stride(TensorLayout& layout) const override;
  40. /*!
  41. * \brief A tensor is contiguous if logical offset in row-major of any
  42. * element always equals to its physical offset (i.e. offset considering
  43. * strides).
  44. *
  45. * Empty tensors are not considered to be contiguous.
  46. */
  47. bool is_contiguous_spec(const TensorLayout& layout) const override;
  48. TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
  49. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  50. std::string to_string() const override;
  51. void serialize_append(std::string& result) const override;
  52. static TensorFormat make();
  53. static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
  54. };
  55. namespace detail {
  56. /*!
  57. * \brief 2D image with requirement on row stride
  58. *
  59. * \p align_axis is the axis to be aligned, also the first axis of image width.
  60. * More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
  61. * align_size_in_elements. Axes from 0 to align_axis-1 would be considered as
  62. * the height of the image, and other axes are the width.
  63. *
  64. * Empty tensors and negative strides are not allowed. Only contiguous or
  65. * broadcasted cases are allowed.
  66. *
  67. * Note: if `stride[align_axis - 1]` is larger than minimal value, it is still
  68. * considered as contiguous.
  69. */
  70. class Image2DTensorFormatBase : public TensorFormat::ImplBase {
  71. size_t m_align_axis, m_align_size_in_elements_log2;
  72. protected:
  73. Image2DTensorFormatBase(
  74. Type type, size_t align_axis, size_t align_size_in_elements);
  75. virtual ~Image2DTensorFormatBase() = default;
  76. public:
  77. /*!
  78. * \brief get alignment requirement in elements
  79. * \param div_log2 the result would be divided by `(1 << div_log2)`
  80. */
  81. size_t align_size_in_elements(size_t div_log2 = 0) const {
  82. return 1 << (m_align_size_in_elements_log2 > div_log2
  83. ? m_align_size_in_elements_log2 - div_log2
  84. : 0);
  85. }
  86. size_t align_axis() const { return m_align_axis; }
  87. size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; }
  88. std::string to_string() const override;
  89. //! modify the align axis and return a new TensorFormat
  90. virtual TensorFormat change_axis(size_t axis) const = 0;
  91. //! number of dtype elems in each row, considering strides
  92. size_t image_width_elems(const TensorLayout& layout) const;
  93. //! number of rows
  94. size_t image_height(const TensorLayout& layout) const;
  95. void serialize_append(std::string& result) const override;
  96. protected:
  97. struct SerializePack {
  98. uint8_t align_axis;
  99. };
  100. };
  101. template <size_t PIXEL_SIZE>
  102. class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
  103. Handle::HandleVendorType m_vendor_type = Handle::HandleVendorType::NOT_SPEC;
  104. /*!
  105. * \brief get fix alignment requirement in bytes, consider m_vendor_type,
  106. * for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width
  107. * align COUNT, but mdl needs align size in byte, which equal to
  108. * (image_width algin count) * sizeof(data_type) * pixel_size
  109. */
  110. size_t image_pitch_alignment_in_bytes(
  111. size_t align_size_in_elements, const TensorLayout& layout) const;
  112. protected:
  113. Image2DPackedTensorFormatBase(
  114. Type type, size_t align_axis, size_t align_size_in_elements,
  115. Handle::HandleVendorType vendor_type)
  116. : detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements),
  117. m_vendor_type(vendor_type) {}
  118. virtual ~Image2DPackedTensorFormatBase() = default;
  119. Handle::HandleVendorType vendor() const { return m_vendor_type; }
  120. public:
  121. /*!
  122. * \brief image width in logical pixels exclude padding
  123. *
  124. * It is the number of accessible elems (in dtype) divided by PIXEL_SIZE.
  125. *
  126. * \see image_row_pitch()
  127. */
  128. size_t image_width(const TensorLayout& layout) const;
  129. size_t image_row_pitch(const TensorLayout& layout) const;
  130. //! raise exception if preconditions violated
  131. void assert_valid(const TensorLayout& layout) const override;
  132. //! span for image must include the padding at the last row
  133. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  134. size_t init_contiguous_stride(TensorLayout& layout) const override;
  135. bool is_contiguous_spec(const TensorLayout& layout) const override;
  136. TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
  137. };
  138. using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
  139. /*!
  140. * \brief used for tensors storing lowbit data
  141. *
  142. * \param m_size_nbits size in bits of elements in the tensor
  143. * \param m_align_size_in_bits aligned size in bits
  144. * \param m_align_size_in_elements aligned size in elements
  145. */
  146. class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase {
  147. size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements;
  148. protected: //?
  149. LowbitsAlignedTensorFormatBase(
  150. Type type, size_t size_nbits, size_t align_size_in_bits);
  151. virtual ~LowbitsAlignedTensorFormatBase() = default;
  152. public:
  153. size_t align_size_in_bits() const { return m_align_size_in_bits; }
  154. size_t size_nbits() const { return m_size_nbits; }
  155. std::string to_string() const override;
  156. //! raise exception if given layout is illegal
  157. void assert_valid(const TensorLayout& layout) const override;
  158. void serialize_append(std::string& result) const override;
  159. //! span for lowbit tensor must include the padding at the innermost
  160. //! dimemsion that make lowbit tensor be aligned to bytes
  161. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  162. size_t init_contiguous_stride(TensorLayout& layout) const override;
  163. bool is_contiguous_spec(const TensorLayout& layout) const override;
  164. TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
  165. protected:
  166. struct SerializePack {
  167. uint8_t size_nbits;
  168. uint8_t align_size_in_bits;
  169. };
  170. };
  171. } // namespace detail
  172. /*!
  173. * \brief 2D image that requires stride of width to be aligned, and pack 4 elems
  174. * into a pixel
  175. *
  176. * This is used for OpenCL.
  177. */
  178. class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase {
  179. public:
  180. static constexpr Type TYPE = Type::IMAGE2D_PACK4;
  181. //! for internal usage or test purposes
  182. static TensorFormat make_raw(
  183. size_t align_axis, size_t align_size_in_elements,
  184. Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC);
  185. static TensorFormat make(size_t align_axis, const Handle* handle);
  186. /*!
  187. * \brief deserialize on a handle
  188. *
  189. * Note that the alignment may be different if deserialized on another
  190. * handle
  191. */
  192. static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
  193. static bool is_valid_image(const TensorLayout& layout) {
  194. if (layout.format.type() == TYPE) {
  195. layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout);
  196. return true;
  197. }
  198. return false;
  199. }
  200. TensorFormat change_axis(size_t axis) const override;
  201. private:
  202. Image2DPack4TensorFormat(
  203. size_t align_axis, size_t align_size_in_elements,
  204. Handle::HandleVendorType vendor_type)
  205. : detail::Image2DPack4TensorFormatBase(
  206. TYPE, align_axis, align_size_in_elements, vendor_type) {}
  207. };
  208. /*!
  209. * \brief Tensor for storing 4bit data that requires stride corresponding to
  210. * non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte
  211. */
  212. class LowbitsAlignedToBytesTensorFormat final
  213. : public detail::LowbitsAlignedTensorFormatBase {
  214. public:
  215. static constexpr Type TYPE = Type::LOWBITS_ALIGNED_TO_BYTE;
  216. static constexpr size_t BYTE_IN_BITS = 8;
  217. static TensorFormat make(size_t size_nbits);
  218. static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
  219. static bool is_valid_layout(const TensorLayout& layout) {
  220. if (layout.format.type() == TYPE) {
  221. layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid(
  222. layout);
  223. return true;
  224. }
  225. return false;
  226. }
  227. private:
  228. LowbitsAlignedToBytesTensorFormat(size_t size_nbits)
  229. : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {}
  230. };
  231. } // namespace megdnn
  232. #include "megdnn/internal/visibility_epilogue.h"
  233. // vim: syntax=cpp.doxygen