You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

relayout.cpp 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #include "megdnn/oprs/general.h"
  2. #include "src/common/relayout_helper.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/relayout.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. using namespace megdnn::relayout;
  9. using namespace test::relayout;
  10. namespace {
  11. TestArg generate_transpose_args(
  12. size_t batch, size_t m, size_t n, size_t c, DType dtype) {
  13. TestArg arg;
  14. arg.src = TensorLayout(
  15. TensorShape{batch, n, m, c},
  16. {static_cast<std::ptrdiff_t>(n * m * c), static_cast<std::ptrdiff_t>(c),
  17. static_cast<std::ptrdiff_t>(n * c), 1},
  18. dtype);
  19. arg.dst = TensorLayout(TensorShape{batch, n, m, c}, dtype);
  20. return arg;
  21. }
  22. } // anonymous namespace
  23. namespace megdnn {
  24. namespace test {
  25. namespace relayout {
  26. void run_test_cv(Handle* handle, size_t CH) {
  27. std::vector<TestArg> args;
  28. for (size_t M = 124; M <= 130; ++M) {
  29. for (size_t N = 124; N <= 130; ++N) {
  30. args.push_back(generate_transpose_args(1, M, N, CH, dtype::Uint8()));
  31. args.push_back(generate_transpose_args(1, M, N, CH, dtype::Int32()));
  32. args.push_back(generate_transpose_args(1, M, N, CH, dtype::Float32()));
  33. args.push_back(generate_transpose_args(3, M, N, CH, dtype::Float32()));
  34. }
  35. }
  36. Checker<Relayout> checker(handle);
  37. for (auto&& arg : args) {
  38. checker.execl({arg.src, arg.dst});
  39. }
  40. }
  41. #define DEF_TEST(name) \
  42. template <> \
  43. void run_test<name>(Handle * handle)
  44. DEF_TEST(cv) {
  45. run_test_cv(handle, 1);
  46. }
  47. DEF_TEST(cv_ch3) {
  48. run_test_cv(handle, 3);
  49. }
  50. DEF_TEST(cv_ch5) {
  51. run_test_cv(handle, 5);
  52. }
  53. DEF_TEST(broadcast) {
  54. std::vector<TestArg> args;
  55. TensorLayout src{{2, 3, 4}, dtype::Float32()}, dst{{2, 3, 4}, dtype::Float32()};
  56. src.stride[0] = 4;
  57. src.stride[1] = 0;
  58. args.emplace_back(src, dst);
  59. // last stride contiguous
  60. args.emplace_back(
  61. TensorLayout({3, 100, 2}, {2, 0, 1}, dtype::Float16()),
  62. TensorLayout({3, 100, 2}, {200, 2, 1}, dtype::Float16()));
  63. Checker<Relayout> checker(handle);
  64. for (auto&& arg : args) {
  65. checker.execl({arg.src, arg.dst});
  66. }
  67. }
  68. DEF_TEST(negative) {
  69. TensorLayout src{{7, 8, 10}, dtype::Float32()}, dst{{7, 8, 10}, dtype::Float32()};
  70. src.stride[0] *= -1;
  71. Checker<Relayout> checker(handle);
  72. checker.execl({src, dst});
  73. }
  74. DEF_TEST(transpose) {
  75. Checker<Relayout> checker(handle);
  76. {
  77. TensorLayout sl({8, 10}, dtype::Int32()), dl({10, 8}, dtype::Int32());
  78. sl = sl.dimshuffle({1, 0});
  79. checker.execl({sl, dl});
  80. checker.execl({dl, sl});
  81. }
  82. {
  83. TensorLayout sl({8, 10, 2}, dtype::Int32()), dl({2, 8, 10}, dtype::Int32());
  84. sl = sl.dimshuffle({2, 0, 1});
  85. checker.execl({sl, dl});
  86. checker.execl({dl, sl});
  87. }
  88. }
  89. #undef DEF_TEST
  90. } // namespace relayout
  91. } // namespace test
  92. } // namespace megdnn
  93. void test::relayout::run_cv_benchmark(Handle* handle) {
  94. auto handle_naive = create_cpu_handle(2);
  95. std::vector<TestArg> args;
  96. args.push_back(generate_transpose_args(1, 255, 256, 1, dtype::Int32()));
  97. args.push_back(generate_transpose_args(1, 513, 1025, 3, dtype::Int32()));
  98. args.push_back(generate_transpose_args(1, 255, 256, 1, dtype::Uint8()));
  99. args.push_back(generate_transpose_args(1, 513, 1025, 3, dtype::Uint8()));
  100. args.push_back(generate_transpose_args(1, 255, 256, 3, dtype::Float32()));
  101. args.push_back(generate_transpose_args(1, 513, 1025, 1, dtype::Float32()));
  102. args.push_back(generate_transpose_args(2, 987, 573, 6, dtype::Float32()));
  103. Benchmarker<Relayout> benchmarker(handle);
  104. Benchmarker<Relayout> benchmarker_naive(handle_naive.get());
  105. Checker<Relayout> checker(handle);
  106. benchmarker_naive.set_times(1).set_display(false);
  107. benchmarker.set_times(1).set_display(false);
  108. for (auto&& arg : args) {
  109. checker.execl({arg.src, arg.dst});
  110. auto t0 = benchmarker.execl({arg.src, arg.dst});
  111. auto t1 = benchmarker_naive.execl({arg.src, arg.dst});
  112. double k = arg.dst.span().dist_byte() * 1e3 / (1024 * 1024 * 1024);
  113. printf("cur=%7.3fms,%5.2fGiB/s naive=%7.3fms,%5.2fGiB/s %s %s\n", t0, k / t0,
  114. t1, k / t1, arg.dst.TensorShape::to_string().c_str(),
  115. arg.dst.dtype.name());
  116. }
  117. }
  118. TEST(RELAYOUT, TRANSPOSE_DET) {
  119. auto run = [](const TensorShape& shape, const std::vector<size_t>& dimshuffle,
  120. bool expect_is_transpose, const TransposeParam& p = {}) {
  121. TensorLayout src{shape, dtype::Float32{}};
  122. src = src.dimshuffle(dimshuffle).collapse_contiguous();
  123. TensorLayout dst{TensorShape{src.total_nr_elems()}, src.dtype};
  124. TransposeParam p_get;
  125. bool succ = is_transpose(src, dst, p_get);
  126. ASSERT_EQ(expect_is_transpose, succ);
  127. if (succ) {
  128. ASSERT_EQ(p_get.batch, p.batch);
  129. ASSERT_EQ(p_get.m, p.m);
  130. ASSERT_EQ(p_get.n, p.n);
  131. ASSERT_EQ(p_get.c, p.c);
  132. }
  133. // swap m, n
  134. succ = is_transpose(dst, src, p_get);
  135. ASSERT_EQ(expect_is_transpose, succ);
  136. if (succ) {
  137. ASSERT_EQ(p_get.batch, p.batch);
  138. ASSERT_EQ(p_get.m, p.n);
  139. ASSERT_EQ(p_get.n, p.m);
  140. ASSERT_EQ(p_get.c, p.c);
  141. }
  142. };
  143. run({2, 3}, {1, 0}, true, {1, 2, 3, 1, 0});
  144. run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5, 0});
  145. run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1, 0});
  146. run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5, 0});
  147. run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1, 0});
  148. run({2, 3, 5}, {2, 1, 0}, false);
  149. run({3, 2, 3, 5}, {3, 2, 1, 0}, false);
  150. }
  151. // vim: syntax=cpp.doxygen