You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.h 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "test/common/rng.h"
  4. namespace megdnn {
  5. namespace test {
  6. /**
  7. * array: index in the array form
  8. * linear: a single index number by assuming contiguous layout
  9. * offset: the memory offset in nr elements (can be negative)
  10. *
  11. * dtype is ignored.
  12. */
  13. class Index {
  14. public:
  15. Index(TensorLayout layout, size_t linear);
  16. Index(TensorLayout layout, TensorShape array);
  17. std::string to_string() const;
  18. TensorShape array() const { return m_array; }
  19. TensorLayout layout() const { return m_layout; }
  20. size_t linear_index() const { return m_linear; }
  21. ptrdiff_t offset() const { return m_offset; }
  22. /**
  23. * Add a universal offset to all return values to make the minimal
  24. * offset zero.
  25. */
  26. size_t positive_offset() const { return m_offset - m_layout.span().low_elem; }
  27. private:
  28. TensorLayout m_layout;
  29. size_t m_linear;
  30. TensorShape m_array;
  31. ptrdiff_t m_offset;
  32. void linear_to_array();
  33. void array_to_linear();
  34. void array_to_offset();
  35. };
  36. class IndexRNG final : public RNG {
  37. size_t& m_size;
  38. std::mt19937_64 m_rng;
  39. public:
  40. IndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {}
  41. void gen(const TensorND& tensor) override {
  42. std::uniform_int_distribution<int> dist(-static_cast<int>(m_size), m_size - 1);
  43. auto ptr = tensor.ptr<int>() + tensor.layout.span().low_elem;
  44. for (size_t i = 0; i < tensor.layout.span().dist_elem(); ++i)
  45. ptr[i] = dist(m_rng);
  46. }
  47. };
  48. } // namespace test
  49. } // namespace megdnn
  50. // vim: syntax=cpp.doxygen