You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. /**
  2. * \file dnn/test/cuda/indexing_multi_axis_vec.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/indexing_multi_axis_vec.h"
  15. #include "test/common/index.h"
  16. #include <random>
  17. using namespace megdnn;
  18. using namespace test;
  19. namespace {
  20. class OrderedRNG final: public RNG {
  21. public:
  22. void gen(const TensorND &tensor) override {
  23. auto span = tensor.layout.span();
  24. if (tensor.layout.dtype == dtype::Float32()) {
  25. auto ptr = tensor.ptr<float>() + span.low_elem;
  26. for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
  27. ptr[i] = i;
  28. }
  29. } else if (tensor.layout.dtype == dtype::Float16()) {
  30. auto ptr = tensor.ptr<dt_float16>() + span.low_elem;
  31. for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
  32. ptr[i] = i;
  33. }
  34. } else {
  35. auto ptr = tensor.ptr<int>() + span.low_elem;
  36. for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
  37. ptr[i] = i;
  38. }
  39. }
  40. }
  41. };
  42. template<class Opr>
  43. void run_check(Handle *handle) {
  44. // see OprProxyIndexingMultiAxisVecHelper for more details
  45. // set_proxy() sets the axes to index on
  46. // execs() give input, output and index layouts
  47. Checker<Opr> checker(handle);
  48. size_t idx_size0, idx_size1;
  49. OrderedRNG rng_inp;
  50. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  51. checker.
  52. set_dtype(0, dtype::Float32()). // data
  53. set_dtype(1, dtype::Float32()). // value
  54. set_dtype(2, dtype::Int32()). // idx0
  55. set_dtype(3, dtype::Int32()). // idx1
  56. set_rng(0, &rng_inp).
  57. set_rng(1, &rng_inp).
  58. set_rng(2, &rng0).
  59. set_rng(3, &rng1);
  60. idx_size0 = 23;
  61. checker.
  62. set_proxy({{0}}).
  63. execs({{23}, {100}, {100}}).
  64. execs({{23, 5}, {100, 5}, {100}});
  65. idx_size0 = 2;
  66. idx_size1 = 3;
  67. checker.
  68. set_proxy({{0, 1}}).
  69. execs({{2, 3}, {10}, {10}, {10}}).
  70. execs({{2, 3, 5}, {10, 5}, {10}, {10}});
  71. idx_size0 = 4;
  72. idx_size1 = 6;
  73. TensorLayout inp_layout{{3, 4, 5, 6}, dtype::Float32()};
  74. inp_layout.stride[0] *= 8;
  75. inp_layout.stride[1] *= 2;
  76. checker.
  77. set_proxy({{1, 3}}).
  78. execl({inp_layout,
  79. {{7, 3, 5}, dtype::Float32()},
  80. {{7}, dtype::Int32()},
  81. {{1}, dtype::Int32()},
  82. });
  83. idx_size0 = 4;
  84. idx_size1 = 5;
  85. checker.
  86. set_proxy({{2, 3}}).
  87. execs({{2, 3, 4, 5, 6, 7}, {2, 3, 10, 6, 7}, {10}, {10}});
  88. idx_size0 = 4;
  89. checker.
  90. set_proxy({{1}}).
  91. execs({{1, 4}, {1, 1024 * 1024}, {1024 * 1024}});
  92. if (std::is_same<Opr, IndexingIncrMultiAxisVec>::value) {
  93. idx_size0 = 4;
  94. TensorLayout val_layout{{23}, dtype::Float32()};
  95. val_layout.stride[0] = 0;
  96. checker.
  97. set_proxy({{0}}).
  98. execl({{{4}, dtype::Float32()},
  99. val_layout,
  100. {{23}, dtype::Int32()}
  101. });
  102. }
  103. }
  104. }
  105. TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
  106. run_check<IndexingMultiAxisVec>(handle_cuda());
  107. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  108. size_t idx_size0;
  109. OrderedRNG rng_inp;
  110. IndexRNG rng0{idx_size0, 2};
  111. checker.
  112. set_dtype(0, dtype::Float32()). // data
  113. set_dtype(1, dtype::Float32()). // value
  114. set_dtype(2, dtype::Int32()). // idx0
  115. set_rng(0, &rng_inp).
  116. set_rng(1, &rng_inp).
  117. set_rng(2, &rng0);
  118. idx_size0 = 20;
  119. checker.set_proxy({{0}})
  120. .execl({TensorLayout{{20}, dtype::Float32()},
  121. TensorLayout{{9}, dtype::Float32()},
  122. TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
  123. }
  124. TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
  125. run_check<IndexingIncrMultiAxisVec>(handle_cuda());
  126. Checker<IndexingIncrMultiAxisVec> checker(handle_cuda());
  127. OrderedRNG rng;
  128. checker.
  129. set_dtype(0, dtype::Float16()). // data
  130. set_dtype(1, dtype::Float16()). // value
  131. set_dtype(2, dtype::Int32()). // idx0
  132. set_rng(0, &rng).
  133. set_rng(1, &rng).
  134. set_rng(2, &rng);
  135. checker.
  136. set_proxy({{1}}).
  137. execs({{5, 8, 3}, {5, 2, 3}, {2}});
  138. }
  139. TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) {
  140. Checker<IndexingSetMultiAxisVec> checker(handle_cuda());
  141. OrderedRNG rng;
  142. checker.
  143. set_dtype(0, dtype::Float32()). // data
  144. set_dtype(1, dtype::Float32()). // value
  145. set_dtype(2, dtype::Int32()). // idx0
  146. set_rng(0, &rng).
  147. set_rng(1, &rng).
  148. set_rng(2, &rng);
  149. checker.
  150. set_proxy({{1}}).
  151. execs({{5, 8, 3}, {5, 2, 3}, {2}});
  152. }
  153. TEST_F(CUDA_ERROR_INFO, INDEXING_MULTI_AXIS_VEC) {
  154. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  155. UniformIntRNG idx_rng{-5, 5};
  156. checker.
  157. set_dtype(0, dtype::Float32()). // data
  158. set_dtype(1, dtype::Float32()). // value
  159. set_dtype(2, dtype::Int32()). // idx
  160. set_rng(2, &idx_rng);
  161. bool failed = false;
  162. ASSERT_EQ(0u, get_error_info().nr_error);
  163. auto on_fail = [&failed, this]() {
  164. failed = true;
  165. auto info = get_error_info();
  166. ASSERT_GE(info.nr_error, 1u);
  167. printf("error msg: ");
  168. printf(info.msg, info.msg_args[0], info.msg_args[1], info.msg_args[2],
  169. info.msg_args[3]);
  170. printf("\n");
  171. };
  172. checker.
  173. set_proxy({{0}}).
  174. execs({{23}, {100}, {100}});
  175. idx_rng = {-500, 500};
  176. checker.
  177. set_expect_exec_fail(on_fail).
  178. execs({{23}, {100}, {100}});
  179. ASSERT_TRUE(failed);
  180. }
  181. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台