You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

topk.cpp 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #include "test/common/topk.h"
  2. #include "megdnn/dtype.h"
  3. #include "megdnn/oprs/general.h"
  4. #include "test/common/checker.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. namespace {
  8. class EqualValueRng final : public RNG {
  9. std::mt19937_64 m_rng{23};
  10. public:
  11. void gen(const TensorND& tensor) override {
  12. memset(tensor.raw_ptr(), 0, tensor.layout.span().dist_byte());
  13. ASSERT_EQ(2u, tensor.layout.ndim);
  14. size_t m = tensor.layout[0], n = tensor.layout[1];
  15. for (size_t i = 0; i < m; ++i) {
  16. int pos0 = m_rng() % n, pos1;
  17. do {
  18. pos1 = m_rng() % n;
  19. } while (pos0 == pos1);
  20. pos0 += i * n;
  21. pos1 += i * n;
  22. #define CASE(ev, dt) \
  23. case DTypeEnum::ev: { \
  24. auto p = tensor.ptr<dt>(); \
  25. p[pos0] = p[pos1] = static_cast<dt>(-1); \
  26. break; \
  27. }
  28. switch (tensor.layout.dtype.enumv()) {
  29. CASE(Float32, float);
  30. CASE(Int32, int);
  31. DNN_INC_FLOAT16(CASE(Float16, half_float::half));
  32. default:
  33. megdnn_throw("bad dtype");
  34. }
  35. }
  36. #undef CASE
  37. }
  38. };
  39. } // namespace
  40. template <typename Dtype>
  41. void test::run_topk_test(Handle* handle) {
  42. Checker<TopK> checker{handle};
  43. using Mode = TopK::Param::Mode;
  44. bool tie_breaking_mode = false;
  45. Mode cur_mode;
  46. auto output_canonizer = [&](const CheckerHelper::TensorValueArray& arr) {
  47. if (cur_mode == Mode::KTH_ONLY) {
  48. return;
  49. }
  50. auto pinp = arr[0].ptr<typename DTypeTrait<Dtype>::ctype>();
  51. auto pval = arr[1].ptr<typename DTypeTrait<Dtype>::ctype>();
  52. auto pidx = arr.at(2).ptr<int>();
  53. size_t m = arr[1].layout[0], n = arr[1].layout[1];
  54. using idx_val = std::pair<int, typename DTypeTrait<Dtype>::ctype>;
  55. std::vector<idx_val> data(n);
  56. auto compare = [](const idx_val& it1, const idx_val& it2) {
  57. return (it1.second > it2.second);
  58. };
  59. for (size_t i = 0; i < m; ++i) {
  60. if (cur_mode == Mode::VALUE_IDX_NOSORT) {
  61. // sort output pairs to canonize
  62. for (size_t j = 0; j < n; ++j) {
  63. data[j].first = pidx[i * n + j];
  64. data[j].second = pval[i * n + j];
  65. }
  66. std::sort(data.begin(), data.end(), compare);
  67. for (size_t j = 0; j < n; ++j) {
  68. pidx[i * n + j] = data[j].first;
  69. pval[i * n + j] = data[j].second;
  70. }
  71. }
  72. if (tie_breaking_mode) {
  73. // check if indices are correct and mark all indices to be zero
  74. for (size_t j = 0; j < n; ++j) {
  75. auto idx = pidx[i * n + j];
  76. auto val = pval[i * n + j];
  77. // + 0 can change the type, such as changing half to float
  78. ASSERT_EQ(pinp[i * arr[0].layout[1] + idx] + 0, val + 0);
  79. pidx[i * n + j] = 0;
  80. }
  81. }
  82. }
  83. };
  84. auto run = [&](int k, size_t m, size_t n, Mode mode, int lda = 0) {
  85. if (::testing::Test::HasFailure()) {
  86. return;
  87. }
  88. cur_mode = mode;
  89. checker.set_proxy(k);
  90. checker.set_param(mode);
  91. TensorLayout layout{{m, n}, Dtype{}};
  92. if (lda) {
  93. layout.stride[0] = lda;
  94. }
  95. checker.set_output_canonizer(output_canonizer);
  96. if (mode == Mode::KTH_ONLY) {
  97. checker.execl({layout, {}});
  98. } else {
  99. checker.execl({layout, {}, {}});
  100. }
  101. if (!checker.prev_succ()) {
  102. fprintf(stderr, "topk failed for (%zu,%zu):%d mode=%d cont=%d tie=%d\n", m,
  103. n, k, static_cast<int>(mode), !lda, tie_breaking_mode);
  104. return;
  105. }
  106. };
  107. std::unique_ptr<IIDRNG> rng0;
  108. std::unique_ptr<RNG> rngf16;
  109. std::unique_ptr<NoReplacementRNG> rng1;
  110. switch (DTypeTrait<Dtype>::enumv) {
  111. case DTypeEnum::Float32: {
  112. rng0 = std::make_unique<UniformFloatRNG>(-100.f, 100.f);
  113. rng1 = std::make_unique<NoReplacementRNG>(rng0.get());
  114. checker.set_rng(0, rng1.get());
  115. break;
  116. }
  117. case DTypeEnum::Int32: {
  118. rng0 = std::make_unique<UniformIntRNG>(INT_MIN, INT_MAX);
  119. rng1 = std::make_unique<NoReplacementRNG>(rng0.get());
  120. checker.set_rng(0, rng1.get());
  121. break;
  122. }
  123. case DTypeEnum::Float16: {
  124. rngf16 = std::make_unique<Float16PeriodicalRNG>();
  125. checker.set_rng(0, rngf16.get());
  126. break;
  127. }
  128. default: {
  129. megdnn_throw(
  130. ssprintf("only float32,int32 and float16 supported for "
  131. "cuda and opencl topk"));
  132. }
  133. }
  134. for (auto mode : {Mode::KTH_ONLY, Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
  135. run(1, 1, 1, mode);
  136. run(-1, 1, 1, mode);
  137. run(1, 23, 1, mode);
  138. run(1, 23, 100, mode);
  139. run(-1, 23, 100, mode);
  140. run(5, 23, 100, mode);
  141. run(-7, 23, 100, mode);
  142. run(23, 3, 50001, mode);
  143. run(5, 123, 3, mode); // equiv to sort
  144. run(-5, 123, 3, mode); // equiv to rev sort
  145. run(5, 3, 1231, mode, 2000); // non contig
  146. //! opencl does not support large batch. fix it in the future.
  147. #if MGB_CUDA
  148. run(3, 70000, 5, mode, 10); // non contig
  149. #endif
  150. }
  151. // special case to check if tie-break is correct
  152. auto tie_rng = std::make_unique<EqualValueRng>();
  153. tie_breaking_mode = true;
  154. checker.set_rng(0, tie_rng.get());
  155. for (auto mode : {Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
  156. run(3, 1, 5, mode);
  157. run(3, 25, 4567, mode);
  158. run(8, 132, 10, mode);
  159. }
  160. }
  161. namespace megdnn {
  162. namespace test {
  163. #define INST(t) template void run_topk_test<t>(Handle*)
  164. INST(dtype::Float32);
  165. INST(dtype::Int32);
  166. DNN_INC_FLOAT16(INST(dtype::Float16));
  167. #undef INST
  168. } // namespace test
  169. } // namespace megdnn
  170. // vim: syntax=cpp.doxygen