You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

topk.cpp 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. /**
  2. * \file dnn/test/common/topk.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/topk.h"
  12. #include "megdnn/dtype.h"
  13. #include "megdnn/oprs/general.h"
  14. #include "test/common/checker.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. namespace {
  18. class EqualValueRng final : public RNG {
  19. std::mt19937_64 m_rng{23};
  20. public:
  21. void gen(const TensorND& tensor) override {
  22. memset(tensor.raw_ptr, 0, tensor.layout.span().dist_byte());
  23. ASSERT_EQ(2u, tensor.layout.ndim);
  24. size_t m = tensor.layout[0], n = tensor.layout[1];
  25. for (size_t i = 0; i < m; ++i) {
  26. int pos0 = m_rng() % n, pos1;
  27. do {
  28. pos1 = m_rng() % n;
  29. } while (pos0 == pos1);
  30. pos0 += i * n;
  31. pos1 += i * n;
  32. #define CASE(ev, dt) \
  33. case DTypeEnum::ev: { \
  34. auto p = tensor.ptr<dt>(); \
  35. p[pos0] = p[pos1] = static_cast<dt>(-1); \
  36. break; \
  37. }
  38. switch (tensor.layout.dtype.enumv()) {
  39. CASE(Float32, float);
  40. CASE(Int32, int);
  41. DNN_INC_FLOAT16(CASE(Float16, half_float::half));
  42. default:
  43. megdnn_throw("bad dtype");
  44. }
  45. }
  46. #undef CASE
  47. }
  48. };
  49. } // namespace
  50. template <typename Dtype>
  51. void test::run_topk_test(Handle* handle) {
  52. Checker<TopK> checker{handle};
  53. using Mode = TopK::Param::Mode;
  54. bool tie_breaking_mode = false;
  55. Mode cur_mode;
  56. auto output_canonizer = [&](const CheckerHelper::TensorValueArray& arr) {
  57. if (cur_mode == Mode::KTH_ONLY) {
  58. return;
  59. }
  60. auto pinp = arr[0].ptr<typename DTypeTrait<Dtype>::ctype>();
  61. auto pval = arr[1].ptr<typename DTypeTrait<Dtype>::ctype>();
  62. auto pidx = arr.at(2).ptr<int>();
  63. size_t m = arr[1].layout[0], n = arr[1].layout[1];
  64. using idx_val = std::pair<int, typename DTypeTrait<Dtype>::ctype>;
  65. std::vector<idx_val> data(n);
  66. auto compare = [](const idx_val& it1, const idx_val& it2) {
  67. return (it1.second > it2.second);
  68. };
  69. for (size_t i = 0; i < m; ++i) {
  70. if (cur_mode == Mode::VALUE_IDX_NOSORT) {
  71. // sort output pairs to canonize
  72. for (size_t j = 0; j < n; ++j) {
  73. data[j].first = pidx[i * n + j];
  74. data[j].second = pval[i * n + j];
  75. }
  76. std::sort(data.begin(), data.end(), compare);
  77. for (size_t j = 0; j < n; ++j) {
  78. pidx[i * n + j] = data[j].first;
  79. pval[i * n + j] = data[j].second;
  80. }
  81. }
  82. if (tie_breaking_mode) {
  83. // check if indices are correct and mark all indices to be zero
  84. for (size_t j = 0; j < n; ++j) {
  85. auto idx = pidx[i * n + j];
  86. auto val = pval[i * n + j];
  87. // + 0 can change the type, such as changing half to float
  88. ASSERT_EQ(pinp[i * arr[0].layout[1] + idx] + 0, val + 0);
  89. pidx[i * n + j] = 0;
  90. }
  91. }
  92. }
  93. };
  94. auto run = [&](int k, size_t m, size_t n, Mode mode, int lda = 0) {
  95. if (::testing::Test::HasFailure()) {
  96. return;
  97. }
  98. cur_mode = mode;
  99. checker.set_proxy(k);
  100. checker.set_param(mode);
  101. TensorLayout layout{{m, n}, Dtype{}};
  102. if (lda) {
  103. layout.stride[0] = lda;
  104. }
  105. checker.set_output_canonizer(output_canonizer);
  106. if (mode == Mode::KTH_ONLY) {
  107. checker.execl({layout, {}});
  108. } else {
  109. checker.execl({layout, {}, {}});
  110. }
  111. if (!checker.prev_succ()) {
  112. fprintf(stderr,
  113. "topk failed for (%zu,%zu):%d mode=%d cont=%d tie=%d\n", m,
  114. n, k, static_cast<int>(mode), !lda, tie_breaking_mode);
  115. return;
  116. }
  117. };
  118. std::unique_ptr<IIDRNG> rng0;
  119. std::unique_ptr<RNG> rngf16;
  120. std::unique_ptr<NoReplacementRNG> rng1;
  121. switch (DTypeTrait<Dtype>::enumv) {
  122. case DTypeEnum::Float32: {
  123. rng0 = std::make_unique<UniformFloatRNG>(-100.f, 100.f);
  124. rng1 = std::make_unique<NoReplacementRNG>(rng0.get());
  125. checker.set_rng(0, rng1.get());
  126. break;
  127. }
  128. case DTypeEnum::Int32: {
  129. rng0 = std::make_unique<UniformIntRNG>(INT_MIN, INT_MAX);
  130. rng1 = std::make_unique<NoReplacementRNG>(rng0.get());
  131. checker.set_rng(0, rng1.get());
  132. break;
  133. }
  134. case DTypeEnum::Float16: {
  135. rngf16 = std::make_unique<Float16PeriodicalRNG>();
  136. checker.set_rng(0, rngf16.get());
  137. break;
  138. }
  139. default: {
  140. megdnn_throw(
  141. ssprintf("only float32,int32 and float16 supported for "
  142. "cuda and opencl topk"));
  143. }
  144. }
  145. for (auto mode :
  146. {Mode::KTH_ONLY, Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
  147. run(1, 1, 1, mode);
  148. run(-1, 1, 1, mode);
  149. run(1, 23, 1, mode);
  150. run(1, 23, 100, mode);
  151. run(-1, 23, 100, mode);
  152. run(5, 23, 100, mode);
  153. run(-7, 23, 100, mode);
  154. run(23, 3, 50001, mode);
  155. run(5, 123, 3, mode); // equiv to sort
  156. run(-5, 123, 3, mode); // equiv to rev sort
  157. run(5, 3, 1231, mode, 2000); // non contig
  158. //! opencl does not support large batch. fix it in the future.
  159. #if MGB_CUDA
  160. run(3, 70000, 5, mode, 10); // non contig
  161. #endif
  162. }
  163. // special case to check if tie-break is correct
  164. auto tie_rng = std::make_unique<EqualValueRng>();
  165. tie_breaking_mode = true;
  166. checker.set_rng(0, tie_rng.get());
  167. for (auto mode : {Mode::VALUE_IDX_NOSORT, Mode::VALUE_IDX_SORTED}) {
  168. run(3, 1, 5, mode);
  169. run(3, 25, 4567, mode);
  170. run(8, 132, 10, mode);
  171. }
  172. }
  173. namespace megdnn {
  174. namespace test {
  175. #define INST(t) template void run_topk_test<t>(Handle*)
  176. INST(dtype::Float32);
  177. INST(dtype::Int32);
  178. DNN_INC_FLOAT16(INST(dtype::Float16));
  179. #undef INST
  180. }
  181. } // namespace megdnn
  182. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台