You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mesh_indexing.h 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. /**
  2. * \file dnn/test/common/mesh_indexing.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/oprs/general.h"
  14. #include "rng.h"
  15. #include "test/common/indexing_multi_axis_vec.h"
  16. #include "test/common/opr_proxy.h"
  17. namespace megdnn {
  18. namespace test {
  19. #define MESH_INDEXING_LIKE_OPR_PROXY(__opr) \
  20. template <> \
  21. struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
  22. using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
  23. void exec(__opr* opr, const TensorNDArray& tensors) const { \
  24. WorkspaceWrapper W( \
  25. opr->handle(), \
  26. opr->get_workspace_in_bytes( \
  27. tensors[1].layout, axes, tensors.size() - 2, 1)); \
  28. opr->exec( \
  29. tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \
  30. } \
  31. void deduce_layout(__opr* opr, TensorLayoutArray& layouts) { \
  32. MEGDNN_MARK_USED_VAR(opr); \
  33. MEGDNN_MARK_USED_VAR(layouts); \
  34. opr->deduce_layout(layouts[0], make_index_layout(layouts), layouts[1]); \
  35. } \
  36. };
  37. #define MESH_MODIFY_LIKE_OPR_PROXY(__opr) \
  38. template <> \
  39. struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
  40. using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; \
  41. void exec(__opr* opr, const TensorNDArray& tensors) const { \
  42. WorkspaceWrapper W( \
  43. opr->handle(), \
  44. opr->get_workspace_in_bytes( \
  45. tensors[1].layout, axes, tensors.size() - 2, 1)); \
  46. opr->exec( \
  47. tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \
  48. } \
  49. void deduce_layout(__opr*, TensorLayoutArray&) {} \
  50. };
  51. MESH_INDEXING_LIKE_OPR_PROXY(MeshIndexing);
  52. MESH_INDEXING_LIKE_OPR_PROXY(BatchedMeshIndexing);
  53. MESH_MODIFY_LIKE_OPR_PROXY(IncrMeshIndexing);
  54. MESH_MODIFY_LIKE_OPR_PROXY(BatchedIncrMeshIndexing);
  55. MESH_MODIFY_LIKE_OPR_PROXY(SetMeshIndexing);
  56. MESH_MODIFY_LIKE_OPR_PROXY(BatchedSetMeshIndexing);
  57. #undef MESH_PROXY_COMMON
  58. #undef MESH_INDEXING_LIKE_OPR_PROXY
  59. #undef MESH_MODIFY_LIKE_OPR_PROXY
  60. namespace mesh_indexing {
  61. class NoReplacementIndexRNG final : public RNG {
  62. size_t& m_size;
  63. std::mt19937_64 m_rng;
  64. public:
  65. NoReplacementIndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {}
  66. void gen(const TensorND& tensor) override {
  67. std::vector<int> seq;
  68. for (size_t i = 0; i < m_size; ++i) {
  69. seq.push_back(i);
  70. }
  71. size_t stride = static_cast<size_t>(tensor.layout.stride[0]);
  72. size_t size = tensor.layout[0];
  73. if (tensor.layout.ndim == 1) {
  74. stride = tensor.layout[0];
  75. size = 1;
  76. }
  77. megdnn_assert(stride <= m_size);
  78. auto ptr = tensor.ptr<int>();
  79. for (size_t n = 0; n < size; ++n) {
  80. std::set<int> used;
  81. COMPAT_RANDOM(seq.begin(), seq.end());
  82. for (size_t step = 0; step < stride; ++step) {
  83. megdnn_assert(used.size() < m_size);
  84. ptr[n * stride + step] = seq[step];
  85. used.insert(seq[step]);
  86. }
  87. }
  88. }
  89. };
  90. } // namespace mesh_indexing
  91. } // namespace test
  92. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台