You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mesh_indexing.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. /**
  2. * \file dnn/test/naive/mesh_indexing.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/mesh_indexing.h"
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/index.h"
  15. #include "test/naive/fixture.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. TEST_F(NAIVE, MESH_INDEXING) {
  19. SmallVector<size_t> init_axes;
  20. auto multi_axis_index_impl = [this, &init_axes](const TensorNDArray& tensors) {
  21. auto opr = handle()->create_operator<IndexingMultiAxisVec>();
  22. OprProxy<IndexingMultiAxisVec> proxy(init_axes);
  23. proxy.exec(opr.get(), tensors);
  24. };
  25. Checker<MeshIndexing> checker(handle());
  26. checker.set_extra_opr_impl(multi_axis_index_impl);
  27. size_t idx_size0, idx_size1;
  28. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  29. checker.set_dtype(0, dtype::Float32())
  30. .set_dtype(1, dtype::Float32())
  31. .set_dtype(2, dtype::Int32())
  32. .set_dtype(3, dtype::Int32())
  33. .set_rng(2, &rng0)
  34. .set_rng(3, &rng1);
  35. idx_size0 = 23;
  36. init_axes = {0};
  37. checker.set_proxy({init_axes})
  38. .execs({{23}, {100}, {100}})
  39. .execs({{23, 5}, {100, 5}, {100}});
  40. idx_size0 = 3;
  41. init_axes = {1};
  42. checker.set_proxy(init_axes)
  43. .execs({{2, 3}, {2, 10}, {10}})
  44. .execs({{2, 3, 5}, {2, 50, 5}, {50}})
  45. .execs({{2, 3, 5, 7}, {2, 55, 5, 7}, {55}});
  46. }
  47. TEST_F(NAIVE, BATCHED_MESH_INDEXING) {
  48. SmallVector<size_t> init_axes;
  49. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  50. auto opr = handle()->create_operator<MeshIndexing>();
  51. OprProxy<MeshIndexing> proxy(init_axes);
  52. size_t N = tensors[0].layout[0];
  53. for (size_t n = 0; n < N; ++n) {
  54. TensorNDArray new_tensors;
  55. for (size_t i = 0; i < tensors.size(); ++i) {
  56. auto&& tensor = tensors[i];
  57. TensorLayout layout = tensor.layout.remove_axis(0);
  58. if (i < 2) {
  59. layout.add_axis_cont_inplace(0);
  60. }
  61. void* ptr = static_cast<dt_byte*>(tensor.raw_ptr()) +
  62. tensor.layout.stride[0] * n * tensor.layout.dtype.size();
  63. new_tensors.emplace_back(ptr, layout);
  64. }
  65. proxy.exec(opr.get(), new_tensors);
  66. }
  67. };
  68. Checker<BatchedMeshIndexing> checker(handle());
  69. checker.set_extra_opr_impl(extra_impl);
  70. size_t idx_size0, idx_size1;
  71. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  72. checker.set_dtype(0, dtype::Float32())
  73. .set_dtype(1, dtype::Float32())
  74. .set_dtype(2, dtype::Int32())
  75. .set_dtype(3, dtype::Int32())
  76. .set_rng(2, &rng0)
  77. .set_rng(3, &rng1);
  78. idx_size0 = 5;
  79. init_axes = {1};
  80. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {1, 3}});
  81. idx_size0 = 23;
  82. idx_size1 = 17;
  83. init_axes = {1, 2};
  84. checker.set_proxy({init_axes})
  85. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {7, 10}, {7, 20}})
  86. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {7, 10}, {7, 20}});
  87. init_axes = {2, 1};
  88. checker.set_proxy({init_axes})
  89. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {8, 10}, {8, 20}})
  90. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {8, 10}, {8, 20}});
  91. idx_size0 = 5;
  92. init_axes = {1};
  93. TensorLayout index_layout{TensorShape{1, 3}, dtype::Int32()};
  94. index_layout = index_layout.broadcast({2, 3});
  95. checker.set_proxy({init_axes})
  96. .execl({TensorLayout{TensorShape{2, idx_size0}, dtype::Float32()},
  97. TensorLayout{TensorShape{2, 3}, dtype::Float32()}, index_layout});
  98. }
  99. TEST_F(NAIVE, MESH_MODIFY_INCREMENT) {
  100. SmallVector<size_t> init_axes;
  101. auto multi_axis_index_impl = [this, &init_axes](const TensorNDArray& tensors) {
  102. auto opr = handle()->create_operator<IndexingIncrMultiAxisVec>();
  103. OprProxy<IndexingIncrMultiAxisVec> proxy(init_axes);
  104. proxy.exec(opr.get(), tensors);
  105. };
  106. Checker<IncrMeshIndexing> checker(handle());
  107. checker.set_extra_opr_impl(multi_axis_index_impl);
  108. size_t idx_size0, idx_size1;
  109. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  110. checker.set_dtype(0, dtype::Float32())
  111. .set_dtype(1, dtype::Float32())
  112. .set_dtype(2, dtype::Int32())
  113. .set_dtype(3, dtype::Int32())
  114. .set_rng(2, &rng0)
  115. .set_rng(3, &rng1);
  116. idx_size0 = 23;
  117. init_axes = {0};
  118. checker.set_proxy({init_axes})
  119. .execs({{23}, {100}, {100}})
  120. .execs({{23, 5}, {100, 5}, {100}});
  121. idx_size0 = 3;
  122. init_axes = {1};
  123. checker.set_proxy(init_axes)
  124. .execs({{2, 3}, {2, 10}, {10}})
  125. .execs({{2, 3, 5}, {2, 50, 5}, {50}})
  126. .execs({{2, 3, 5, 7}, {2, 55, 5, 7}, {55}});
  127. }
  128. TEST_F(NAIVE, BATCHED_MESH_MODIFY_INCREMENT) {
  129. SmallVector<size_t> init_axes;
  130. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  131. auto opr = handle()->create_operator<IncrMeshIndexing>();
  132. OprProxy<IncrMeshIndexing> proxy(init_axes);
  133. size_t N = tensors[0].layout[0];
  134. for (size_t n = 0; n < N; ++n) {
  135. TensorNDArray new_tensors;
  136. for (size_t i = 0; i < tensors.size(); ++i) {
  137. auto&& tensor = tensors[i];
  138. TensorLayout layout = tensor.layout.remove_axis(0);
  139. if (i < 2) {
  140. layout.add_axis_cont_inplace(0);
  141. }
  142. void* ptr = static_cast<dt_byte*>(tensor.raw_ptr()) +
  143. tensor.layout.dtype.size(tensor.layout.stride[0] * n);
  144. new_tensors.emplace_back(ptr, layout);
  145. }
  146. proxy.exec(opr.get(), new_tensors);
  147. }
  148. };
  149. Checker<BatchedIncrMeshIndexing> checker(handle());
  150. checker.set_extra_opr_impl(extra_impl);
  151. size_t idx_size0, idx_size1;
  152. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  153. checker.set_dtype(0, dtype::Float32())
  154. .set_dtype(1, dtype::Float32())
  155. .set_dtype(2, dtype::Int32())
  156. .set_dtype(3, dtype::Int32())
  157. .set_rng(2, &rng0)
  158. .set_rng(3, &rng1);
  159. idx_size0 = 5;
  160. init_axes = {1};
  161. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {1, 3}});
  162. idx_size0 = 23;
  163. idx_size1 = 17;
  164. init_axes = {1, 2};
  165. checker.set_proxy({init_axes})
  166. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {7, 10}, {7, 20}})
  167. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {7, 10}, {7, 20}});
  168. init_axes = {2, 1};
  169. checker.set_proxy({init_axes})
  170. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {8, 10}, {8, 20}})
  171. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {8, 10}, {8, 20}});
  172. }
  173. TEST_F(NAIVE, MESH_MODIFY_SETTING) {
  174. SmallVector<size_t> init_axes;
  175. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  176. auto opr = handle()->create_operator<IncrMeshIndexing>();
  177. OprProxy<IncrMeshIndexing> proxy(init_axes);
  178. proxy.exec(opr.get(), tensors);
  179. };
  180. Checker<SetMeshIndexing> checker(handle());
  181. checker.set_extra_opr_impl(extra_impl);
  182. size_t idx_size0, idx_size1;
  183. mesh_indexing::NoReplacementIndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  184. ConstValue zero_gen;
  185. checker.set_dtype(0, dtype::Float32())
  186. .set_dtype(1, dtype::Float32())
  187. .set_dtype(2, dtype::Int32())
  188. .set_dtype(3, dtype::Int32())
  189. .set_rng(2, &rng0)
  190. .set_rng(3, &rng1)
  191. .set_rng(0, &zero_gen);
  192. idx_size0 = 5;
  193. init_axes = {1};
  194. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {3}});
  195. idx_size0 = 23;
  196. idx_size1 = 20;
  197. init_axes = {1, 2};
  198. checker.set_proxy({init_axes})
  199. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {10}, {20}})
  200. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {10}, {20}});
  201. init_axes = {2, 1};
  202. checker.set_proxy({init_axes})
  203. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {10}, {20}})
  204. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {10}, {20}});
  205. }
  206. TEST_F(NAIVE, BATCHED_MESH_MODIFY_SETTING) {
  207. SmallVector<size_t> init_axes;
  208. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  209. auto opr = handle()->create_operator<BatchedIncrMeshIndexing>();
  210. OprProxy<BatchedIncrMeshIndexing> proxy(init_axes);
  211. proxy.exec(opr.get(), tensors);
  212. };
  213. Checker<BatchedSetMeshIndexing> checker(handle());
  214. checker.set_extra_opr_impl(extra_impl);
  215. size_t idx_size0, idx_size1;
  216. mesh_indexing::NoReplacementIndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  217. ConstValue zero_gen;
  218. checker.set_dtype(0, dtype::Float32())
  219. .set_dtype(1, dtype::Float32())
  220. .set_dtype(2, dtype::Int32())
  221. .set_dtype(3, dtype::Int32())
  222. .set_rng(2, &rng0)
  223. .set_rng(3, &rng1)
  224. .set_rng(0, &zero_gen);
  225. idx_size0 = 5;
  226. init_axes = {1};
  227. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {1, 3}});
  228. idx_size0 = 23;
  229. idx_size1 = 20;
  230. init_axes = {1, 2};
  231. checker.set_proxy({init_axes})
  232. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {7, 10}, {7, 20}})
  233. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {7, 10}, {7, 20}});
  234. init_axes = {2, 1};
  235. checker.set_proxy({init_axes})
  236. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {8, 10}, {8, 20}})
  237. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {8, 10}, {8, 20}});
  238. }