You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

named_tensor.h 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #pragma once
  2. #include "megdnn/internal/defs.h"
  3. #include "megdnn/opr_param_defs.h"
  4. #include <array>
  5. #include <string>
  6. #include "megdnn/thin/small_vector.h"
  7. #include "megdnn/internal/visibility_prologue.h"
  8. namespace megdnn {
  9. class Dimension {
  10. public:
  11. enum class Name : char {
  12. N = 'N', // Batch size
  13. C = 'C', // input channel
  14. H = 'H', // input height
  15. W = 'W', // input width
  16. G = 'G', // group
  17. K = 'K', // output channel
  18. R = 'R', // filter height
  19. S = 'S', // filter width
  20. P = 'P', // output height
  21. Q = 'Q', // output width
  22. };
  23. static constexpr uint32_t UNDETERMINED_EXTENT =
  24. std::numeric_limits<uint32_t>::max();
  25. static const Name NAME_ALL[];
  26. static const int NR_NAMES;
  27. Dimension() = default;
  28. Dimension(const std::string& expr);
  29. Dimension(Name name, uint32_t stride, uint32_t extent = UNDETERMINED_EXTENT)
  30. : m_name{name}, m_stride{stride}, m_extent{extent} {}
  31. Dimension(const Dimension& rhs) { operator=(rhs); }
  32. Dimension& operator=(const Dimension& rhs);
  33. bool operator==(const Dimension& rhs) const;
  34. bool operator<(const Dimension& rhs) const;
  35. Dimension operator*(const Dimension& rhs) const;
  36. Dimension operator/(const Dimension& rhs) const;
  37. std::string to_string() const;
  38. Name name() const { return m_name; }
  39. uint32_t extent() const { return m_extent; }
  40. uint32_t stride() const { return m_stride; }
  41. private:
  42. Name m_name;
  43. uint32_t m_stride;
  44. uint32_t m_extent;
  45. };
  46. struct NamedTensorShape {
  47. using Format = param::ConvBias::Format;
  48. static constexpr size_t MAX_NDIM = MEGDNN_MAX_NDIM;
  49. std::array<Dimension, MAX_NDIM> dims;
  50. size_t ndim = 0;
  51. NamedTensorShape() = default;
  52. NamedTensorShape(const NamedTensorShape& rhs) = default;
  53. NamedTensorShape(const SmallVector<Dimension>& init_shape);
  54. NamedTensorShape(std::initializer_list<Dimension> init_shape);
  55. std::string to_string() const;
  56. bool eq_shape(const NamedTensorShape& rhs) const;
  57. Dimension& operator[](size_t i) { return dims[i]; }
  58. Dimension operator[](size_t i) const { return dims[i]; }
  59. NamedTensorShape static make_named_tensor_shape(Format format);
  60. };
  61. } // namespace megdnn
  62. #include "megdnn/internal/visibility_epilogue.h"
  63. // vim: syntax=cpp.doxygen