You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /**
  2. * \file dnn/test/rocm/indexing_multi_axis_vec.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "test/rocm/fixture.h"
  13. #include "test/rocm/benchmarker.h"
  14. #include "megdnn/oprs.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/indexing_multi_axis_vec.h"
  17. #include "test/common/index.h"
  18. #include <random>
  19. using namespace megdnn;
  20. using namespace test;
  21. namespace {
  22. class OrderedRNG final: public RNG {
  23. public:
  24. void gen(const TensorND &tensor) override {
  25. auto span = tensor.layout.span();
  26. if (tensor.layout.dtype == dtype::Float32()) {
  27. auto ptr = tensor.ptr<float>() + span.low_elem;
  28. for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
  29. ptr[i] = i;
  30. }
  31. } else {
  32. auto ptr = tensor.ptr<int>() + span.low_elem;
  33. for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
  34. ptr[i] = i;
  35. }
  36. }
  37. }
  38. };
  39. template<class Opr>
  40. void run_check(Handle *handle) {
  41. // see OprProxyIndexingMultiAxisVecHelper for more details
  42. // execs() give input, output and index layouts
  43. Checker<Opr> checker(handle);
  44. size_t idx_size0, idx_size1;
  45. OrderedRNG rng_inp;
  46. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  47. checker.
  48. set_dtype(0, dtype::Float32()). // data
  49. set_dtype(1, dtype::Float32()). // value
  50. set_dtype(2, dtype::Int32()). // idx0
  51. set_dtype(3, dtype::Int32()). // idx1
  52. set_rng(0, &rng_inp).
  53. set_rng(1, &rng_inp).
  54. set_rng(2, &rng0).
  55. set_rng(3, &rng1);
  56. idx_size0 = 23;
  57. checker.
  58. set_proxy({{0}}).
  59. execs({{23}, {100}, {100}}).
  60. execs({{23, 5}, {100, 5}, {100}});
  61. idx_size0 = 2;
  62. idx_size1 = 3;
  63. checker.
  64. set_proxy({{0, 1}}).
  65. execs({{2, 3}, {10}, {10}, {10}}).
  66. execs({{2, 3, 5}, {10, 5}, {10}, {10}});
  67. idx_size0 = 4;
  68. idx_size1 = 6;
  69. TensorLayout inp_layout{{3, 4, 5, 6}, dtype::Float32()};
  70. inp_layout.stride[0] *= 8;
  71. inp_layout.stride[1] *= 2;
  72. checker.
  73. set_proxy({{1, 3}}).
  74. execl({inp_layout,
  75. {{7, 3, 5}, dtype::Float32()},
  76. {{7}, dtype::Int32()},
  77. {{1}, dtype::Int32()},
  78. });
  79. idx_size0 = 4;
  80. idx_size1 = 5;
  81. checker.
  82. set_proxy({{2, 3}}).
  83. execs({{2, 3, 4, 5, 6, 7}, {2, 3, 10, 6, 7}, {10}, {10}});
  84. idx_size0 = 4;
  85. checker.
  86. set_proxy({{1}}).
  87. execs({{1, 4}, {1, 1024 * 1024}, {1024 * 1024}});
  88. if (std::is_same<Opr, IndexingIncrMultiAxisVec>::value) {
  89. idx_size0 = 4;
  90. TensorLayout val_layout{{23}, dtype::Float32()};
  91. val_layout.stride[0] = 0;
  92. checker.
  93. set_proxy({{0}}).
  94. execl({{{4}, dtype::Float32()},
  95. val_layout,
  96. {{23}, dtype::Int32()}
  97. });
  98. }
  99. }
  100. }
  101. TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC) {
  102. run_check<IndexingMultiAxisVec>(handle_rocm());
  103. Checker<IndexingMultiAxisVec> checker(handle_rocm());
  104. size_t idx_size0;
  105. OrderedRNG rng_inp;
  106. IndexRNG rng0{idx_size0, 2};
  107. checker.
  108. set_dtype(0, dtype::Float32()). // data
  109. set_dtype(1, dtype::Float32()). // value
  110. set_dtype(2, dtype::Int32()). // idx0
  111. set_rng(0, &rng_inp).
  112. set_rng(1, &rng_inp).
  113. set_rng(2, &rng0);
  114. idx_size0 = 20;
  115. checker.set_proxy({{0}})
  116. .execl({TensorLayout{{20}, dtype::Float32()},
  117. TensorLayout{{9}, dtype::Float32()},
  118. TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
  119. }
  120. TEST_F(ROCM, INDEXING_INCR_MULTI_AXIS_VEC) {
  121. run_check<IndexingIncrMultiAxisVec>(handle_rocm());
  122. }
  123. TEST_F(ROCM, INDEXING_SET_MULTI_AXIS_VEC) {
  124. Checker<IndexingSetMultiAxisVec> checker(handle_rocm());
  125. OrderedRNG rng;
  126. checker.
  127. set_dtype(0, dtype::Float32()). // data
  128. set_dtype(1, dtype::Float32()). // value
  129. set_dtype(2, dtype::Int32()). // idx0
  130. set_rng(0, &rng).
  131. set_rng(1, &rng).
  132. set_rng(2, &rng);
  133. checker.
  134. set_proxy({{1}}).
  135. execs({{5, 8, 3}, {5, 2, 3}, {2}});
  136. }
  137. TEST_F(ROCM_ERROR_INFO, INDEXING_MULTI_AXIS_VEC) {
  138. Checker<IndexingMultiAxisVec> checker(handle_rocm());
  139. UniformIntRNG idx_rng{-5, 5};
  140. checker.
  141. set_dtype(0, dtype::Float32()). // data
  142. set_dtype(1, dtype::Float32()). // value
  143. set_dtype(2, dtype::Int32()). // idx
  144. set_rng(2, &idx_rng);
  145. bool failed = false;
  146. ASSERT_EQ(0u, get_error_info().nr_error);
  147. auto on_fail = [&failed, this]() {
  148. failed = true;
  149. auto info = get_error_info();
  150. ASSERT_GE(info.nr_error, 1u);
  151. printf("error msg: ");
  152. printf(info.msg, info.msg_args[0], info.msg_args[1], info.msg_args[2],
  153. info.msg_args[3]);
  154. printf("\n");
  155. };
  156. checker.
  157. set_proxy({{0}}).
  158. execs({{23}, {100}, {100}});
  159. idx_rng = {-500, 500};
  160. checker.
  161. set_expect_exec_fail(on_fail).
  162. execs({{23}, {100}, {100}});
  163. ASSERT_TRUE(failed);
  164. }
  165. TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC_BENCHMARK) {
  166. ROCMBenchmarker<IndexingMultiAxisVec> benchmarker(handle_rocm(), handle_naive(false));
  167. benchmarker.set_display(true);
  168. OrderedRNG rng_inp;
  169. size_t idx_size = 10000;
  170. IndexRNG rng0{idx_size, 3}, rng1{idx_size, 1};
  171. benchmarker.
  172. set_dtype(0, dtype::Float32()).
  173. set_dtype(1, dtype::Float32()).
  174. set_dtype(2, dtype::Int32()).
  175. set_dtype(3, dtype::Int32()).
  176. set_rng(0, &rng_inp).
  177. set_rng(1, &rng_inp).
  178. set_rng(2, &rng0).
  179. set_rng(3, &rng1).
  180. set_proxy({{0, 1}});
  181. auto time_ms = benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}});
  182. long io = 2 * 1000 * 1000 * dtype::Float32().size();
  183. printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n",
  184. (float)(io / 1e9), (float)(io / (time_ms * 1e6)));
  185. }
  186. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台