You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_iter.h 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. /**
  2. * \file dnn/include/megdnn/tensor_iter.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/dtype.h"
  14. #include "megdnn/internal/visibility_prologue.h"
  15. namespace megdnn {
  16. template <typename T>
  17. class TypeRef {
  18. public:
  19. using dtype = T&;
  20. static T& get(T* _ptr, size_t _offset) {
  21. T& ret = _ptr[_offset];
  22. return ret;
  23. }
  24. };
  25. template <>
  26. class TypeRef<dt_quint4> {
  27. private:
  28. uint8_t* ptr = nullptr;
  29. size_t offset = 0;
  30. public:
  31. using dtype = TypeRef<dt_quint4>;
  32. dt_quint4 val = dt_quint4(0);
  33. TypeRef(dt_quint4* _ptr, size_t _offset);
  34. void operator=(const uint8_t _);
  35. void operator=(const dt_quint4& _) { *this = _.as_uint8(); }
  36. void operator=(const TypeRef<dt_quint4>& _) { *this = _.val.as_uint8(); }
  37. operator dt_quint4() const { return val; }
  38. operator uint8_t() const { return val.as_uint8(); }
  39. static TypeRef<dt_quint4> get(dt_quint4* _ptr, size_t _offset) {
  40. return TypeRef<dt_quint4>(_ptr, _offset);
  41. }
  42. };
  43. template <>
  44. class TypeRef<dt_qint4> {
  45. private:
  46. int8_t* ptr = nullptr;
  47. size_t offset = 0;
  48. public:
  49. using dtype = TypeRef<dt_qint4>;
  50. dt_qint4 val = dt_qint4(0);
  51. TypeRef(dt_qint4* _ptr, size_t _offset);
  52. void operator=(const int8_t _);
  53. void operator=(const dt_qint4& _) { *this = _.as_int8(); }
  54. void operator=(const TypeRef<dt_qint4>& _) { *this = _.val.as_int8(); }
  55. operator dt_qint4() const { return val; }
  56. operator int8_t() const { return val.as_int8(); }
  57. static TypeRef<dt_qint4> get(dt_qint4* _ptr, size_t _offset) {
  58. return TypeRef<dt_qint4>(_ptr, _offset);
  59. }
  60. };
  61. /*!
  62. * \brief helper for iterating on a tensor with arbitrary layout
  63. * \tparam ctype tensor element plain data type
  64. * \tparam valonly whether only value is needed (so logical index does not need
  65. * to be maintained)
  66. */
  67. template <typename ctype, bool valonly>
  68. class TensorIter {
  69. TensorND m_tensor;
  70. public:
  71. class Iter {
  72. MEGDNN_NORETURN void on_access_idx_valonly_true() const;
  73. ctype* m_ptr = nullptr;
  74. TensorLayout m_layout;
  75. ptrdiff_t m_axis_reset_stride[TensorShape::MAX_NDIM],
  76. m_offset = 0; //!< physical offset in buffer
  77. //! offset in each axis
  78. size_t m_axis_offset[TensorShape::MAX_NDIM],
  79. m_logical_offset = 0, //!< contiguous logical offset
  80. m_tot_nr_elems = 0; //!< tot elems (max logical offset)
  81. public:
  82. Iter() {
  83. memset(m_axis_reset_stride, 0, sizeof(m_axis_reset_stride));
  84. memset(m_axis_offset, 0, sizeof(m_axis_offset));
  85. }
  86. /*!
  87. * \brief create an iterator
  88. */
  89. static Iter make(ctype* ptr, const TensorLayout& layout, size_t offset);
  90. static Iter make(TensorND& t, size_t offset) {
  91. return make(t.ptr<ctype>(), t.layout, offset);
  92. }
  93. //! access element without boundary check
  94. typename TypeRef<ctype>::dtype operator*() {
  95. return TypeRef<ctype>::get(m_ptr, m_offset);
  96. };
  97. Iter& operator++() {
  98. if ((++m_logical_offset) == m_tot_nr_elems)
  99. return *this;
  100. auto mem_offset = m_offset;
  101. for (int axis = m_layout.ndim - 1;; axis--) {
  102. size_t& ax_offset = ++m_axis_offset[axis];
  103. if (ax_offset < m_layout.shape[axis]) {
  104. mem_offset += m_layout.stride[axis];
  105. break;
  106. } else {
  107. ax_offset = 0;
  108. mem_offset -= m_axis_reset_stride[axis];
  109. }
  110. }
  111. m_offset = mem_offset;
  112. return *this;
  113. }
  114. //! whether current value valid
  115. bool valid() const { return m_logical_offset < m_tot_nr_elems; }
  116. //! whether current pos is at end of buffer
  117. bool at_end() const { return m_logical_offset == m_tot_nr_elems; }
  118. //! get logical index; valonly must be false
  119. const size_t* idx() const {
  120. if (valonly)
  121. on_access_idx_valonly_true();
  122. return m_axis_offset;
  123. }
  124. /*!
  125. * \brief memory address offset, measured in number of elements
  126. */
  127. size_t offset() const { return m_offset; }
  128. /*!
  129. * \brief number of elements from first element
  130. */
  131. size_t logical_offset() const { return m_logical_offset; }
  132. bool operator!=(const Iter& rhs) const {
  133. return m_logical_offset != rhs.m_logical_offset;
  134. }
  135. };
  136. TensorIter() = default;
  137. TensorIter(const TensorND& tensor) : m_tensor(tensor) {}
  138. Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); }
  139. Iter end() const {
  140. return Iter::make(
  141. const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems());
  142. }
  143. };
  144. /*!
  145. * \brief iterate over elements of a tensor; only access tensor value
  146. */
  147. template <typename ctype>
  148. TensorIter<ctype, true> tensor_iter_valonly(const TensorND& t) {
  149. return {t};
  150. }
  151. /*!
  152. * \brief iterate over elements of a tensor, retaining logical index
  153. */
  154. template <typename ctype>
  155. TensorIter<ctype, false> tensor_iter(const TensorND& t) {
  156. return {t};
  157. }
  158. } // namespace megdnn
  159. #include "megdnn/internal/visibility_epilogue.h"
  160. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台