You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

remap.h 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. /**
  2. * \file dnn/test/common/remap.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <iostream>
  14. #include "megdnn/basic_types.h"
  15. #include "megdnn/opr_param_defs.h"
  16. #include "./rng.h"
  17. namespace megdnn {
  18. namespace test {
  19. namespace remap {
  20. struct TestArg {
  21. param::Remap param;
  22. TensorShape src;
  23. TensorShape map_xy;
  24. TensorShape dst;
  25. TestArg(param::Remap param_, TensorShape src_, TensorShape map_xy_,
  26. TensorShape dst_)
  27. : param(param_), src(src_), map_xy(map_xy_), dst(dst_) {}
  28. };
  29. static inline std::vector<TestArg> get_nchw_args() {
  30. std::vector<TestArg> args;
  31. param::Remap param;
  32. std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NCHW};
  33. std::vector<param::Remap::InterpolationMode> interp_mode_vec = {
  34. param::Remap::InterpolationMode::NEAREST,
  35. param::Remap::InterpolationMode::LINEAR};
  36. std::vector<param::Remap::BorderMode> border_mode_vec = {
  37. param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
  38. param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
  39. param::Remap::BorderMode::REPLICATE};
  40. // current do not test this.
  41. std::vector<float> scalar;
  42. for (auto fmt : format_vec) {
  43. for (auto interp_mode : interp_mode_vec) {
  44. for (auto border_type : border_mode_vec) {
  45. param.format = fmt;
  46. param.imode = interp_mode;
  47. param.border_type = border_type;
  48. args.emplace_back(
  49. param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2},
  50. TensorShape{70000, 1, 2, 2});
  51. args.emplace_back(
  52. param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2},
  53. TensorShape{1, 1, 2, 2});
  54. args.emplace_back(
  55. param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2},
  56. TensorShape{1, 3, 2, 2});
  57. args.emplace_back(
  58. param, TensorShape{1, 10, 100, 100},
  59. TensorShape{1, 100, 100, 2}, TensorShape{1, 10, 100, 100});
  60. args.emplace_back(
  61. param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2},
  62. TensorShape{2, 4, 100, 200});
  63. args.emplace_back(
  64. param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2},
  65. TensorShape{2, 4, 20, 30});
  66. args.emplace_back(
  67. param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2},
  68. TensorShape{2, 4, 20, 30});
  69. }
  70. }
  71. }
  72. return args;
  73. }
  74. static inline std::vector<TestArg> get_nhwcd4_args() {
  75. std::vector<TestArg> args;
  76. param::Remap param;
  77. param.format = param::Remap::Format::NHWCD4;
  78. param.imode = param::Remap::InterpolationMode::LINEAR;
  79. param.border_type = param::Remap::BorderMode::CONSTANT;
  80. // FIXME: when fractional part of bval is not zero, naive and opencl bankend may
  81. // have different rounding result
  82. param.scalar = 77;
  83. args.emplace_back(
  84. param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
  85. TensorShape{2, 4, 1, 6, 4});
  86. args.emplace_back(
  87. param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
  88. TensorShape{2, 2, 1, 3, 4});
  89. param.imode = param::Remap::InterpolationMode::NEAREST;
  90. args.emplace_back(
  91. param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
  92. TensorShape{2, 4, 1, 6, 4});
  93. args.emplace_back(
  94. param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
  95. TensorShape{2, 2, 1, 3, 4});
  96. return args;
  97. }
  98. static inline std::vector<TestArg> get_nhwc_args() {
  99. std::vector<TestArg> args;
  100. param::Remap param;
  101. std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NHWC};
  102. std::vector<param::Remap::InterpolationMode> interp_mode_vec = {
  103. param::Remap::InterpolationMode::NEAREST,
  104. param::Remap::InterpolationMode::LINEAR};
  105. std::vector<param::Remap::BorderMode> border_mode_vec = {
  106. param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
  107. param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
  108. param::Remap::BorderMode::REPLICATE};
  109. // current do not test this.
  110. std::vector<float> scalar;
  111. for (auto fmt : format_vec) {
  112. for (auto interp_mode : interp_mode_vec) {
  113. for (auto border_type : border_mode_vec) {
  114. param.format = fmt;
  115. param.imode = interp_mode;
  116. param.border_type = border_type;
  117. param.scalar = 12.f;
  118. args.emplace_back(
  119. param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2},
  120. TensorShape{70000, 2, 2, 1});
  121. args.emplace_back(
  122. param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2},
  123. TensorShape{1, 2, 2, 1});
  124. args.emplace_back(
  125. param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2},
  126. TensorShape{1, 2, 2, 3});
  127. args.emplace_back(
  128. param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2},
  129. TensorShape{1, 2, 2, 66});
  130. args.emplace_back(
  131. param, TensorShape{1, 100, 100, 10},
  132. TensorShape{1, 100, 100, 2}, TensorShape{1, 100, 100, 10});
  133. args.emplace_back(
  134. param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2},
  135. TensorShape{2, 100, 200, 4});
  136. args.emplace_back(
  137. param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2},
  138. TensorShape{2, 20, 30, 4});
  139. args.emplace_back(
  140. param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2},
  141. TensorShape{2, 20, 30, 4});
  142. }
  143. }
  144. }
  145. return args;
  146. }
  147. } // namespace remap
  148. } // namespace test
  149. } // namespace megdnn
  150. // vim: syntax=cpp.doxygen