You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. /**
  2. * \file dnn/test/fallback/resize.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/resize.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/task_record_check.h"
  14. #include "test/fallback/fixture.h"
  15. namespace megdnn {
  16. namespace test {
  17. TEST_F(FALLBACK, RESIZE_CV) {
  18. using namespace resize;
  19. std::vector<TestArg> args = get_cv_args();
  20. Checker<Resize> checker(handle());
  21. for (auto&& arg : args) {
  22. checker.set_param(arg.param)
  23. .set_dtype(0, dtype::Uint8())
  24. .set_dtype(1, dtype::Uint8())
  25. .set_epsilon(1 + 1e-3)
  26. .execs({arg.src, arg.dst});
  27. }
  28. for (auto&& arg : args) {
  29. checker.set_param(arg.param)
  30. .set_dtype(0, dtype::Float32())
  31. .set_dtype(1, dtype::Float32())
  32. .execs({arg.src, arg.dst});
  33. }
  34. }
  35. TEST_F(FALLBACK, RESIZE_CV_RECORD) {
  36. using namespace resize;
  37. std::vector<TestArg> args = get_cv_args();
  38. TaskRecordChecker<Resize> checker(1);
  39. for (auto&& arg : args) {
  40. checker.set_param(arg.param)
  41. .set_dtype(0, dtype::Uint8())
  42. .set_dtype(1, dtype::Uint8())
  43. .set_epsilon(1 + 1e-3)
  44. .execs({arg.src, arg.dst});
  45. }
  46. for (auto&& arg : args) {
  47. checker.set_param(arg.param)
  48. .set_dtype(0, dtype::Float32())
  49. .set_dtype(1, dtype::Float32())
  50. .execs({arg.src, arg.dst});
  51. }
  52. }
  53. TEST_F(FALLBACK, RESIZE) {
  54. using namespace resize;
  55. std::vector<TestArg> args = get_args();
  56. Checker<Resize> checker(handle());
  57. for (auto&& arg : args) {
  58. checker.set_param(arg.param)
  59. .set_dtype(0, dtype::Uint8())
  60. .set_dtype(1, dtype::Uint8())
  61. .set_epsilon(1 + 1e-3)
  62. .execs({arg.src, arg.dst});
  63. }
  64. for (auto&& arg : args) {
  65. checker.set_param(arg.param)
  66. .set_dtype(0, dtype::Float32())
  67. .set_dtype(1, dtype::Float32())
  68. .execs({arg.src, arg.dst});
  69. }
  70. }
  71. TEST_F(FALLBACK, RESIZE_RECORD) {
  72. using namespace resize;
  73. std::vector<TestArg> args = get_args();
  74. TaskRecordChecker<Resize> checker(1);
  75. for (auto&& arg : args) {
  76. checker.set_param(arg.param)
  77. .set_dtype(0, dtype::Uint8())
  78. .set_dtype(1, dtype::Uint8())
  79. .set_epsilon(1 + 1e-3)
  80. .execs({arg.src, arg.dst});
  81. }
  82. for (auto&& arg : args) {
  83. checker.set_param(arg.param)
  84. .set_dtype(0, dtype::Float32())
  85. .set_dtype(1, dtype::Float32())
  86. .execs({arg.src, arg.dst});
  87. }
  88. }
  89. TEST_F(FALLBACK, RESIZE_NCHW_WITH_STRIDE) {
  90. param::Resize param;
  91. param.format = param::Resize::Format::NCHW;
  92. param.imode = param::Resize::InterpolationMode::LINEAR;
  93. Checker<Resize> checker(handle());
  94. checker.set_epsilon(1 + 1e-3).set_param(param);
  95. auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
  96. TensorShape dst_shape, DType dtype) {
  97. checker.set_dtype(0, dtype).set_dtype(1, dtype).execl(
  98. {{src_shape, src_layout, dtype}, {dst_shape, dtype}});
  99. };
  100. for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8()}) {
  101. run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
  102. run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
  103. run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
  104. run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
  105. run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
  106. }
  107. }
  108. TEST_F(FALLBACK, RESIZE_NCHW_WITH_STRIDE_RECORD) {
  109. param::Resize param;
  110. param.format = param::Resize::Format::NCHW;
  111. param.imode = param::Resize::InterpolationMode::LINEAR;
  112. TaskRecordChecker<Resize> checker(1);
  113. checker.set_epsilon(1 + 1e-3).set_param(param);
  114. auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
  115. TensorShape dst_shape, DType dtype) {
  116. checker.set_dtype(0, dtype).set_dtype(1, dtype).execl(
  117. {{src_shape, src_layout, dtype}, {dst_shape, dtype}});
  118. };
  119. for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8()}) {
  120. run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
  121. run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
  122. run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
  123. run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
  124. run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
  125. }
  126. }
  127. TEST_F(FALLBACK, RESIZE_NCHW4) {
  128. using namespace resize;
  129. auto args = get_nchw4_args();
  130. Checker<Resize> checker(handle());
  131. for (auto&& arg : args) {
  132. checker.set_param(arg.param)
  133. .set_dtype(0, dtype::QuantizedS8(1.0f))
  134. .set_dtype(1, dtype::QuantizedS8(1.0f))
  135. .set_epsilon(1 + 1e-3)
  136. .execs({arg.src, arg.dst});
  137. }
  138. }
  139. TEST_F(FALLBACK, RESIZE_NCHW4_RECORD) {
  140. using namespace resize;
  141. auto args = get_nchw4_args();
  142. TaskRecordChecker<Resize> checker(1);
  143. for (auto&& arg : args) {
  144. checker.set_param(arg.param)
  145. .set_dtype(0, dtype::QuantizedS8(1.0f))
  146. .set_dtype(1, dtype::QuantizedS8(1.0f))
  147. .set_epsilon(1 + 1e-3)
  148. .execs({arg.src, arg.dst});
  149. }
  150. }
  151. } // namespace test
  152. } // namespace megdnn
  153. // vim: syntax=cpp.doxygen