You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

argsort.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. /**
  2. * \file dnn/test/cuda/argsort.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/rng.h"
  14. #include "test/common/tensor.h"
  15. #include "../src/cuda/argsort/opr_impl.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. namespace {
  19. class ArgsortRNG final : public RNG {
  20. bool m_rev_order = false;
  21. DType m_dtype;
  22. template <typename T>
  23. void fill(T* ptr, int n) {
  24. if (m_rev_order) {
  25. for (int i = 0; i < n; ++i)
  26. ptr[i] = static_cast<T>(n / 2 - i);
  27. } else {
  28. for (int i = 0; i < n; ++i)
  29. ptr[i] = static_cast<T>(i - n / 2);
  30. COMPAT_RANDOM(ptr, ptr + n);
  31. }
  32. }
  33. void gen(const TensorND& tensor) override {
  34. auto n = tensor.layout.total_nr_elems();
  35. if (m_dtype == dtype::Float32{}) {
  36. fill(tensor.ptr<dt_float32>(), n);
  37. } else {
  38. megdnn_assert(m_dtype == dtype::Int32{});
  39. fill(tensor.ptr<dt_int32>(), n);
  40. }
  41. }
  42. public:
  43. ArgsortRNG(DType dt) : m_dtype{dt} {}
  44. void set_rev_order(bool flag) { m_rev_order = flag; }
  45. };
  46. void run_forward_test(Handle* handle, DType dtype) {
  47. Checker<ArgsortForward> checker(handle);
  48. using Param = Argsort::Param;
  49. using Order = Param::Order;
  50. ArgsortRNG rng{dtype};
  51. checker.set_dtype(2, dtype::Int32());
  52. checker.set_dtype(0, dtype).set_rng(0, &rng);
  53. for (size_t i = 3; i < 10240; i *= 2) {
  54. Param param;
  55. param.order = Order::ASCENDING;
  56. checker.set_param(param).execs({{3, i + 1}, {}, {}});
  57. param.order = Order::DESCENDING;
  58. checker.set_param(param).execs({{3, i - 1}, {}, {}});
  59. checker.set_param(param).execs({{13, i + 3}, {}, {}});
  60. }
  61. {
  62. // reverse sort large array
  63. constexpr size_t N = 200003;
  64. rng.set_rev_order(true);
  65. Param param;
  66. param.order = Order::ASCENDING;
  67. checker.set_param(param).execs({{1, N}, {}, {}});
  68. }
  69. }
  70. void run_backward_test(Handle* handle, DType dtype) {
  71. class IdxRng final : public RNG {
  72. void gen(const TensorND& tensor) override {
  73. auto ptr = tensor.ptr<dt_int32>();
  74. auto m = tensor.layout[0], n = tensor.layout[1];
  75. for (size_t i = 0; i < m; ++i) {
  76. for (size_t j = 0; j < n; ++j) {
  77. ptr[j] = j;
  78. }
  79. COMPAT_RANDOM(ptr, ptr + n);
  80. ptr += n;
  81. }
  82. }
  83. } rng;
  84. Checker<ArgsortBackward> checker(handle);
  85. checker.set_dtype(1, dtype::Int32()).set_rng(1, &rng);
  86. checker.set_dtype(0, dtype);
  87. checker.set_dtype(2, dtype);
  88. for (size_t i = 16; i < 4096; i *= 2) {
  89. checker.execs({{3, i}, {3, i}, {3, i}});
  90. checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 3}});
  91. checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 7}});
  92. }
  93. }
  94. } // anonymous namespace
  95. TEST_F(CUDA, ARGSORT_FORWARD_F32) {
  96. run_forward_test(handle_cuda(), dtype::Float32{});
  97. }
  98. TEST_F(CUDA, ARGSORT_FORWARD_I32) {
  99. run_forward_test(handle_cuda(), dtype::Int32{});
  100. }
  101. TEST_F(CUDA, ARGSORT_BACKWARD_F32) {
  102. run_backward_test(handle_cuda(), dtype::Float32{});
  103. }
  104. TEST_F(CUDA, ARGSORT_BACKWARD_I32) {
  105. run_backward_test(handle_cuda(), dtype::Int32{});
  106. }
  107. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台