You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mesh_indexing.cpp 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #include "test/common/mesh_indexing.h"
  2. #include "megdnn/basic_types.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/index.h"
  5. #include "test/cuda/fixture.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. TEST_F(CUDA, MESH_INDEXING) {
  9. Checker<MeshIndexing> checker(handle_cuda());
  10. size_t idx_size0, idx_size1;
  11. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  12. checker.set_dtype(0, dtype::Float32())
  13. .set_dtype(1, dtype::Float32())
  14. .set_dtype(2, dtype::Int32())
  15. .set_dtype(3, dtype::Int32())
  16. .set_rng(2, &rng0)
  17. .set_rng(3, &rng1);
  18. SmallVector<size_t> init_axes;
  19. idx_size0 = 23;
  20. init_axes = {0};
  21. checker.set_proxy({init_axes})
  22. .execs({{23}, {100}, {100}})
  23. .execs({{23, 5}, {100, 5}, {100}});
  24. idx_size0 = 3;
  25. init_axes = {1};
  26. checker.set_proxy({init_axes})
  27. .execs({{2, 3}, {2, 10}, {10}})
  28. .execs({{2, 3, 5}, {2, 50, 5}, {50}})
  29. .execs({{2, 3, 5, 7}, {2, 55, 5, 7}, {55}});
  30. idx_size0 = 23;
  31. idx_size1 = 17;
  32. init_axes = {3, 1};
  33. checker.set_proxy({init_axes})
  34. .execs({{3, 17, 9, 23}, {3, 100, 9, 100}, {100}, {100}})
  35. .execs({{3, 17, 29, 30}, {3, 66, 29, 99}, {99}, {66}});
  36. }
  37. TEST_F(CUDA, BATCHED_MESH_INDEXING) {
  38. Checker<BatchedMeshIndexing> checker(handle_cuda());
  39. size_t idx_size0, idx_size1;
  40. IndexRNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  41. checker.set_dtype(0, dtype::Float32())
  42. .set_dtype(1, dtype::Float32())
  43. .set_dtype(2, dtype::Int32())
  44. .set_dtype(3, dtype::Int32())
  45. .set_rng(2, &rng0)
  46. .set_rng(3, &rng1);
  47. SmallVector<size_t> init_axes;
  48. init_axes = {1};
  49. idx_size0 = 5;
  50. checker.set_proxy({init_axes}).execs({{2, 5}, {2, 3}, {2, 3}});
  51. idx_size0 = 23;
  52. idx_size1 = 17;
  53. init_axes = {3, 1};
  54. checker.set_proxy({init_axes})
  55. .execs({{3, 17, 9, 23}, {3, 100, 9, 100}, {3, 100}, {3, 100}})
  56. .execs({{3, 17, 29, 30}, {3, 66, 29, 99}, {3, 99}, {3, 66}});
  57. idx_size0 = 5;
  58. init_axes = {1};
  59. TensorLayout index_layout{TensorShape{1, 3}, dtype::Int32()};
  60. index_layout = index_layout.broadcast({2, 3});
  61. checker.set_proxy({init_axes})
  62. .execl({TensorLayout{TensorShape{2, idx_size0}, dtype::Float32()},
  63. TensorLayout{TensorShape{2, 3}, dtype::Float32()}, index_layout});
  64. }
  65. namespace {
  66. template <typename T, typename RNG>
  67. void run_modify_test(Handle* handle) {
  68. Checker<T> checker(handle);
  69. size_t idx_size0, idx_size1;
  70. RNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  71. checker.set_dtype(0, dtype::Float32())
  72. .set_dtype(1, dtype::Float32())
  73. .set_dtype(2, dtype::Int32())
  74. .set_dtype(3, dtype::Int32())
  75. .set_rng(2, &rng0)
  76. .set_rng(3, &rng1);
  77. SmallVector<size_t> init_axes;
  78. idx_size0 = 230;
  79. init_axes = {0};
  80. checker.set_proxy({init_axes})
  81. .execs({{230}, {100}, {100}})
  82. .execs({{230, 5}, {100, 5}, {100}});
  83. idx_size0 = 30;
  84. init_axes = {1};
  85. checker.set_proxy(init_axes)
  86. .execs({{2, 30}, {2, 10}, {10}})
  87. .execs({{2, 30, 5}, {2, 20, 5}, {20}})
  88. .execs({{2, 30, 5, 7}, {2, 25, 5, 7}, {25}});
  89. }
  90. template <typename T, typename RNG>
  91. void run_batch_modify_test(Handle* handle) {
  92. Checker<T> checker(handle);
  93. size_t idx_size0, idx_size1;
  94. RNG rng0{idx_size0, 2}, rng1{idx_size1, 3};
  95. checker.set_dtype(0, dtype::Float32())
  96. .set_dtype(1, dtype::Float32())
  97. .set_dtype(2, dtype::Int32())
  98. .set_dtype(3, dtype::Int32())
  99. .set_rng(2, &rng0)
  100. .set_rng(3, &rng1);
  101. SmallVector<size_t> init_axes;
  102. init_axes = {1};
  103. idx_size0 = 5;
  104. checker.set_proxy({init_axes}).execs({{2, 5}, {2, 3}, {2, 3}});
  105. idx_size0 = 23;
  106. idx_size1 = 17;
  107. init_axes = {3, 1};
  108. checker.set_proxy({init_axes})
  109. .execs({{3, 17, 9, 23}, {3, 10, 9, 10}, {3, 10}, {3, 10}})
  110. .execs({{3, 17, 29, 30}, {3, 11, 29, 22}, {3, 22}, {3, 11}});
  111. }
  112. } // namespace
  113. TEST_F(CUDA, MESH_MODIFY_INCREMENT) {
  114. run_modify_test<IncrMeshIndexing, IndexRNG>(handle_cuda());
  115. }
  116. TEST_F(CUDA, MESH_MODIFY_SETTING) {
  117. run_modify_test<SetMeshIndexing, mesh_indexing::NoReplacementIndexRNG>(
  118. handle_cuda());
  119. }
  120. TEST_F(CUDA, BATCHED_MESH_MODIFY_INCREMENT) {
  121. run_batch_modify_test<BatchedIncrMeshIndexing, IndexRNG>(handle_cuda());
  122. }
  123. TEST_F(CUDA, BATCHED_MESH_MODIFY_SETTING) {
  124. run_batch_modify_test<BatchedSetMeshIndexing, mesh_indexing::NoReplacementIndexRNG>(
  125. handle_cuda());
  126. }