You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_format.h 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. /**
  2. * \file dnn/include/megdnn/tensor_format.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/handle.h"
  14. #include "megdnn/internal/visibility_prologue.h"
  15. namespace megdnn {
  16. enum class TensorFormat::Type {
  17. DEFAULT = 0, //!< see DefaultTensorFormat
  18. IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
  19. };
  20. class TensorFormat::ImplBase {
  21. public:
  22. using Type = TensorFormat::Type;
  23. virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0;
  24. virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
  25. virtual TensorLayout collapse_contiguous_spec(
  26. const TensorLayout& layout) const = 0;
  27. virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;
  28. //! a human-readable string description of this TensorFormat
  29. virtual std::string to_string() const = 0;
  30. virtual void serialize_append(std::string& result) const = 0;
  31. Type type() const { return m_type; }
  32. protected:
  33. ImplBase(Type type) : m_type{type} {}
  34. virtual ~ImplBase() = default;
  35. static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; }
  36. private:
  37. Type m_type;
  38. };
  39. TensorFormat::Type TensorFormat::type() const {
  40. return m_impl->type();
  41. }
  42. //! default tensor format that imposes no stride constraints
  43. class DefaultTensorFormat final : public TensorFormat::ImplBase {
  44. public:
  45. static constexpr Type TYPE = Type::DEFAULT;
  46. DefaultTensorFormat() : ImplBase(TYPE) {}
  47. size_t init_contiguous_stride(TensorLayout& layout) const override;
  48. /*!
  49. * \brief A tensor is contiguous if logical offset in row-major of any
  50. * element always equals to its physical offset (i.e. offset considering
  51. * strides).
  52. *
  53. * Empty tensors are not considered to be contiguous.
  54. */
  55. bool is_contiguous_spec(const TensorLayout& layout) const override;
  56. TensorLayout collapse_contiguous_spec(
  57. const TensorLayout& layout) const override;
  58. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  59. std::string to_string() const override;
  60. void serialize_append(std::string& result) const override;
  61. static TensorFormat make();
  62. static TensorFormat deserialize(const Handle* handle, const void* buf,
  63. size_t size);
  64. };
  65. namespace detail {
  66. /*!
  67. * \brief 2D image with requirement on row stride
  68. *
  69. * \p align_axis is the axis to be aligned, also the first axis of image width.
  70. * More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
  71. * align_size_in_elements. Axes from 0 to align_axis-1 would be considered as
  72. * the height of the image, and other axes are the width.
  73. *
  74. * Empty tensors and negative strides are not allowed. Only contiguous or
  75. * broadcasted cases are allowed.
  76. *
  77. * Note: if `stride[align_axis - 1]` is larger than minimal value, it is still
  78. * considered as contiguous.
  79. */
  80. class Image2DTensorFormatBase : public TensorFormat::ImplBase {
  81. size_t m_align_axis, m_align_size_in_elements_log2;
  82. protected:
  83. Image2DTensorFormatBase(Type type, size_t align_axis,
  84. size_t align_size_in_elements);
  85. virtual ~Image2DTensorFormatBase() = default;
  86. public:
  87. /*!
  88. * \brief get alignment requirement in elements
  89. * \param div_log2 the result would be divided by `(1 << div_log2)`
  90. */
  91. size_t align_size_in_elements(size_t div_log2 = 0) const {
  92. return 1 << (m_align_size_in_elements_log2 > div_log2
  93. ? m_align_size_in_elements_log2 - div_log2
  94. : 0);
  95. }
  96. size_t align_axis() const { return m_align_axis; }
  97. size_t align_size_in_elements_log2() const {
  98. return m_align_size_in_elements_log2;
  99. }
  100. std::string to_string() const override;
  101. //! modify the align axis and return a new TensorFormat
  102. virtual TensorFormat change_axis(size_t axis) const = 0;
  103. //! number of dtype elems in each row, considering strides
  104. size_t image_width_elems(const TensorLayout& layout) const;
  105. //! number of rows
  106. size_t image_height(const TensorLayout& layout) const;
  107. void serialize_append(std::string& result) const override;
  108. protected:
  109. struct SerializePack {
  110. uint8_t align_axis;
  111. };
  112. };
  113. template <size_t PIXEL_SIZE>
  114. class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
  115. Handle::HandleVendorType m_vendor_type = Handle::HandleVendorType::NOT_SPEC;
  116. /*!
  117. * \brief get fix alignment requirement in bytes, consider m_vendor_type,
  118. * for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width
  119. * align COUNT, but mdl needs align size in byte, which equal to
  120. * (image_width algin count) * sizeof(data_type) * pixel_size
  121. */
  122. size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements,
  123. const TensorLayout& layout) const;
  124. protected:
  125. Image2DPackedTensorFormatBase(Type type, size_t align_axis,
  126. size_t align_size_in_elements,
  127. Handle::HandleVendorType vendor_type)
  128. : detail::Image2DTensorFormatBase(type, align_axis,
  129. align_size_in_elements),
  130. m_vendor_type(vendor_type) {}
  131. virtual ~Image2DPackedTensorFormatBase() = default;
  132. Handle::HandleVendorType vendor() const { return m_vendor_type; }
  133. public:
  134. /*!
  135. * \brief image width in logical pixels exclude padding
  136. *
  137. * It is the number of accessible elems (in dtype) divided by PIXEL_SIZE.
  138. *
  139. * \see image_row_pitch()
  140. */
  141. size_t image_width(const TensorLayout& layout) const;
  142. //! raise exception if preconditions violated
  143. void assert_valid(const TensorLayout& layout) const;
  144. size_t image_row_pitch(const TensorLayout& layout) const;
  145. //! span for image must include the padding at the last row
  146. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  147. size_t init_contiguous_stride(TensorLayout& layout) const override;
  148. bool is_contiguous_spec(const TensorLayout& layout) const override;
  149. TensorLayout collapse_contiguous_spec(
  150. const TensorLayout& layout) const override;
  151. };
  152. using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
  153. } // namespace detail
  154. /*!
  155. * \brief 2D image that requires stride of width to be aligned, and pack 4 elems
  156. * into a pixel
  157. *
  158. * This is used for OpenCL.
  159. */
  160. class Image2DPack4TensorFormat final
  161. : public detail::Image2DPack4TensorFormatBase {
  162. public:
  163. static constexpr Type TYPE = Type::IMAGE2D_PACK4;
  164. //! for internal usage or test purposes
  165. static TensorFormat make_raw(size_t align_axis,
  166. size_t align_size_in_elements,
  167. Handle::HandleVendorType vendor_type =
  168. Handle::HandleVendorType::NOT_SPEC);
  169. static TensorFormat make(size_t align_axis, const Handle* handle);
  170. /*!
  171. * \brief deserialize on a handle
  172. *
  173. * Note that the alignment may be different if deserialized on another
  174. * handle
  175. */
  176. static TensorFormat deserialize(const Handle* handle, const void* buf,
  177. size_t size);
  178. static bool is_valid_image(const TensorLayout& layout) {
  179. if (layout.format.type() == TYPE) {
  180. layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(
  181. layout);
  182. return true;
  183. }
  184. return false;
  185. }
  186. TensorFormat change_axis(size_t axis) const override;
  187. private:
  188. Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements,
  189. Handle::HandleVendorType vendor_type)
  190. : detail::Image2DPack4TensorFormatBase(
  191. TYPE, align_axis, align_size_in_elements, vendor_type) {}
  192. };
  193. } // namespace megdnn
  194. #include "megdnn/internal/visibility_epilogue.h"
  195. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台