You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_format.h 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. /**
  2. * \file dnn/include/megdnn/tensor_format.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/internal/visibility_prologue.h"
  14. namespace megdnn {
  15. enum class TensorFormat::Type {
  16. DEFAULT = 0, //!< see DefaultTensorFormat
  17. IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
  18. };
  19. class TensorFormat::ImplBase {
  20. public:
  21. using Type = TensorFormat::Type;
  22. virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0;
  23. virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
  24. virtual TensorLayout collapse_contiguous_spec(
  25. const TensorLayout& layout) const = 0;
  26. virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;
  27. //! a human-readable string description of this TensorFormat
  28. virtual std::string to_string() const = 0;
  29. virtual void serialize_append(std::string& result) const = 0;
  30. Type type() const { return m_type; }
  31. protected:
  32. ImplBase(Type type) : m_type{type} {}
  33. ~ImplBase() = default;
  34. static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; }
  35. private:
  36. Type m_type;
  37. };
  38. TensorFormat::Type TensorFormat::type() const {
  39. return m_impl->type();
  40. }
  41. //! default tensor format that imposes no stride constraints
  42. class DefaultTensorFormat final : public TensorFormat::ImplBase {
  43. public:
  44. static constexpr Type TYPE = Type::DEFAULT;
  45. DefaultTensorFormat() : ImplBase(TYPE) {}
  46. size_t init_contiguous_stride(TensorLayout& layout) const override;
  47. /*!
  48. * \brief A tensor is contiguous if logical offset in row-major of any
  49. * element always equals to its physical offset (i.e. offset considering
  50. * strides).
  51. *
  52. * Empty tensors are not considered to be contiguous.
  53. */
  54. bool is_contiguous_spec(const TensorLayout& layout) const override;
  55. TensorLayout collapse_contiguous_spec(
  56. const TensorLayout& layout) const override;
  57. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  58. std::string to_string() const override;
  59. void serialize_append(std::string& result) const override;
  60. static TensorFormat make();
  61. static TensorFormat deserialize(const Handle* handle, const void* buf,
  62. size_t size);
  63. };
  64. namespace detail {
  65. /*!
  66. * \brief 2D image with requirement on row stride
  67. *
  68. * \p align_axis is the axis to be aligned, also the first axis of image width.
  69. * More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
  70. * align_size_in_byte. Axes from 0 to align_axis-1 would be considered as the
  71. * height of the image, and other axes are the width.
  72. *
  73. * Empty tensors and negative strides are not allowed. Only contiguous or
  74. * broadcasted cases are allowed.
  75. *
  76. * Note: if `stride[align_axis - 1]` is larger than minimal value, it is still
  77. * considered as contiguous.
  78. */
  79. class Image2DTensorFormatBase : public TensorFormat::ImplBase {
  80. size_t m_align_axis, m_align_size_in_byte_log2;
  81. protected:
  82. Image2DTensorFormatBase(Type type, size_t align_axis,
  83. size_t align_size_in_byte);
  84. ~Image2DTensorFormatBase() = default;
  85. public:
  86. /*!
  87. * \brief get alignment requirement in bytes
  88. * \param div_log2 the result would be divided by `(1 << div_log2)`
  89. */
  90. size_t align_size_in_byte(size_t div_log2 = 0) const {
  91. return 1 << (m_align_size_in_byte_log2 > div_log2
  92. ? m_align_size_in_byte_log2 - div_log2
  93. : 0);
  94. }
  95. size_t align_axis() const { return m_align_axis; }
  96. size_t init_contiguous_stride(TensorLayout& layout) const override;
  97. bool is_contiguous_spec(const TensorLayout& layout) const override;
  98. TensorLayout collapse_contiguous_spec(
  99. const TensorLayout& layout) const override;
  100. //! span for image must include the padding at the last row
  101. TensorLayout::Span span_spec(const TensorLayout& layout) const override;
  102. std::string to_string() const override;
  103. //! raise exception if preconditions violated
  104. virtual void assert_valid(const TensorLayout& layout) const;
  105. //! modify the align axis and return a new TensorFormat
  106. virtual TensorFormat change_axis(size_t axis) const = 0;
  107. //! number of dtype elems in each row, considering strides
  108. size_t image_width_elems(const TensorLayout& layout) const;
  109. //! number of rows
  110. size_t image_height(const TensorLayout& layout) const;
  111. //! delta of addresses of consecutive rows (in bytes)
  112. size_t image_row_pitch(const TensorLayout& layout) const;
  113. void serialize_append(std::string& result) const override;
  114. protected:
  115. struct SerializePack {
  116. uint8_t align_axis;
  117. };
  118. };
  119. template <size_t PIXEL_SIZE>
  120. class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
  121. protected:
  122. using Image2DTensorFormatBase::Image2DTensorFormatBase;
  123. ~Image2DPackedTensorFormatBase() = default;
  124. public:
  125. /*!
  126. * \brief image width in logical pixels exclude padding
  127. *
  128. * It is the number of accessible elems (in dtype) divided by PIXEL_SIZE.
  129. *
  130. * \see image_row_pitch()
  131. */
  132. size_t image_width(const TensorLayout& layout) const;
  133. void assert_valid(const TensorLayout& layout) const override;
  134. };
  135. using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
  136. } // namespace detail
  137. /*!
  138. * \brief 2D image that requires stride of width to be aligned, and pack 4 elems
  139. * into a pixel
  140. *
  141. * This is used for OpenCL.
  142. */
  143. class Image2DPack4TensorFormat final
  144. : public detail::Image2DPack4TensorFormatBase {
  145. public:
  146. static constexpr Type TYPE = Type::IMAGE2D_PACK4;
  147. //! for internal usage or test purposes
  148. static TensorFormat make_raw(size_t align_axis, size_t align_size_in_byte);
  149. static TensorFormat make(size_t align_axis, const Handle* handle);
  150. /*!
  151. * \brief deserialize on a handle
  152. *
  153. * Note that the alignment may be different if deserialized on another
  154. * handle
  155. */
  156. static TensorFormat deserialize(const Handle* handle, const void* buf,
  157. size_t size);
  158. static bool is_valid_image(const TensorLayout& layout) {
  159. if (layout.format.type() == TYPE) {
  160. layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(
  161. layout);
  162. return true;
  163. }
  164. return false;
  165. }
  166. TensorFormat change_axis(size_t axis) const override;
  167. private:
  168. Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_byte)
  169. : detail::Image2DPack4TensorFormatBase(TYPE, align_axis,
  170. align_size_in_byte) {}
  171. };
  172. } // namespace megdnn
  173. #include "megdnn/internal/visibility_epilogue.h"
  174. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台