You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mesh_indexing.h 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. /**
  2. * \file dnn/test/common/mesh_indexing.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "megdnn/oprs/general.h"
  14. #include "rng.h"
  15. #include "test/common/indexing_multi_axis_vec.h"
  16. #include "test/common/opr_proxy.h"
  17. namespace megdnn {
  18. namespace test {
  19. #define MESH_INDEXING_LIKE_OPR_PROXY(__opr) \
  20. template <> \
  21. struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
  22. using OprProxyIndexingMultiAxisVecHelper:: \
  23. OprProxyIndexingMultiAxisVecHelper; \
  24. void exec(__opr* opr, const TensorNDArray& tensors) const { \
  25. WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes( \
  26. tensors[1].layout, axes, \
  27. tensors.size() - 2)); \
  28. opr->exec(tensors[0], make_index_desc(tensors), tensors[1], \
  29. W.workspace()); \
  30. } \
  31. void deduce_layout(__opr* opr, TensorLayoutArray& layouts) { \
  32. MEGDNN_MARK_USED_VAR(opr); \
  33. MEGDNN_MARK_USED_VAR(layouts); \
  34. opr->deduce_layout(layouts[0], make_index_layout(layouts), \
  35. layouts[1]); \
  36. } \
  37. };
  38. #define MESH_MODIFY_LIKE_OPR_PROXY(__opr) \
  39. template <> \
  40. struct OprProxy<__opr> : public OprProxyIndexingMultiAxisVecHelper { \
  41. using OprProxyIndexingMultiAxisVecHelper:: \
  42. OprProxyIndexingMultiAxisVecHelper; \
  43. void exec(__opr* opr, const TensorNDArray& tensors) const { \
  44. WorkspaceWrapper W(opr->handle(), opr->get_workspace_in_bytes( \
  45. tensors[1].layout, axes, \
  46. tensors.size() - 2)); \
  47. opr->exec(tensors[0], tensors[1], make_index_desc(tensors), \
  48. W.workspace()); \
  49. } \
  50. void deduce_layout(__opr*, TensorLayoutArray&) {} \
  51. };
  52. MESH_INDEXING_LIKE_OPR_PROXY(MeshIndexing);
  53. MESH_INDEXING_LIKE_OPR_PROXY(BatchedMeshIndexing);
  54. MESH_MODIFY_LIKE_OPR_PROXY(IncrMeshIndexing);
  55. MESH_MODIFY_LIKE_OPR_PROXY(BatchedIncrMeshIndexing);
  56. MESH_MODIFY_LIKE_OPR_PROXY(SetMeshIndexing);
  57. MESH_MODIFY_LIKE_OPR_PROXY(BatchedSetMeshIndexing);
  58. #undef MESH_PROXY_COMMON
  59. #undef MESH_INDEXING_LIKE_OPR_PROXY
  60. #undef MESH_MODIFY_LIKE_OPR_PROXY
  61. namespace mesh_indexing {
  62. class NoReplacementIndexRNG final : public RNG {
  63. size_t& m_size;
  64. std::mt19937_64 m_rng;
  65. public:
  66. NoReplacementIndexRNG(size_t& sz, size_t seed) : m_size{sz}, m_rng(seed) {}
  67. void gen(const TensorND& tensor) override {
  68. std::vector<int> seq;
  69. for (size_t i = 0; i < m_size; ++i) {
  70. seq.push_back(i);
  71. }
  72. size_t stride = static_cast<size_t>(tensor.layout.stride[0]);
  73. size_t size = tensor.layout[0];
  74. if (tensor.layout.ndim == 1) {
  75. stride = tensor.layout[0];
  76. size = 1;
  77. }
  78. megdnn_assert(stride <= m_size);
  79. auto ptr = tensor.ptr<int>();
  80. for (size_t n = 0; n < size; ++n) {
  81. std::set<int> used;
  82. COMPAT_RANDOM(seq.begin(), seq.end());
  83. for (size_t step = 0; step < stride; ++step) {
  84. megdnn_assert(used.size() < m_size);
  85. ptr[n * stride + step] = seq[step];
  86. used.insert(seq[step]);
  87. }
  88. }
  89. }
  90. };
  91. } // namespace mesh_indexing
  92. } // namespace test
  93. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台