You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_iter.h 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/dtype.h"
  4. #include "megdnn/internal/visibility_prologue.h"
  5. namespace megdnn {
  6. template <typename T>
  7. class TypeRef {
  8. public:
  9. using dtype = T&;
  10. static T& get(T* _ptr, size_t _offset) {
  11. T& ret = _ptr[_offset];
  12. return ret;
  13. }
  14. };
  15. template <>
  16. class TypeRef<dt_quint4> {
  17. private:
  18. uint8_t* ptr = nullptr;
  19. size_t offset = 0;
  20. public:
  21. using dtype = TypeRef<dt_quint4>;
  22. dt_quint4 val = dt_quint4(0);
  23. TypeRef(dt_quint4* _ptr, size_t _offset);
  24. void operator=(const uint8_t _);
  25. void operator=(const dt_quint4& _) { *this = _.as_uint8(); }
  26. void operator=(const TypeRef<dt_quint4>& _) { *this = _.val.as_uint8(); }
  27. operator dt_quint4() const { return val; }
  28. operator uint8_t() const { return val.as_uint8(); }
  29. static TypeRef<dt_quint4> get(dt_quint4* _ptr, size_t _offset) {
  30. return TypeRef<dt_quint4>(_ptr, _offset);
  31. }
  32. };
  33. template <>
  34. class TypeRef<dt_qint4> {
  35. private:
  36. int8_t* ptr = nullptr;
  37. size_t offset = 0;
  38. public:
  39. using dtype = TypeRef<dt_qint4>;
  40. dt_qint4 val = dt_qint4(0);
  41. TypeRef(dt_qint4* _ptr, size_t _offset);
  42. void operator=(const int8_t _);
  43. void operator=(const dt_qint4& _) { *this = _.as_int8(); }
  44. void operator=(const TypeRef<dt_qint4>& _) { *this = _.val.as_int8(); }
  45. operator dt_qint4() const { return val; }
  46. operator int8_t() const { return val.as_int8(); }
  47. static TypeRef<dt_qint4> get(dt_qint4* _ptr, size_t _offset) {
  48. return TypeRef<dt_qint4>(_ptr, _offset);
  49. }
  50. };
  51. /*!
  52. * \brief helper for iterating on a tensor with arbitrary layout
  53. * \tparam ctype tensor element plain data type
  54. * \tparam valonly whether only value is needed (so logical index does not need
  55. * to be maintained)
  56. */
  57. template <typename ctype, bool valonly>
  58. class TensorIter {
  59. TensorND m_tensor;
  60. public:
  61. class Iter {
  62. MEGDNN_NORETURN void on_access_idx_valonly_true() const;
  63. ctype* m_ptr = nullptr;
  64. TensorLayout m_layout;
  65. ptrdiff_t m_axis_reset_stride[TensorShape::MAX_NDIM],
  66. m_offset = 0; //!< physical offset in buffer
  67. //! offset in each axis
  68. size_t m_axis_offset[TensorShape::MAX_NDIM],
  69. m_logical_offset = 0, //!< contiguous logical offset
  70. m_tot_nr_elems = 0; //!< tot elems (max logical offset)
  71. public:
  72. Iter() {
  73. memset(m_axis_reset_stride, 0, sizeof(m_axis_reset_stride));
  74. memset(m_axis_offset, 0, sizeof(m_axis_offset));
  75. }
  76. /*!
  77. * \brief create an iterator
  78. */
  79. static Iter make(ctype* ptr, const TensorLayout& layout, size_t offset);
  80. static Iter make(TensorND& t, size_t offset) {
  81. return make(t.ptr<ctype>(), t.layout, offset);
  82. }
  83. //! access element without boundary check
  84. typename TypeRef<ctype>::dtype operator*() {
  85. return TypeRef<ctype>::get(m_ptr, m_offset);
  86. };
  87. Iter& operator++() {
  88. if ((++m_logical_offset) == m_tot_nr_elems)
  89. return *this;
  90. auto mem_offset = m_offset;
  91. for (int axis = m_layout.ndim - 1;; axis--) {
  92. size_t& ax_offset = ++m_axis_offset[axis];
  93. if (ax_offset < m_layout.shape[axis]) {
  94. mem_offset += m_layout.stride[axis];
  95. break;
  96. } else {
  97. ax_offset = 0;
  98. mem_offset -= m_axis_reset_stride[axis];
  99. }
  100. }
  101. m_offset = mem_offset;
  102. return *this;
  103. }
  104. //! whether current value valid
  105. bool valid() const { return m_logical_offset < m_tot_nr_elems; }
  106. //! whether current pos is at end of buffer
  107. bool at_end() const { return m_logical_offset == m_tot_nr_elems; }
  108. //! get logical index; valonly must be false
  109. const size_t* idx() const {
  110. if (valonly)
  111. on_access_idx_valonly_true();
  112. return m_axis_offset;
  113. }
  114. /*!
  115. * \brief memory address offset, measured in number of elements
  116. */
  117. size_t offset() const { return m_offset; }
  118. /*!
  119. * \brief number of elements from first element
  120. */
  121. size_t logical_offset() const { return m_logical_offset; }
  122. bool operator!=(const Iter& rhs) const {
  123. return m_logical_offset != rhs.m_logical_offset;
  124. }
  125. };
  126. TensorIter() = default;
  127. TensorIter(const TensorND& tensor) : m_tensor(tensor) {}
  128. Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); }
  129. Iter end() const {
  130. return Iter::make(
  131. const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems());
  132. }
  133. };
  134. /*!
  135. * \brief iterate over elements of a tensor; only access tensor value
  136. */
  137. template <typename ctype>
  138. TensorIter<ctype, true> tensor_iter_valonly(const TensorND& t) {
  139. return {t};
  140. }
  141. /*!
  142. * \brief iterate over elements of a tensor, retaining logical index
  143. */
  144. template <typename ctype>
  145. TensorIter<ctype, false> tensor_iter(const TensorND& t) {
  146. return {t};
  147. }
  148. } // namespace megdnn
  149. #include "megdnn/internal/visibility_epilogue.h"
  150. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}