You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. /**
  2. * \file dnn/test/rocm/indexing_multi_axis_vec.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "test/rocm/benchmarker.h"
  13. #include "test/rocm/fixture.h"
  14. #include "megdnn/oprs.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/index.h"
  17. #include "test/common/indexing_multi_axis_vec.h"
  18. #include <random>
  19. using namespace megdnn;
  20. using namespace test;
  21. namespace {
  22. class OrderedRNG final : public RNG {
  23. public:
  24. void gen(const TensorND& tensor) override {
  25. auto span = tensor.layout.span();
  26. if (tensor.layout.dtype == dtype::Float32()) {
  27. auto ptr = tensor.ptr<float>() + span.low_elem;
  28. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  29. ptr[i] = i;
  30. }
  31. } else {
  32. auto ptr = tensor.ptr<int>() + span.low_elem;
  33. for (size_t i = 0, it = span.dist_elem(); i < it; ++i) {
  34. ptr[i] = i;
  35. }
  36. }
  37. }
  38. };
  39. template <class Opr>
  40. void run_check(Handle* handle) {
  41. // see OprProxyIndexingMultiAxisVecHelper for more details
  42. // execs() give input, output and index layouts
  43. Checker<Opr> checker(handle);
  44. size_t idx_size0, idx_size1;
  45. OrderedRNG rng_inp;
  46. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  47. checker.set_dtype(0, dtype::Float32())
  48. . // data
  49. set_dtype(1, dtype::Float32())
  50. . // value
  51. set_dtype(2, dtype::Int32())
  52. . // idx0
  53. set_dtype(3, dtype::Int32())
  54. . // idx1
  55. set_rng(0, &rng_inp)
  56. .set_rng(1, &rng_inp)
  57. .set_rng(2, &rng0)
  58. .set_rng(3, &rng1);
  59. idx_size0 = 23;
  60. checker.set_proxy({{0}})
  61. .execs({{23}, {100}, {100}})
  62. .execs({{23, 5}, {100, 5}, {100}});
  63. idx_size0 = 2;
  64. idx_size1 = 3;
  65. checker.set_proxy({{0, 1}})
  66. .execs({{2, 3}, {10}, {10}, {10}})
  67. .execs({{2, 3, 5}, {10, 5}, {10}, {10}});
  68. idx_size0 = 4;
  69. idx_size1 = 6;
  70. TensorLayout inp_layout{{3, 4, 5, 6}, dtype::Float32()};
  71. inp_layout.stride[0] *= 8;
  72. inp_layout.stride[1] *= 2;
  73. checker.set_proxy({{1, 3}}).execl({
  74. inp_layout,
  75. {{7, 3, 5}, dtype::Float32()},
  76. {{7}, dtype::Int32()},
  77. {{1}, dtype::Int32()},
  78. });
  79. idx_size0 = 4;
  80. idx_size1 = 5;
  81. checker.set_proxy({{2, 3}}).execs(
  82. {{2, 3, 4, 5, 6, 7}, {2, 3, 10, 6, 7}, {10}, {10}});
  83. idx_size0 = 4;
  84. checker.set_proxy({{1}}).execs({{1, 4}, {1, 1024 * 1024}, {1024 * 1024}});
  85. if (std::is_same<Opr, IndexingIncrMultiAxisVec>::value) {
  86. idx_size0 = 4;
  87. TensorLayout val_layout{{23}, dtype::Float32()};
  88. val_layout.stride[0] = 0;
  89. checker.set_proxy({{0}}).execl(
  90. {{{4}, dtype::Float32()}, val_layout, {{23}, dtype::Int32()}});
  91. }
  92. }
  93. } // namespace
  94. TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC) {
  95. run_check<IndexingMultiAxisVec>(handle_rocm());
  96. Checker<IndexingMultiAxisVec> checker(handle_rocm());
  97. size_t idx_size0;
  98. OrderedRNG rng_inp;
  99. IndexRNG rng0{idx_size0, 2};
  100. checker.set_dtype(0, dtype::Float32())
  101. . // data
  102. set_dtype(1, dtype::Float32())
  103. . // value
  104. set_dtype(2, dtype::Int32())
  105. . // idx0
  106. set_rng(0, &rng_inp)
  107. .set_rng(1, &rng_inp)
  108. .set_rng(2, &rng0);
  109. idx_size0 = 20;
  110. checker.set_proxy({{0}}).execl(
  111. {TensorLayout{{20}, dtype::Float32()}, TensorLayout{{9}, dtype::Float32()},
  112. TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}});
  113. }
  114. TEST_F(ROCM, INDEXING_INCR_MULTI_AXIS_VEC) {
  115. run_check<IndexingIncrMultiAxisVec>(handle_rocm());
  116. }
  117. TEST_F(ROCM, INDEXING_SET_MULTI_AXIS_VEC) {
  118. Checker<IndexingSetMultiAxisVec> checker(handle_rocm());
  119. OrderedRNG rng;
  120. checker.set_dtype(0, dtype::Float32())
  121. . // data
  122. set_dtype(1, dtype::Float32())
  123. . // value
  124. set_dtype(2, dtype::Int32())
  125. . // idx0
  126. set_rng(0, &rng)
  127. .set_rng(1, &rng)
  128. .set_rng(2, &rng);
  129. checker.set_proxy({{1}}).execs({{5, 8, 3}, {5, 2, 3}, {2}});
  130. }
  131. TEST_F(ROCM_ERROR_INFO, INDEXING_MULTI_AXIS_VEC) {
  132. Checker<IndexingMultiAxisVec> checker(handle_rocm());
  133. UniformIntRNG idx_rng{-5, 5};
  134. checker.set_dtype(0, dtype::Float32())
  135. . // data
  136. set_dtype(1, dtype::Float32())
  137. . // value
  138. set_dtype(2, dtype::Int32())
  139. . // idx
  140. set_rng(2, &idx_rng);
  141. bool failed = false;
  142. ASSERT_EQ(0u, get_error_info().nr_error);
  143. auto on_fail = [&failed, this]() {
  144. failed = true;
  145. auto info = get_error_info();
  146. ASSERT_GE(info.nr_error, 1u);
  147. printf("error msg: ");
  148. printf(info.msg, info.msg_args[0], info.msg_args[1], info.msg_args[2],
  149. info.msg_args[3]);
  150. printf("\n");
  151. };
  152. checker.set_proxy({{0}}).execs({{23}, {100}, {100}});
  153. idx_rng = {-500, 500};
  154. checker.set_expect_exec_fail(on_fail).execs({{23}, {100}, {100}});
  155. ASSERT_TRUE(failed);
  156. }
  157. TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC_BENCHMARK) {
  158. ROCMBenchmarker<IndexingMultiAxisVec> benchmarker(
  159. handle_rocm(), handle_naive(false));
  160. benchmarker.set_display(true);
  161. OrderedRNG rng_inp;
  162. size_t idx_size = 10000;
  163. IndexRNG rng0{idx_size, 3}, rng1{idx_size, 1};
  164. benchmarker.set_dtype(0, dtype::Float32())
  165. .set_dtype(1, dtype::Float32())
  166. .set_dtype(2, dtype::Int32())
  167. .set_dtype(3, dtype::Int32())
  168. .set_rng(0, &rng_inp)
  169. .set_rng(1, &rng_inp)
  170. .set_rng(2, &rng0)
  171. .set_rng(3, &rng1)
  172. .set_proxy({{0, 1}});
  173. auto time_ms =
  174. benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}});
  175. long io = 2 * 1000 * 1000 * dtype::Float32().size();
  176. printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n", (float)(io / 1e9),
  177. (float)(io / (time_ms * 1e6)));
  178. }
  179. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台