You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mesh_indexing.h 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/oprs/general.h"
  4. #include "rng.h"
  5. #include "test/common/indexing_multi_axis_vec.h"
  6. #include "test/common/opr_proxy.h"
  7. namespace megdnn {
  8. namespace test {
  9. #define MESH_INDEXING_LIKE_OPR_PROXY(__opr) \
  10. template <> \
  11. struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
  12. using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
  13. void exec(__opr* opr, const TensorNDArray& tensors) const { \
  14. WorkspaceWrapper W( \
  15. opr->handle(), \
  16. opr->get_workspace_in_bytes( \
  17. tensors[1].layout, axes, tensors.size() - 2, 1)); \
  18. opr->exec( \
  19. tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \
  20. } \
  21. void deduce_layout(__opr* opr, TensorLayoutArray& layouts) { \
  22. MEGDNN_MARK_USED_VAR(opr); \
  23. MEGDNN_MARK_USED_VAR(layouts); \
  24. opr->deduce_layout(layouts[0], make_index_layout(layouts), layouts[1]); \
  25. } \
  26. };
  27. #define MESH_MODIFY_LIKE_OPR_PROXY(__opr) \
  28. template <> \
  29. struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
  30. using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
  31. void exec(__opr* opr, const TensorNDArray& tensors) const { \
  32. WorkspaceWrapper W( \
  33. opr->handle(), \
  34. opr->get_workspace_in_bytes( \
  35. tensors[1].layout, axes, tensors.size() - 2, 1)); \
  36. opr->exec( \
  37. tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \
  38. } \
  39. void deduce_layout(__opr*, TensorLayoutArray&) {} \
  40. };
  41. MESH_INDEXING_LIKE_OPR_PROXY(MeshIndexing);
  42. MESH_INDEXING_LIKE_OPR_PROXY(BatchedMeshIndexing);
  43. MESH_MODIFY_LIKE_OPR_PROXY(IncrMeshIndexing);
  44. MESH_MODIFY_LIKE_OPR_PROXY(BatchedIncrMeshIndexing);
  45. MESH_MODIFY_LIKE_OPR_PROXY(SetMeshIndexing);
  46. MESH_MODIFY_LIKE_OPR_PROXY(BatchedSetMeshIndexing);
  47. #undef MESH_PROXY_COMMON
  48. #undef MESH_INDEXING_LIKE_OPR_PROXY
  49. #undef MESH_MODIFY_LIKE_OPR_PROXY
  50. namespace mesh_indexing {
  51. class NoReplacementIndexRNG final : public RNG {
  52. size_t& m_size;
  53. std::mt19937_64 m_rng;
  54. public:
  55. NoReplacementIndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {}
  56. void gen(const TensorND& tensor) override {
  57. std::vector<int> seq;
  58. for (size_t i = 0; i < m_size; ++i) {
  59. seq.push_back(i);
  60. }
  61. size_t stride = static_cast<size_t>(tensor.layout.stride[0]);
  62. size_t size = tensor.layout[0];
  63. if (tensor.layout.ndim == 1) {
  64. stride = tensor.layout[0];
  65. size = 1;
  66. }
  67. megdnn_assert(stride <= m_size);
  68. auto ptr = tensor.ptr<int>();
  69. for (size_t n = 0; n < size; ++n) {
  70. std::set<int> used;
  71. COMPAT_RANDOM(seq.begin(), seq.end());
  72. for (size_t step = 0; step < stride; ++step) {
  73. megdnn_assert(used.size() < m_size);
  74. ptr[n * stride + step] = seq[step];
  75. used.insert(seq[step]);
  76. }
  77. }
  78. }
  79. };
  80. } // namespace mesh_indexing
  81. } // namespace test
  82. } // namespace megdnn