You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #include "test/common/resize.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/task_record_check.h"
  4. #include "test/fallback/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(FALLBACK, RESIZE_CV) {
  8. using namespace resize;
  9. std::vector<TestArg> args = get_cv_args();
  10. Checker<Resize> checker(handle());
  11. for (auto&& arg : args) {
  12. checker.set_param(arg.param)
  13. .set_dtype(0, dtype::Uint8())
  14. .set_dtype(1, dtype::Uint8())
  15. .set_epsilon(1 + 1e-3)
  16. .execs({arg.src, arg.dst});
  17. }
  18. for (auto&& arg : args) {
  19. checker.set_param(arg.param)
  20. .set_dtype(0, dtype::Float32())
  21. .set_dtype(1, dtype::Float32())
  22. .execs({arg.src, arg.dst});
  23. }
  24. }
  25. TEST_F(FALLBACK, RESIZE_CV_RECORD) {
  26. using namespace resize;
  27. std::vector<TestArg> args = get_cv_args();
  28. TaskRecordChecker<Resize> checker(1);
  29. for (auto&& arg : args) {
  30. checker.set_param(arg.param)
  31. .set_dtype(0, dtype::Uint8())
  32. .set_dtype(1, dtype::Uint8())
  33. .set_epsilon(1 + 1e-3)
  34. .execs({arg.src, arg.dst});
  35. }
  36. for (auto&& arg : args) {
  37. checker.set_param(arg.param)
  38. .set_dtype(0, dtype::Float32())
  39. .set_dtype(1, dtype::Float32())
  40. .execs({arg.src, arg.dst});
  41. }
  42. }
  43. TEST_F(FALLBACK, RESIZE) {
  44. using namespace resize;
  45. std::vector<TestArg> args = get_args();
  46. Checker<Resize> checker(handle());
  47. for (auto&& arg : args) {
  48. checker.set_param(arg.param)
  49. .set_dtype(0, dtype::Uint8())
  50. .set_dtype(1, dtype::Uint8())
  51. .set_epsilon(1 + 1e-3)
  52. .execs({arg.src, arg.dst});
  53. }
  54. for (auto&& arg : args) {
  55. checker.set_param(arg.param)
  56. .set_dtype(0, dtype::Float32())
  57. .set_dtype(1, dtype::Float32())
  58. .execs({arg.src, arg.dst});
  59. }
  60. }
  61. TEST_F(FALLBACK, RESIZE_RECORD) {
  62. using namespace resize;
  63. std::vector<TestArg> args = get_args();
  64. TaskRecordChecker<Resize> checker(1);
  65. for (auto&& arg : args) {
  66. checker.set_param(arg.param)
  67. .set_dtype(0, dtype::Uint8())
  68. .set_dtype(1, dtype::Uint8())
  69. .set_epsilon(1 + 1e-3)
  70. .execs({arg.src, arg.dst});
  71. }
  72. for (auto&& arg : args) {
  73. checker.set_param(arg.param)
  74. .set_dtype(0, dtype::Float32())
  75. .set_dtype(1, dtype::Float32())
  76. .execs({arg.src, arg.dst});
  77. }
  78. }
  79. TEST_F(FALLBACK, RESIZE_NCHW_WITH_STRIDE) {
  80. param::Resize param;
  81. param.format = param::Resize::Format::NCHW;
  82. param.imode = param::Resize::InterpolationMode::LINEAR;
  83. Checker<Resize> checker(handle());
  84. checker.set_epsilon(1 + 1e-3).set_param(param);
  85. auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
  86. TensorShape dst_shape, DType dtype) {
  87. checker.set_dtype(0, dtype).set_dtype(1, dtype).execl(
  88. {{src_shape, src_layout, dtype}, {dst_shape, dtype}});
  89. };
  90. for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8()}) {
  91. run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
  92. run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
  93. run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
  94. run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
  95. run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
  96. }
  97. }
  98. TEST_F(FALLBACK, RESIZE_NCHW_WITH_STRIDE_RECORD) {
  99. param::Resize param;
  100. param.format = param::Resize::Format::NCHW;
  101. param.imode = param::Resize::InterpolationMode::LINEAR;
  102. TaskRecordChecker<Resize> checker(1);
  103. checker.set_epsilon(1 + 1e-3).set_param(param);
  104. auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
  105. TensorShape dst_shape, DType dtype) {
  106. checker.set_dtype(0, dtype).set_dtype(1, dtype).execl(
  107. {{src_shape, src_layout, dtype}, {dst_shape, dtype}});
  108. };
  109. for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8()}) {
  110. run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
  111. run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
  112. run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
  113. run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
  114. run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
  115. }
  116. }
  117. TEST_F(FALLBACK, RESIZE_NCHW4) {
  118. using namespace resize;
  119. auto args = get_nchw4_args();
  120. Checker<Resize> checker(handle());
  121. for (auto&& arg : args) {
  122. checker.set_param(arg.param)
  123. .set_dtype(0, dtype::QuantizedS8(1.0f))
  124. .set_dtype(1, dtype::QuantizedS8(1.0f))
  125. .set_epsilon(1 + 1e-3)
  126. .execs({arg.src, arg.dst});
  127. }
  128. }
  129. TEST_F(FALLBACK, RESIZE_NCHW4_RECORD) {
  130. using namespace resize;
  131. auto args = get_nchw4_args();
  132. TaskRecordChecker<Resize> checker(1);
  133. for (auto&& arg : args) {
  134. checker.set_param(arg.param)
  135. .set_dtype(0, dtype::QuantizedS8(1.0f))
  136. .set_dtype(1, dtype::QuantizedS8(1.0f))
  137. .set_epsilon(1 + 1e-3)
  138. .execs({arg.src, arg.dst});
  139. }
  140. }
  141. namespace {
  142. static void set_nchw_args(resize::IMode imode, std::vector<resize::TestArg>& args) {
  143. param::Resize param;
  144. param.format = param::Resize::Format::NCHW;
  145. param.imode = imode;
  146. rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
  147. args.emplace_back(
  148. param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul},
  149. TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul});
  150. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 20, 20});
  151. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 7, 9});
  152. args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
  153. args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
  154. }
  155. } // namespace
  156. TEST_F(FALLBACK, RESIZE_NCHW_FP32) {
  157. std::vector<resize::TestArg> args;
  158. set_nchw_args(resize::IMode::INTER_LINEAR, args);
  159. set_nchw_args(resize::IMode::INTER_NEAREST, args);
  160. Checker<Resize> checker(handle());
  161. for (auto&& arg : args) {
  162. checker.set_param(arg.param)
  163. .set_dtype(0, dtype::Float32())
  164. .set_dtype(1, dtype::Float32())
  165. .execs({arg.src, arg.dst});
  166. }
  167. }
  168. TEST_F(FALLBACK, RESIZE_NCHW44_FP32) {
  169. std::vector<resize::TestArg> args = resize::get_nchw44_args();
  170. Checker<Resize> checker(handle());
  171. for (auto&& arg : args) {
  172. checker.set_param(arg.param)
  173. .set_dtype(0, dtype::Float32())
  174. .set_dtype(1, dtype::Float32())
  175. .execs({arg.src, arg.dst});
  176. }
  177. }
  178. } // namespace test
  179. } // namespace megdnn
  180. // vim: syntax=cpp.doxygen