You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

argsort.cpp 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #include "test/rocm/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/rng.h"
  4. #include "test/common/tensor.h"
  5. #include "../src/rocm/argsort/opr_impl.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. namespace {
  9. class ArgsortRNG final : public RNG {
  10. bool m_rev_order = false;
  11. DType m_dtype;
  12. template <typename T>
  13. void fill(T* ptr, int n) {
  14. if (m_rev_order) {
  15. for (int i = 0; i < n; ++i)
  16. ptr[i] = static_cast<T>(n / 2 - i);
  17. } else {
  18. for (int i = 0; i < n; ++i)
  19. ptr[i] = static_cast<T>(i - n / 2);
  20. COMPAT_RANDOM(ptr, ptr + n);
  21. }
  22. }
  23. void gen(const TensorND& tensor) override {
  24. auto n = tensor.layout.total_nr_elems();
  25. if (m_dtype == dtype::Float32{}) {
  26. fill(tensor.ptr<dt_float32>(), n);
  27. } else {
  28. megdnn_assert(m_dtype == dtype::Int32{});
  29. fill(tensor.ptr<dt_int32>(), n);
  30. }
  31. }
  32. public:
  33. ArgsortRNG(DType dt) : m_dtype{dt} {}
  34. void set_rev_order(bool flag) { m_rev_order = flag; }
  35. };
  36. void run_forward_test(Handle* handle, DType dtype) {
  37. Checker<ArgsortForward> checker(handle);
  38. using Param = Argsort::Param;
  39. using Order = Param::Order;
  40. ArgsortRNG rng{dtype};
  41. checker.set_dtype(2, dtype::Int32());
  42. checker.set_dtype(0, dtype).set_rng(0, &rng);
  43. for (size_t i = 3; i < 10240; i *= 2) {
  44. Param param;
  45. param.order = Order::ASCENDING;
  46. checker.set_param(param).execs({{3, i + 1}, {}, {}});
  47. param.order = Order::DESCENDING;
  48. checker.set_param(param).execs({{3, i - 1}, {}, {}});
  49. checker.set_param(param).execs({{13, i + 3}, {}, {}});
  50. }
  51. {
  52. // reverse sort large array
  53. constexpr size_t N = 200003;
  54. rng.set_rev_order(true);
  55. Param param;
  56. param.order = Order::ASCENDING;
  57. checker.set_param(param).execs({{1, N}, {}, {}});
  58. }
  59. }
  60. void run_backward_test(Handle* handle, DType dtype) {
  61. class IdxRng final : public RNG {
  62. void gen(const TensorND& tensor) override {
  63. auto ptr = tensor.ptr<dt_int32>();
  64. auto m = tensor.layout[0], n = tensor.layout[1];
  65. for (size_t i = 0; i < m; ++i) {
  66. for (size_t j = 0; j < n; ++j) {
  67. ptr[j] = j;
  68. }
  69. COMPAT_RANDOM(ptr, ptr + n);
  70. ptr += n;
  71. }
  72. }
  73. } rng;
  74. Checker<ArgsortBackward> checker(handle);
  75. checker.set_dtype(1, dtype::Int32()).set_rng(1, &rng);
  76. checker.set_dtype(0, dtype);
  77. checker.set_dtype(2, dtype);
  78. for (size_t i = 16; i < 4096; i *= 2) {
  79. checker.execs({{3, i}, {3, i}, {3, i}});
  80. checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 3}});
  81. checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 7}});
  82. }
  83. }
  84. } // anonymous namespace
  85. TEST_F(ROCM, ARGSORT_FORWARD_F32) {
  86. run_forward_test(handle_rocm(), dtype::Float32{});
  87. }
  88. TEST_F(ROCM, ARGSORT_FORWARD_I32) {
  89. run_forward_test(handle_rocm(), dtype::Int32{});
  90. }
  91. TEST_F(ROCM, ARGSORT_BACKWARD_F32) {
  92. run_backward_test(handle_rocm(), dtype::Float32{});
  93. }
  94. TEST_F(ROCM, ARGSORT_BACKWARD_I32) {
  95. run_backward_test(handle_rocm(), dtype::Int32{});
  96. }
  97. // vim: syntax=cpp.doxygen