You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mesh_indexing.cpp 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. #include "test/common/mesh_indexing.h"
  2. #include "megdnn/basic_types.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/index.h"
  5. #include "test/naive/fixture.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. TEST_F(NAIVE, MESH_INDEXING) {
  9. SmallVector<size_t> init_axes;
  10. auto multi_axis_index_impl = [this, &init_axes](const TensorNDArray& tensors) {
  11. auto opr = handle()->create_operator<IndexingMultiAxisVec>();
  12. OprProxy<IndexingMultiAxisVec> proxy(init_axes);
  13. proxy.exec(opr.get(), tensors);
  14. };
  15. Checker<MeshIndexing> checker(handle());
  16. checker.set_extra_opr_impl(multi_axis_index_impl);
  17. size_t idx_size0, idx_size1;
  18. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  19. checker.set_dtype(0, dtype::Float32())
  20. .set_dtype(1, dtype::Float32())
  21. .set_dtype(2, dtype::Int32())
  22. .set_dtype(3, dtype::Int32())
  23. .set_rng(2, &rng0)
  24. .set_rng(3, &rng1);
  25. idx_size0 = 23;
  26. init_axes = {0};
  27. checker.set_proxy({init_axes})
  28. .execs({{23}, {100}, {100}})
  29. .execs({{23, 5}, {100, 5}, {100}});
  30. idx_size0 = 3;
  31. init_axes = {1};
  32. checker.set_proxy(init_axes)
  33. .execs({{2, 3}, {2, 10}, {10}})
  34. .execs({{2, 3, 5}, {2, 50, 5}, {50}})
  35. .execs({{2, 3, 5, 7}, {2, 55, 5, 7}, {55}});
  36. }
  37. TEST_F(NAIVE, BATCHED_MESH_INDEXING) {
  38. SmallVector<size_t> init_axes;
  39. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  40. auto opr = handle()->create_operator<MeshIndexing>();
  41. OprProxy<MeshIndexing> proxy(init_axes);
  42. size_t N = tensors[0].layout[0];
  43. for (size_t n = 0; n < N; ++n) {
  44. TensorNDArray new_tensors;
  45. for (size_t i = 0; i < tensors.size(); ++i) {
  46. auto&& tensor = tensors[i];
  47. TensorLayout layout = tensor.layout.remove_axis(0);
  48. if (i < 2) {
  49. layout.add_axis_cont_inplace(0);
  50. }
  51. void* ptr = static_cast<dt_byte*>(tensor.raw_ptr()) +
  52. tensor.layout.stride[0] * n * tensor.layout.dtype.size();
  53. new_tensors.emplace_back(ptr, layout);
  54. }
  55. proxy.exec(opr.get(), new_tensors);
  56. }
  57. };
  58. Checker<BatchedMeshIndexing> checker(handle());
  59. checker.set_extra_opr_impl(extra_impl);
  60. size_t idx_size0, idx_size1;
  61. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  62. checker.set_dtype(0, dtype::Float32())
  63. .set_dtype(1, dtype::Float32())
  64. .set_dtype(2, dtype::Int32())
  65. .set_dtype(3, dtype::Int32())
  66. .set_rng(2, &rng0)
  67. .set_rng(3, &rng1);
  68. idx_size0 = 5;
  69. init_axes = {1};
  70. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {1, 3}});
  71. idx_size0 = 23;
  72. idx_size1 = 17;
  73. init_axes = {1, 2};
  74. checker.set_proxy({init_axes})
  75. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {7, 10}, {7, 20}})
  76. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {7, 10}, {7, 20}});
  77. init_axes = {2, 1};
  78. checker.set_proxy({init_axes})
  79. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {8, 10}, {8, 20}})
  80. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {8, 10}, {8, 20}});
  81. idx_size0 = 5;
  82. init_axes = {1};
  83. TensorLayout index_layout{TensorShape{1, 3}, dtype::Int32()};
  84. index_layout = index_layout.broadcast({2, 3});
  85. checker.set_proxy({init_axes})
  86. .execl({TensorLayout{TensorShape{2, idx_size0}, dtype::Float32()},
  87. TensorLayout{TensorShape{2, 3}, dtype::Float32()}, index_layout});
  88. }
  89. TEST_F(NAIVE, MESH_MODIFY_INCREMENT) {
  90. SmallVector<size_t> init_axes;
  91. auto multi_axis_index_impl = [this, &init_axes](const TensorNDArray& tensors) {
  92. auto opr = handle()->create_operator<IndexingIncrMultiAxisVec>();
  93. OprProxy<IndexingIncrMultiAxisVec> proxy(init_axes);
  94. proxy.exec(opr.get(), tensors);
  95. };
  96. Checker<IncrMeshIndexing> checker(handle());
  97. checker.set_extra_opr_impl(multi_axis_index_impl);
  98. size_t idx_size0, idx_size1;
  99. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  100. checker.set_dtype(0, dtype::Float32())
  101. .set_dtype(1, dtype::Float32())
  102. .set_dtype(2, dtype::Int32())
  103. .set_dtype(3, dtype::Int32())
  104. .set_rng(2, &rng0)
  105. .set_rng(3, &rng1);
  106. idx_size0 = 23;
  107. init_axes = {0};
  108. checker.set_proxy({init_axes})
  109. .execs({{23}, {100}, {100}})
  110. .execs({{23, 5}, {100, 5}, {100}});
  111. idx_size0 = 3;
  112. init_axes = {1};
  113. checker.set_proxy(init_axes)
  114. .execs({{2, 3}, {2, 10}, {10}})
  115. .execs({{2, 3, 5}, {2, 50, 5}, {50}})
  116. .execs({{2, 3, 5, 7}, {2, 55, 5, 7}, {55}});
  117. }
  118. TEST_F(NAIVE, BATCHED_MESH_MODIFY_INCREMENT) {
  119. SmallVector<size_t> init_axes;
  120. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  121. auto opr = handle()->create_operator<IncrMeshIndexing>();
  122. OprProxy<IncrMeshIndexing> proxy(init_axes);
  123. size_t N = tensors[0].layout[0];
  124. for (size_t n = 0; n < N; ++n) {
  125. TensorNDArray new_tensors;
  126. for (size_t i = 0; i < tensors.size(); ++i) {
  127. auto&& tensor = tensors[i];
  128. TensorLayout layout = tensor.layout.remove_axis(0);
  129. if (i < 2) {
  130. layout.add_axis_cont_inplace(0);
  131. }
  132. void* ptr = static_cast<dt_byte*>(tensor.raw_ptr()) +
  133. tensor.layout.dtype.size(tensor.layout.stride[0] * n);
  134. new_tensors.emplace_back(ptr, layout);
  135. }
  136. proxy.exec(opr.get(), new_tensors);
  137. }
  138. };
  139. Checker<BatchedIncrMeshIndexing> checker(handle());
  140. checker.set_extra_opr_impl(extra_impl);
  141. size_t idx_size0, idx_size1;
  142. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  143. checker.set_dtype(0, dtype::Float32())
  144. .set_dtype(1, dtype::Float32())
  145. .set_dtype(2, dtype::Int32())
  146. .set_dtype(3, dtype::Int32())
  147. .set_rng(2, &rng0)
  148. .set_rng(3, &rng1);
  149. idx_size0 = 5;
  150. init_axes = {1};
  151. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {1, 3}});
  152. idx_size0 = 23;
  153. idx_size1 = 17;
  154. init_axes = {1, 2};
  155. checker.set_proxy({init_axes})
  156. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {7, 10}, {7, 20}})
  157. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {7, 10}, {7, 20}});
  158. init_axes = {2, 1};
  159. checker.set_proxy({init_axes})
  160. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {8, 10}, {8, 20}})
  161. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {8, 10}, {8, 20}});
  162. }
  163. TEST_F(NAIVE, MESH_MODIFY_SETTING) {
  164. SmallVector<size_t> init_axes;
  165. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  166. auto opr = handle()->create_operator<IncrMeshIndexing>();
  167. OprProxy<IncrMeshIndexing> proxy(init_axes);
  168. proxy.exec(opr.get(), tensors);
  169. };
  170. Checker<SetMeshIndexing> checker(handle());
  171. checker.set_extra_opr_impl(extra_impl);
  172. size_t idx_size0, idx_size1;
  173. mesh_indexing::NoReplacementIndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  174. ConstValue zero_gen;
  175. checker.set_dtype(0, dtype::Float32())
  176. .set_dtype(1, dtype::Float32())
  177. .set_dtype(2, dtype::Int32())
  178. .set_dtype(3, dtype::Int32())
  179. .set_rng(2, &rng0)
  180. .set_rng(3, &rng1)
  181. .set_rng(0, &zero_gen);
  182. idx_size0 = 5;
  183. init_axes = {1};
  184. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {3}});
  185. idx_size0 = 23;
  186. idx_size1 = 20;
  187. init_axes = {1, 2};
  188. checker.set_proxy({init_axes})
  189. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {10}, {20}})
  190. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {10}, {20}});
  191. init_axes = {2, 1};
  192. checker.set_proxy({init_axes})
  193. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {10}, {20}})
  194. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {10}, {20}});
  195. }
  196. TEST_F(NAIVE, BATCHED_MESH_MODIFY_SETTING) {
  197. SmallVector<size_t> init_axes;
  198. auto extra_impl = [this, &init_axes](const TensorNDArray& tensors) {
  199. auto opr = handle()->create_operator<BatchedIncrMeshIndexing>();
  200. OprProxy<BatchedIncrMeshIndexing> proxy(init_axes);
  201. proxy.exec(opr.get(), tensors);
  202. };
  203. Checker<BatchedSetMeshIndexing> checker(handle());
  204. checker.set_extra_opr_impl(extra_impl);
  205. size_t idx_size0, idx_size1;
  206. mesh_indexing::NoReplacementIndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  207. ConstValue zero_gen;
  208. checker.set_dtype(0, dtype::Float32())
  209. .set_dtype(1, dtype::Float32())
  210. .set_dtype(2, dtype::Int32())
  211. .set_dtype(3, dtype::Int32())
  212. .set_rng(2, &rng0)
  213. .set_rng(3, &rng1)
  214. .set_rng(0, &zero_gen);
  215. idx_size0 = 5;
  216. init_axes = {1};
  217. checker.set_proxy({init_axes}).execs({{1, idx_size0}, {1, 3}, {1, 3}});
  218. idx_size0 = 23;
  219. idx_size1 = 20;
  220. init_axes = {1, 2};
  221. checker.set_proxy({init_axes})
  222. .execs({{7, idx_size0, idx_size1}, {7, 10, 20}, {7, 10}, {7, 20}})
  223. .execs({{7, idx_size0, idx_size1, 9}, {7, 10, 20, 9}, {7, 10}, {7, 20}});
  224. init_axes = {2, 1};
  225. checker.set_proxy({init_axes})
  226. .execs({{8, idx_size1, idx_size0}, {8, 20, 10}, {8, 10}, {8, 20}})
  227. .execs({{8, idx_size1, idx_size0, 9}, {8, 20, 10, 9}, {8, 10}, {8, 20}});
  228. }