You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #include "test/common/resize.h"
  2. #include "test/arm_common/fixture.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/task_record_check.h"
  5. namespace megdnn {
  6. namespace test {
  7. using namespace resize;
  8. static void set_nchw_args(IMode imode, std::vector<TestArg>& args) {
  9. param::Resize param;
  10. param.format = param::Resize::Format::NCHW;
  11. param.imode = imode;
  12. rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
  13. args.emplace_back(
  14. param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul},
  15. TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul});
  16. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 20, 20});
  17. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 7, 9});
  18. args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
  19. args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
  20. }
  21. TEST_F(ARM_COMMON, RESIZE_CV) {
  22. std::vector<TestArg> args = get_cv_args();
  23. Checker<Resize> checker(handle());
  24. for (auto&& arg : args) {
  25. checker.set_param(arg.param)
  26. .set_epsilon(1 + 1e-3)
  27. .set_dtype(0, dtype::Uint8())
  28. .set_dtype(1, dtype::Uint8())
  29. .execs({arg.src, arg.dst});
  30. }
  31. for (auto&& arg : args) {
  32. checker.set_param(arg.param)
  33. .set_dtype(0, dtype::Float32())
  34. .set_dtype(1, dtype::Float32())
  35. .execs({arg.src, arg.dst});
  36. }
  37. }
  38. TEST_F(ARM_COMMON, RESIZE_CV_RECORD) {
  39. std::vector<TestArg> args = get_cv_args();
  40. TaskRecordChecker<Resize> checker(0);
  41. for (auto&& arg : args) {
  42. checker.set_param(arg.param)
  43. .set_epsilon(1 + 1e-3)
  44. .set_dtype(0, dtype::Uint8())
  45. .set_dtype(1, dtype::Uint8())
  46. .execs({arg.src, arg.dst});
  47. }
  48. for (auto&& arg : args) {
  49. checker.set_param(arg.param)
  50. .set_dtype(0, dtype::Float32())
  51. .set_dtype(1, dtype::Float32())
  52. .execs({arg.src, arg.dst});
  53. }
  54. }
  55. #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  56. TEST_F(ARM_COMMON, RESIZE_NCHW_FP16) {
  57. std::vector<TestArg> args;
  58. set_nchw_args(IMode::INTER_LINEAR, args);
  59. set_nchw_args(IMode::INTER_NEAREST, args);
  60. Checker<Resize> checker(handle());
  61. for (auto&& arg : args) {
  62. checker.set_param(arg.param)
  63. .set_epsilon(0.01)
  64. .set_dtype(0, dtype::Float16())
  65. .set_dtype(1, dtype::Float16())
  66. .execs({arg.src, arg.dst});
  67. }
  68. }
  69. #endif
  70. TEST_F(ARM_COMMON, RESIZE_NCHW_FP32) {
  71. std::vector<TestArg> args;
  72. set_nchw_args(IMode::INTER_LINEAR, args);
  73. set_nchw_args(IMode::INTER_NEAREST, args);
  74. Checker<Resize> checker(handle());
  75. for (auto&& arg : args) {
  76. checker.set_param(arg.param)
  77. .set_dtype(0, dtype::Float32())
  78. .set_dtype(1, dtype::Float32())
  79. .execs({arg.src, arg.dst});
  80. }
  81. }
  82. TEST_F(ARM_COMMON, RESIZE_NCHW44_FP32) {
  83. std::vector<TestArg> args = get_nchw44_args();
  84. Checker<Resize> checker(handle());
  85. for (auto&& arg : args) {
  86. checker.set_param(arg.param)
  87. .set_dtype(0, dtype::Float32())
  88. .set_dtype(1, dtype::Float32())
  89. .execs({arg.src, arg.dst});
  90. }
  91. }
  92. #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  93. TEST_F(ARM_COMMON, RESIZE_NCHW88_FP16) {
  94. std::vector<TestArg> args = get_nchw88_args();
  95. Checker<Resize> checker(handle());
  96. for (auto&& arg : args) {
  97. checker.set_param(arg.param)
  98. .set_epsilon(0.01)
  99. .set_dtype(0, dtype::Float16())
  100. .set_dtype(1, dtype::Float16())
  101. .execs({arg.src, arg.dst});
  102. }
  103. }
  104. #endif
  105. } // namespace test
  106. } // namespace megdnn
  107. // vim: syntax=cpp.doxygen