|
- #pragma once
-
- #include "megdnn/basic_types.h"
- #include "megdnn/oprs/general.h"
- #include "rng.h"
- #include "test/common/indexing_multi_axis_vec.h"
- #include "test/common/opr_proxy.h"
-
- namespace megdnn {
- namespace test {
-
- #define MESH_INDEXING_LIKE_OPR_PROXY(__opr) \
- template <> \
- struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
- using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
- void exec(__opr* opr, const TensorNDArray& tensors) const { \
- WorkspaceWrapper W( \
- opr->handle(), \
- opr->get_workspace_in_bytes( \
- tensors[1].layout, axes, tensors.size() - 2, 1)); \
- opr->exec( \
- tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \
- } \
- void deduce_layout(__opr* opr, TensorLayoutArray& layouts) { \
- MEGDNN_MARK_USED_VAR(opr); \
- MEGDNN_MARK_USED_VAR(layouts); \
- opr->deduce_layout(layouts[0], make_index_layout(layouts), layouts[1]); \
- } \
- };
-
- #define MESH_MODIFY_LIKE_OPR_PROXY(__opr) \
- template <> \
- struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
- using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
- void exec(__opr* opr, const TensorNDArray& tensors) const { \
- WorkspaceWrapper W( \
- opr->handle(), \
- opr->get_workspace_in_bytes( \
- tensors[1].layout, axes, tensors.size() - 2, 1)); \
- opr->exec( \
- tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \
- } \
- void deduce_layout(__opr*, TensorLayoutArray&) {} \
- };
-
- MESH_INDEXING_LIKE_OPR_PROXY(MeshIndexing);
- MESH_INDEXING_LIKE_OPR_PROXY(BatchedMeshIndexing);
- MESH_MODIFY_LIKE_OPR_PROXY(IncrMeshIndexing);
- MESH_MODIFY_LIKE_OPR_PROXY(BatchedIncrMeshIndexing);
- MESH_MODIFY_LIKE_OPR_PROXY(SetMeshIndexing);
- MESH_MODIFY_LIKE_OPR_PROXY(BatchedSetMeshIndexing);
-
- #undef MESH_PROXY_COMMON
- #undef MESH_INDEXING_LIKE_OPR_PROXY
- #undef MESH_MODIFY_LIKE_OPR_PROXY
-
- namespace mesh_indexing {
- class NoReplacementIndexRNG final : public RNG {
- size_t& m_size;
- std::mt19937_64 m_rng;
-
- public:
- NoReplacementIndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {}
-
- void gen(const TensorND& tensor) override {
- std::vector<int> seq;
- for (size_t i = 0; i < m_size; ++i) {
- seq.push_back(i);
- }
- size_t stride = static_cast<size_t>(tensor.layout.stride[0]);
- size_t size = tensor.layout[0];
- if (tensor.layout.ndim == 1) {
- stride = tensor.layout[0];
- size = 1;
- }
- megdnn_assert(stride <= m_size);
-
- auto ptr = tensor.ptr<int>();
- for (size_t n = 0; n < size; ++n) {
- std::set<int> used;
- COMPAT_RANDOM(seq.begin(), seq.end());
- for (size_t step = 0; step < stride; ++step) {
- megdnn_assert(used.size() < m_size);
- ptr[n * stride + step] = seq[step];
- used.insert(seq[step]);
- }
- }
- }
- };
- } // namespace mesh_indexing
-
- } // namespace test
- } // namespace megdnn
|