You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. /**
  2. * \file dnn/test/cuda/indexing_multi_axis_vec.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/index.h"
  15. #include "test/common/indexing_multi_axis_vec.h"
  16. #include <random>
  17. using namespace megdnn;
  18. using namespace test;
  19. namespace {
  20. class OrderedRNG final : public RNG {
  21. public:
  22. void gen(const TensorND& tensor) override {
  23. auto span = tensor.layout.span();
  24. if (tensor.layout.dtype == dtype::Float32()) {
  25. auto ptr = tensor.ptr<float>() + span.low_elem;
  26. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  27. ptr[i] = i;
  28. }
  29. } else if (tensor.layout.dtype == dtype::Float16()) {
  30. auto ptr = tensor.ptr<dt_float16>() + span.low_elem;
  31. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  32. ptr[i] = i;
  33. }
  34. } else {
  35. auto ptr = tensor.ptr<int>() + span.low_elem;
  36. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  37. ptr[i] = i;
  38. }
  39. }
  40. }
  41. };
  42. template <class Opr>
  43. void run_check(Handle* handle) {
  44. // see OprProxyIndexingMultiAxisVecHelper for more details
  45. // set_proxy() sets the axes to index on
  46. // execs() give input, output and index layouts
  47. Checker<Opr> checker(handle);
  48. size_t idx_size0, idx_size1;
  49. OrderedRNG rng_inp;
  50. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  51. checker.set_dtype(0, dtype::Float32())
  52. . // data
  53. set_dtype(1, dtype::Float32())
  54. . // value
  55. set_dtype(2, dtype::Int32())
  56. . // idx0
  57. set_dtype(3, dtype::Int32())
  58. . // idx1
  59. set_rng(0, &rng_inp)
  60. .set_rng(1, &rng_inp)
  61. .set_rng(2, &rng0)
  62. .set_rng(3, &rng1);
  63. idx_size0 = 23;
  64. checker.set_proxy({{0}})
  65. .execs({{23}, {100}, {100}})
  66. .execs({{23, 5}, {100, 5}, {100}});
  67. idx_size0 = 2;
  68. idx_size1 = 3;
  69. checker.set_proxy({{0, 1}})
  70. .execs({{2, 3}, {10}, {10}, {10}})
  71. .execs({{2, 3, 5}, {10, 5}, {10}, {10}});
  72. idx_size0 = 4;
  73. idx_size1 = 6;
  74. TensorLayout inp_layout{{3, 4, 5, 6}, dtype::Float32()};
  75. inp_layout.stride[0] *= 8;
  76. inp_layout.stride[1] *= 2;
  77. checker.set_proxy({{1, 3}}).execl({
  78. inp_layout,
  79. {{7, 3, 5}, dtype::Float32()},
  80. {{7}, dtype::Int32()},
  81. {{1}, dtype::Int32()},
  82. });
  83. idx_size0 = 4;
  84. idx_size1 = 5;
  85. checker.set_proxy({{2, 3}}).execs(
  86. {{2, 3, 4, 5, 6, 7}, {2, 3, 10, 6, 7}, {10}, {10}});
  87. idx_size0 = 4;
  88. checker.set_proxy({{1}}).execs({{1, 4}, {1, 1024 * 1024}, {1024 * 1024}});
  89. if (std::is_same<Opr, IndexingIncrMultiAxisVec>::value) {
  90. idx_size0 = 4;
  91. TensorLayout val_layout{{23}, dtype::Float32()};
  92. val_layout.stride[0] = 0;
  93. checker.set_proxy({{0}}).execl(
  94. {{{4}, dtype::Float32()}, val_layout, {{23}, dtype::Int32()}});
  95. }
  96. }
  97. } // namespace
  98. TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
  99. run_check<IndexingMultiAxisVec>(handle_cuda());
  100. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  101. size_t idx_size0;
  102. OrderedRNG rng_inp;
  103. IndexRNG rng0{idx_size0, 2};
  104. checker.set_dtype(0, dtype::Float32())
  105. . // data
  106. set_dtype(1, dtype::Float32())
  107. . // value
  108. set_dtype(2, dtype::Int32())
  109. . // idx0
  110. set_rng(0, &rng_inp)
  111. .set_rng(1, &rng_inp)
  112. .set_rng(2, &rng0);
  113. idx_size0 = 20;
  114. checker.set_proxy({{0}}).execl(
  115. {TensorLayout{{20}, dtype::Float32()}, TensorLayout{{9}, dtype::Float32()},
  116. TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
  117. }
  118. TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC_ND_INDEX) {
  119. run_check<IndexingMultiAxisVec>(handle_cuda());
  120. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  121. OrderedRNG rng;
  122. checker.set_dtype(0, dtype::Float32())
  123. .set_dtype(1, dtype::Float32())
  124. .set_dtype(2, dtype::Int32())
  125. .set_dtype(3, dtype::Int32())
  126. .set_dtype(4, dtype::Int32())
  127. .set_rng(0, &rng)
  128. .set_rng(1, &rng)
  129. .set_rng(2, &rng)
  130. .set_rng(3, &rng)
  131. .set_rng(4, &rng);
  132. checker.set_proxy({{1, 2, 3}})
  133. .execs({{5, 5, 6, 7, 3}, {5, 2, 3, 4, 3}, {3, 1}, {2, 1, 1}, {1, 4}});
  134. }
  135. TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
  136. run_check<IndexingIncrMultiAxisVec>(handle_cuda());
  137. Checker<IndexingIncrMultiAxisVec> checker(handle_cuda());
  138. OrderedRNG rng;
  139. checker.set_dtype(0, dtype::Float16())
  140. . // data
  141. set_dtype(1, dtype::Float16())
  142. . // value
  143. set_dtype(2, dtype::Int32())
  144. . // idx0
  145. set_rng(0, &rng)
  146. .set_rng(1, &rng)
  147. .set_rng(2, &rng);
  148. checker.set_proxy({{1}}).execs({{5, 8, 3}, {5, 2, 3}, {2}});
  149. }
  150. TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) {
  151. Checker<IndexingSetMultiAxisVec> checker(handle_cuda());
  152. OrderedRNG rng;
  153. checker.set_dtype(0, dtype::Float32())
  154. . // data
  155. set_dtype(1, dtype::Float32())
  156. . // value
  157. set_dtype(2, dtype::Int32())
  158. . // idx0
  159. set_rng(0, &rng)
  160. .set_rng(1, &rng)
  161. .set_rng(2, &rng);
  162. checker.set_proxy({{1}}).execs({{5, 8, 3}, {5, 2, 3}, {2}});
  163. }
  164. TEST_F(CUDA_ERROR_INFO, INDEXING_MULTI_AXIS_VEC) {
  165. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  166. UniformIntRNG idx_rng{-5, 5};
  167. checker.set_dtype(0, dtype::Float32())
  168. . // data
  169. set_dtype(1, dtype::Float32())
  170. . // value
  171. set_dtype(2, dtype::Int32())
  172. . // idx
  173. set_rng(2, &idx_rng);
  174. bool failed = false;
  175. ASSERT_EQ(0u, get_error_info().nr_error);
  176. auto on_fail = [&failed, this]() {
  177. failed = true;
  178. auto info = get_error_info();
  179. ASSERT_GE(info.nr_error, 1u);
  180. printf("error msg: ");
  181. printf(info.msg, info.msg_args[0], info.msg_args[1], info.msg_args[2],
  182. info.msg_args[3]);
  183. printf("\n");
  184. };
  185. checker.set_proxy({{0}}).execs({{23}, {100}, {100}});
  186. idx_rng = {-500, 500};
  187. checker.set_expect_exec_fail(on_fail).execs({{23}, {100}, {100}});
  188. ASSERT_TRUE(failed);
  189. }
  190. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台