You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * \file dnn/test/cuda/indexing_multi_axis_vec.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/indexing_multi_axis_vec.h"
  15. #include "test/common/index.h"
  16. #include <random>
  17. using namespace megdnn;
  18. using namespace test;
  19. namespace {
  20. class OrderedRNG final: public RNG {
  21. public:
  22. void gen(const TensorND &tensor) override {
  23. auto span = tensor.layout.span();
  24. if (tensor.layout.dtype == dtype::Float32()) {
  25. auto ptr = tensor.ptr<float>() + span.low_elem;
  26. for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
  27. ptr[i] = i;
  28. }
  29. } else {
  30. auto ptr = tensor.ptr<int>() + span.low_elem;
  31. for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
  32. ptr[i] = i;
  33. }
  34. }
  35. }
  36. };
  37. template<class Opr>
  38. void run_check(Handle *handle) {
  39. // see OprProxyIndexingMultiAxisVecHelper for more details
  40. // set_proxy() sets the axes to index on
  41. // execs() give input, output and index layouts
  42. Checker<Opr> checker(handle);
  43. size_t idx_size0, idx_size1;
  44. OrderedRNG rng_inp;
  45. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  46. checker.
  47. set_dtype(0, dtype::Float32()). // data
  48. set_dtype(1, dtype::Float32()). // value
  49. set_dtype(2, dtype::Int32()). // idx0
  50. set_dtype(3, dtype::Int32()). // idx1
  51. set_rng(0, &rng_inp).
  52. set_rng(1, &rng_inp).
  53. set_rng(2, &rng0).
  54. set_rng(3, &rng1);
  55. idx_size0 = 23;
  56. checker.
  57. set_proxy({{0}}).
  58. execs({{23}, {100}, {100}}).
  59. execs({{23, 5}, {100, 5}, {100}});
  60. idx_size0 = 2;
  61. idx_size1 = 3;
  62. checker.
  63. set_proxy({{0, 1}}).
  64. execs({{2, 3}, {10}, {10}, {10}}).
  65. execs({{2, 3, 5}, {10, 5}, {10}, {10}});
  66. idx_size0 = 4;
  67. idx_size1 = 6;
  68. TensorLayout inp_layout{{3, 4, 5, 6}, dtype::Float32()};
  69. inp_layout.stride[0] *= 8;
  70. inp_layout.stride[1] *= 2;
  71. checker.
  72. set_proxy({{1, 3}}).
  73. execl({inp_layout,
  74. {{7, 3, 5}, dtype::Float32()},
  75. {{7}, dtype::Int32()},
  76. {{1}, dtype::Int32()},
  77. });
  78. idx_size0 = 4;
  79. idx_size1 = 5;
  80. checker.
  81. set_proxy({{2, 3}}).
  82. execs({{2, 3, 4, 5, 6, 7}, {2, 3, 10, 6, 7}, {10}, {10}});
  83. idx_size0 = 4;
  84. checker.
  85. set_proxy({{1}}).
  86. execs({{1, 4}, {1, 1024 * 1024}, {1024 * 1024}});
  87. if (std::is_same<Opr, IndexingIncrMultiAxisVec>::value) {
  88. idx_size0 = 4;
  89. TensorLayout val_layout{{23}, dtype::Float32()};
  90. val_layout.stride[0] = 0;
  91. checker.
  92. set_proxy({{0}}).
  93. execl({{{4}, dtype::Float32()},
  94. val_layout,
  95. {{23}, dtype::Int32()}
  96. });
  97. }
  98. }
  99. }
  100. TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
  101. run_check<IndexingMultiAxisVec>(handle_cuda());
  102. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  103. size_t idx_size0;
  104. OrderedRNG rng_inp;
  105. IndexRNG rng0{idx_size0, 2};
  106. checker.
  107. set_dtype(0, dtype::Float32()). // data
  108. set_dtype(1, dtype::Float32()). // value
  109. set_dtype(2, dtype::Int32()). // idx0
  110. set_rng(0, &rng_inp).
  111. set_rng(1, &rng_inp).
  112. set_rng(2, &rng0);
  113. idx_size0 = 20;
  114. checker.set_proxy({{0}})
  115. .execl({TensorLayout{{20}, dtype::Float32()},
  116. TensorLayout{{9}, dtype::Float32()},
  117. TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
  118. }
  119. TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
  120. run_check<IndexingIncrMultiAxisVec>(handle_cuda());
  121. }
  122. TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) {
  123. Checker<IndexingSetMultiAxisVec> checker(handle_cuda());
  124. OrderedRNG rng;
  125. checker.
  126. set_dtype(0, dtype::Float32()). // data
  127. set_dtype(1, dtype::Float32()). // value
  128. set_dtype(2, dtype::Int32()). // idx0
  129. set_rng(0, &rng).
  130. set_rng(1, &rng).
  131. set_rng(2, &rng);
  132. checker.
  133. set_proxy({{1}}).
  134. execs({{5, 8, 3}, {5, 2, 3}, {2}});
  135. }
  136. TEST_F(CUDA_ERROR_INFO, INDEXING_MULTI_AXIS_VEC) {
  137. Checker<IndexingMultiAxisVec> checker(handle_cuda());
  138. UniformIntRNG idx_rng{-5, 5};
  139. checker.
  140. set_dtype(0, dtype::Float32()). // data
  141. set_dtype(1, dtype::Float32()). // value
  142. set_dtype(2, dtype::Int32()). // idx
  143. set_rng(2, &idx_rng);
  144. bool failed = false;
  145. ASSERT_EQ(0u, get_error_info().nr_error);
  146. auto on_fail = [&failed, this]() {
  147. failed = true;
  148. auto info = get_error_info();
  149. ASSERT_GE(info.nr_error, 1u);
  150. printf("error msg: ");
  151. printf(info.msg, info.msg_args[0], info.msg_args[1], info.msg_args[2],
  152. info.msg_args[3]);
  153. printf("\n");
  154. };
  155. checker.
  156. set_proxy({{0}}).
  157. execs({{23}, {100}, {100}});
  158. idx_rng = {-500, 500};
  159. checker.
  160. set_expect_exec_fail(on_fail).
  161. execs({{23}, {100}, {100}});
  162. ASSERT_TRUE(failed);
  163. }
  164. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台