You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/index.h"
  5. #include "test/common/indexing_multi_axis_vec.h"
  6. #include <random>
  7. using namespace megdnn;
  8. using namespace test;
  9. namespace {
  10. class OrderedRNG final : public RNG {
  11. public:
  12. void gen(const TensorND& tensor) override {
  13. auto span = tensor.layout.span();
  14. if (tensor.layout.dtype == dtype::Float32()) {
  15. auto ptr = tensor.ptr<float>() + span.low_elem;
  16. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  17. ptr[i] = i;
  18. }
  19. } else if (tensor.layout.dtype == dtype::Float16()) {
  20. auto ptr = tensor.ptr<dt_float16>() + span.low_elem;
  21. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  22. ptr[i] = i;
  23. }
  24. } else {
  25. auto ptr = tensor.ptr<int>() + span.low_elem;
  26. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  27. ptr[i] = i;
  28. }
  29. }
  30. }
  31. };
  32. template <class Opr>
  33. void run_check(Handle* handle) {
  34. // see OprProxyIndexingMultiAxisVecHelper for more details
  35. // set_proxy() sets the axes to index on
  36. // execs() give input, output and index layouts
  37. Checker<Opr> checker(handle);
  38. size_t idx_size0, idx_size1;
  39. OrderedRNG rng_inp;
  40. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  41. checker.set_dtype(0, dtype::Float32())
  42. . // data
  43. set_dtype(1, dtype::Float32())
  44. . // value
  45. set_dtype(2, dtype::Int32())
  46. . // idx0
  47. set_dtype(3, dtype::Int32())
  48. . // idx1
  49. set_rng(0, &rng_inp)
  50. .set_rng(1, &rng_inp)
  51. .set_rng(2, &rng0)
  52. .set_rng(3, &rng1);
  53. idx_size0 = 23;
  54. checker.set_proxy({{0}})
  55. .execs({{23}, {100}, {100}})
  56. .execs({{23, 5}, {100, 5}, {100}});
  57. idx_size0 = 2;
  58. idx_size1 = 3;
  59. checker.set_proxy({{0, 1}})
  60. .execs({{2, 3}, {10}, {10}, {10}})
  61. .execs({{2, 3, 5}, {10, 5}, {10}, {10}});
  62. idx_size0 = 4;
  63. idx_size1 = 6;
  64. TensorLayout inp_layout{{3, 4, 5, 6}, dtype::Float32()};
  65. inp_layout.stride[0] *= 8;
  66. inp_layout.stride[1] *= 2;
  67. checker.set_proxy({{1, 3}}).execl({
  68. inp_layout,
  69. {{7, 3, 5}, dtype::Float32()},
  70. {{7}, dtype::Int32()},
  71. {{1}, dtype::Int32()},
  72. });
  73. idx_size0 = 4;
  74. idx_size1 = 5;
  75. checker.set_proxy({{2, 3}}).execs(
  76. {{2, 3, 4, 5, 6, 7}, {2, 3, 10, 6, 7}, {10}, {10}});
  77. idx_size0 = 4;
  78. checker.set_proxy({{1}}).execs({{1, 4}, {1, 1024 * 1024}, {1024 * 1024}});
  79. if (std::is_same<Opr, IndexingIncrMultiAxisVec>::value) {
  80. idx_size0 = 4;
  81. TensorLayout val_layout{{23}, dtype::Float32()};
  82. val_layout.stride[0] = 0;
  83. checker.set_proxy({{0}}).execl(
  84. {{{4}, dtype::Float32()}, val_layout, {{23}, dtype::Int32()}});
  85. }
  86. }
  87. } // namespace
  88. TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
  89. run_check<IndexingMultiAxisVec>(handle_cuda());
  90. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  91. size_t idx_size0;
  92. OrderedRNG rng_inp;
  93. IndexRNG rng0{idx_size0, 2};
  94. checker.set_dtype(0, dtype::Float32())
  95. . // data
  96. set_dtype(1, dtype::Float32())
  97. . // value
  98. set_dtype(2, dtype::Int32())
  99. . // idx0
  100. set_rng(0, &rng_inp)
  101. .set_rng(1, &rng_inp)
  102. .set_rng(2, &rng0);
  103. idx_size0 = 20;
  104. checker.set_proxy({{0}}).execl(
  105. {TensorLayout{{20}, dtype::Float32()}, TensorLayout{{9}, dtype::Float32()},
  106. TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
  107. }
  108. TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC_ND_INDEX) {
  109. run_check<IndexingMultiAxisVec>(handle_cuda());
  110. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  111. OrderedRNG rng;
  112. checker.set_dtype(0, dtype::Float32())
  113. .set_dtype(1, dtype::Float32())
  114. .set_dtype(2, dtype::Int32())
  115. .set_dtype(3, dtype::Int32())
  116. .set_dtype(4, dtype::Int32())
  117. .set_rng(0, &rng)
  118. .set_rng(1, &rng)
  119. .set_rng(2, &rng)
  120. .set_rng(3, &rng)
  121. .set_rng(4, &rng);
  122. checker.set_proxy({{1, 2, 3}})
  123. .execs({{5, 5, 6, 7, 3}, {5, 2, 3, 4, 3}, {3, 1}, {2, 1, 1}, {1, 4}});
  124. }
  125. TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
  126. run_check<IndexingIncrMultiAxisVec>(handle_cuda());
  127. Checker<IndexingIncrMultiAxisVec> checker(handle_cuda());
  128. OrderedRNG rng;
  129. checker.set_dtype(0, dtype::Float16())
  130. . // data
  131. set_dtype(1, dtype::Float16())
  132. . // value
  133. set_dtype(2, dtype::Int32())
  134. . // idx0
  135. set_rng(0, &rng)
  136. .set_rng(1, &rng)
  137. .set_rng(2, &rng);
  138. checker.set_proxy({{1}}).execs({{5, 8, 3}, {5, 2, 3}, {2}});
  139. }
  140. TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) {
  141. Checker<IndexingSetMultiAxisVec> checker(handle_cuda());
  142. OrderedRNG rng;
  143. checker.set_dtype(0, dtype::Float32())
  144. . // data
  145. set_dtype(1, dtype::Float32())
  146. . // value
  147. set_dtype(2, dtype::Int32())
  148. . // idx0
  149. set_rng(0, &rng)
  150. .set_rng(1, &rng)
  151. .set_rng(2, &rng);
  152. checker.set_proxy({{1}}).execs({{5, 8, 3}, {5, 2, 3}, {2}});
  153. }
  154. TEST_F(CUDA_ERROR_INFO, INDEXING_MULTI_AXIS_VEC) {
  155. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  156. UniformIntRNG idx_rng{-5, 5};
  157. checker.set_dtype(0, dtype::Float32())
  158. . // data
  159. set_dtype(1, dtype::Float32())
  160. . // value
  161. set_dtype(2, dtype::Int32())
  162. . // idx
  163. set_rng(2, &idx_rng);
  164. bool failed = false;
  165. ASSERT_EQ(0u, get_error_info().nr_error);
  166. auto on_fail = [&failed, this]() {
  167. failed = true;
  168. auto info = get_error_info();
  169. ASSERT_GE(info.nr_error, 1u);
  170. printf("error msg: ");
  171. printf(info.msg, info.msg_args[0], info.msg_args[1], info.msg_args[2],
  172. info.msg_args[3]);
  173. printf("\n");
  174. };
  175. checker.set_proxy({{0}}).execs({{23}, {100}, {100}});
  176. idx_rng = {-500, 500};
  177. checker.set_expect_exec_fail(on_fail).execs({{23}, {100}, {100}});
  178. ASSERT_TRUE(failed);
  179. }
  180. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}