You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. /**
  2. * \file dnn/test/arm_common/resize.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/common/resize.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/task_record_check.h"
  16. namespace megdnn {
  17. namespace test {
  18. using namespace resize;
  19. static void set_nchw_args(IMode imode, std::vector<TestArg>& args) {
  20. param::Resize param;
  21. param.format = param::Resize::Format::NCHW;
  22. param.imode = imode;
  23. rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
  24. args.emplace_back(
  25. param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul},
  26. TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul});
  27. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 20, 20});
  28. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 7, 9});
  29. args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
  30. args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
  31. }
  32. TEST_F(ARM_COMMON, RESIZE_CV) {
  33. std::vector<TestArg> args = get_cv_args();
  34. Checker<Resize> checker(handle());
  35. for (auto&& arg : args) {
  36. checker.set_param(arg.param)
  37. .set_epsilon(1 + 1e-3)
  38. .set_dtype(0, dtype::Uint8())
  39. .set_dtype(1, dtype::Uint8())
  40. .execs({arg.src, arg.dst});
  41. }
  42. for (auto&& arg : args) {
  43. checker.set_param(arg.param)
  44. .set_dtype(0, dtype::Float32())
  45. .set_dtype(1, dtype::Float32())
  46. .execs({arg.src, arg.dst});
  47. }
  48. }
  49. TEST_F(ARM_COMMON, RESIZE_CV_RECORD) {
  50. std::vector<TestArg> args = get_cv_args();
  51. TaskRecordChecker<Resize> checker(0);
  52. for (auto&& arg : args) {
  53. checker.set_param(arg.param)
  54. .set_epsilon(1 + 1e-3)
  55. .set_dtype(0, dtype::Uint8())
  56. .set_dtype(1, dtype::Uint8())
  57. .execs({arg.src, arg.dst});
  58. }
  59. for (auto&& arg : args) {
  60. checker.set_param(arg.param)
  61. .set_dtype(0, dtype::Float32())
  62. .set_dtype(1, dtype::Float32())
  63. .execs({arg.src, arg.dst});
  64. }
  65. }
  66. #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  67. TEST_F(ARM_COMMON, RESIZE_NCHW_FP16) {
  68. std::vector<TestArg> args;
  69. set_nchw_args(IMode::INTER_LINEAR, args);
  70. set_nchw_args(IMode::INTER_NEAREST, args);
  71. Checker<Resize> checker(handle());
  72. for (auto&& arg : args) {
  73. checker.set_param(arg.param)
  74. .set_epsilon(0.01)
  75. .set_dtype(0, dtype::Float16())
  76. .set_dtype(1, dtype::Float16())
  77. .execs({arg.src, arg.dst});
  78. }
  79. }
  80. #endif
  81. TEST_F(ARM_COMMON, RESIZE_NCHW_FP32) {
  82. std::vector<TestArg> args;
  83. set_nchw_args(IMode::INTER_LINEAR, args);
  84. set_nchw_args(IMode::INTER_NEAREST, args);
  85. Checker<Resize> checker(handle());
  86. for (auto&& arg : args) {
  87. checker.set_param(arg.param)
  88. .set_dtype(0, dtype::Float32())
  89. .set_dtype(1, dtype::Float32())
  90. .execs({arg.src, arg.dst});
  91. }
  92. }
  93. TEST_F(ARM_COMMON, RESIZE_NCHW44_FP32) {
  94. std::vector<TestArg> args = get_nchw44_args();
  95. Checker<Resize> checker(handle());
  96. for (auto&& arg : args) {
  97. checker.set_param(arg.param)
  98. .set_dtype(0, dtype::Float32())
  99. .set_dtype(1, dtype::Float32())
  100. .execs({arg.src, arg.dst});
  101. }
  102. }
  103. #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  104. TEST_F(ARM_COMMON, RESIZE_NCHW88_FP16) {
  105. std::vector<TestArg> args = get_nchw88_args();
  106. Checker<Resize> checker(handle());
  107. for (auto&& arg : args) {
  108. checker.set_param(arg.param)
  109. .set_epsilon(0.01)
  110. .set_dtype(0, dtype::Float16())
  111. .set_dtype(1, dtype::Float16())
  112. .execs({arg.src, arg.dst});
  113. }
  114. }
  115. #endif
  116. } // namespace test
  117. } // namespace megdnn
  118. // vim: syntax=cpp.doxygen