You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/rocm/benchmarker.h"
  3. #include "test/rocm/fixture.h"
  4. #include "megdnn/oprs.h"
  5. #include "test/common/checker.h"
  6. #include "test/common/index.h"
  7. #include "test/common/indexing_multi_axis_vec.h"
  8. #include <random>
  9. using namespace megdnn;
  10. using namespace test;
  11. namespace {
  12. class OrderedRNG final : public RNG {
  13. public:
  14. void gen(const TensorND& tensor) override {
  15. auto span = tensor.layout.span();
  16. if (tensor.layout.dtype == dtype::Float32()) {
  17. auto ptr = tensor.ptr<float>() + span.low_elem;
  18. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  19. ptr[i] = i;
  20. }
  21. } else {
  22. auto ptr = tensor.ptr<int>() + span.low_elem;
  23. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  24. ptr[i] = i;
  25. }
  26. }
  27. }
  28. };
  29. template <class Opr>
  30. void run_check(Handle* handle) {
  31. // see OprProxyIndexingMultiAxisVecHelper for more details
  32. // execs() give input, output and index layouts
  33. Checker<Opr> checker(handle);
  34. size_t idx_size0, idx_size1;
  35. OrderedRNG rng_inp;
  36. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  37. checker.set_dtype(0, dtype::Float32())
  38. . // data
  39. set_dtype(1, dtype::Float32())
  40. . // value
  41. set_dtype(2, dtype::Int32())
  42. . // idx0
  43. set_dtype(3, dtype::Int32())
  44. . // idx1
  45. set_rng(0, &rng_inp)
  46. .set_rng(1, &rng_inp)
  47. .set_rng(2, &rng0)
  48. .set_rng(3, &rng1);
  49. idx_size0 = 23;
  50. checker.set_proxy({{0}})
  51. .execs({{23}, {100}, {100}})
  52. .execs({{23, 5}, {100, 5}, {100}});
  53. idx_size0 = 2;
  54. idx_size1 = 3;
  55. checker.set_proxy({{0, 1}})
  56. .execs({{2, 3}, {10}, {10}, {10}})
  57. .execs({{2, 3, 5}, {10, 5}, {10}, {10}});
  58. idx_size0 = 4;
  59. idx_size1 = 6;
  60. TensorLayout inp_layout{{3, 4, 5, 6}, dtype::Float32()};
  61. inp_layout.stride[0] *= 8;
  62. inp_layout.stride[1] *= 2;
  63. checker.set_proxy({{1, 3}}).execl({
  64. inp_layout,
  65. {{7, 3, 5}, dtype::Float32()},
  66. {{7}, dtype::Int32()},
  67. {{1}, dtype::Int32()},
  68. });
  69. idx_size0 = 4;
  70. idx_size1 = 5;
  71. checker.set_proxy({{2, 3}}).execs(
  72. {{2, 3, 4, 5, 6, 7}, {2, 3, 10, 6, 7}, {10}, {10}});
  73. idx_size0 = 4;
  74. checker.set_proxy({{1}}).execs({{1, 4}, {1, 1024 * 1024}, {1024 * 1024}});
  75. if (std::is_same<Opr, IndexingIncrMultiAxisVec>::value) {
  76. idx_size0 = 4;
  77. TensorLayout val_layout{{23}, dtype::Float32()};
  78. val_layout.stride[0] = 0;
  79. checker.set_proxy({{0}}).execl(
  80. {{{4}, dtype::Float32()}, val_layout, {{23}, dtype::Int32()}});
  81. }
  82. }
  83. } // namespace
  84. TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC) {
  85. run_check<IndexingMultiAxisVec>(handle_rocm());
  86. Checker<IndexingMultiAxisVec> checker(handle_rocm());
  87. size_t idx_size0;
  88. OrderedRNG rng_inp;
  89. IndexRNG rng0{idx_size0, 2};
  90. checker.set_dtype(0, dtype::Float32())
  91. . // data
  92. set_dtype(1, dtype::Float32())
  93. . // value
  94. set_dtype(2, dtype::Int32())
  95. . // idx0
  96. set_rng(0, &rng_inp)
  97. .set_rng(1, &rng_inp)
  98. .set_rng(2, &rng0);
  99. idx_size0 = 20;
  100. checker.set_proxy({{0}}).execl(
  101. {TensorLayout{{20}, dtype::Float32()}, TensorLayout{{9}, dtype::Float32()},
  102. TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
  103. }
  104. TEST_F(ROCM, INDEXING_INCR_MULTI_AXIS_VEC) {
  105. run_check<IndexingIncrMultiAxisVec>(handle_rocm());
  106. }
  107. TEST_F(ROCM, INDEXING_SET_MULTI_AXIS_VEC) {
  108. Checker<IndexingSetMultiAxisVec> checker(handle_rocm());
  109. OrderedRNG rng;
  110. checker.set_dtype(0, dtype::Float32())
  111. . // data
  112. set_dtype(1, dtype::Float32())
  113. . // value
  114. set_dtype(2, dtype::Int32())
  115. . // idx0
  116. set_rng(0, &rng)
  117. .set_rng(1, &rng)
  118. .set_rng(2, &rng);
  119. checker.set_proxy({{1}}).execs({{5, 8, 3}, {5, 2, 3}, {2}});
  120. }
  121. TEST_F(ROCM_ERROR_INFO, INDEXING_MULTI_AXIS_VEC) {
  122. Checker<IndexingMultiAxisVec> checker(handle_rocm());
  123. UniformIntRNG idx_rng{-5, 5};
  124. checker.set_dtype(0, dtype::Float32())
  125. . // data
  126. set_dtype(1, dtype::Float32())
  127. . // value
  128. set_dtype(2, dtype::Int32())
  129. . // idx
  130. set_rng(2, &idx_rng);
  131. bool failed = false;
  132. ASSERT_EQ(0u, get_error_info().nr_error);
  133. auto on_fail = [&failed, this]() {
  134. failed = true;
  135. auto info = get_error_info();
  136. ASSERT_GE(info.nr_error, 1u);
  137. printf("error msg: ");
  138. printf(info.msg, info.msg_args[0], info.msg_args[1], info.msg_args[2],
  139. info.msg_args[3]);
  140. printf("\n");
  141. };
  142. checker.set_proxy({{0}}).execs({{23}, {100}, {100}});
  143. idx_rng = {-500, 500};
  144. checker.set_expect_exec_fail(on_fail).execs({{23}, {100}, {100}});
  145. ASSERT_TRUE(failed);
  146. }
  147. TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC_BENCHMARK) {
  148. ROCMBenchmarker<IndexingMultiAxisVec> benchmarker(
  149. handle_rocm(), handle_naive(false));
  150. benchmarker.set_display(true);
  151. OrderedRNG rng_inp;
  152. size_t idx_size = 10000;
  153. IndexRNG rng0{idx_size, 3}, rng1{idx_size, 1};
  154. benchmarker.set_dtype(0, dtype::Float32())
  155. .set_dtype(1, dtype::Float32())
  156. .set_dtype(2, dtype::Int32())
  157. .set_dtype(3, dtype::Int32())
  158. .set_rng(0, &rng_inp)
  159. .set_rng(1, &rng_inp)
  160. .set_rng(2, &rng0)
  161. .set_rng(3, &rng1)
  162. .set_proxy({{0, 1}});
  163. auto time_ms =
  164. benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}});
  165. long io = 2 * 1000 * 1000 * dtype::Float32().size();
  166. printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n", (float)(io / 1e9),
  167. (float)(io / (time_ms * 1e6)));
  168. }
  169. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}