You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #include "test/common/resize.h"
  2. #include "test/common/checker.h"
  3. #include "test/cpu/fixture.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CPU, RESIZE_CV) {
  7. using namespace resize;
  8. std::vector<TestArg> args = get_cv_args();
  9. Checker<Resize> checker(handle());
  10. for (auto&& arg : args) {
  11. checker.set_param(arg.param)
  12. .set_dtype(0, dtype::Uint8())
  13. .set_dtype(1, dtype::Uint8())
  14. .set_epsilon(1 + 1e-3)
  15. .execs({arg.src, arg.dst});
  16. }
  17. for (auto&& arg : args) {
  18. checker.set_param(arg.param)
  19. .set_dtype(0, dtype::Float32())
  20. .set_dtype(1, dtype::Float32())
  21. .execs({arg.src, arg.dst});
  22. }
  23. }
  24. TEST_F(CPU, RESIZE) {
  25. using namespace resize;
  26. std::vector<TestArg> args = get_args();
  27. Checker<Resize> checker(handle());
  28. for (auto&& arg : args) {
  29. checker.set_param(arg.param)
  30. .set_dtype(0, dtype::Uint8())
  31. .set_dtype(1, dtype::Uint8())
  32. .set_epsilon(1 + 1e-3)
  33. .execs({arg.src, arg.dst});
  34. }
  35. for (auto&& arg : args) {
  36. checker.set_param(arg.param)
  37. .set_dtype(0, dtype::Float32())
  38. .set_dtype(1, dtype::Float32())
  39. .execs({arg.src, arg.dst});
  40. }
  41. }
  42. TEST_F(CPU, RESIZE_NCHW_WITH_STRIDE) {
  43. param::Resize param;
  44. param.format = param::Resize::Format::NCHW;
  45. param.imode = param::Resize::InterpolationMode::LINEAR;
  46. Checker<Resize> checker(handle());
  47. checker.set_epsilon(1 + 1e-3).set_param(param);
  48. auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
  49. TensorShape dst_shape, DType dtype) {
  50. checker.set_dtype(0, dtype).set_dtype(1, dtype).execl(
  51. {{src_shape, src_layout, dtype}, {dst_shape, dtype}});
  52. };
  53. for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8()}) {
  54. run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
  55. run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
  56. run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
  57. run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
  58. run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
  59. }
  60. }
  61. TEST_F(CPU, RESIZE_NCHW4) {
  62. using namespace resize;
  63. auto args = get_nchw4_args();
  64. Checker<Resize> checker(handle());
  65. for (auto&& arg : args) {
  66. checker.set_param(arg.param)
  67. .set_dtype(0, dtype::QuantizedS8(1.0f))
  68. .set_dtype(1, dtype::QuantizedS8(1.0f))
  69. .set_epsilon(1 + 1e-3)
  70. .execs({arg.src, arg.dst});
  71. }
  72. }
  73. } // namespace test
  74. } // namespace megdnn
  75. // vim: syntax=cpp.doxygen