You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

remap.h 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #pragma once
  2. #include <iostream>
  3. #include "megdnn/basic_types.h"
  4. #include "megdnn/opr_param_defs.h"
  5. #include "./rng.h"
  6. namespace megdnn {
  7. namespace test {
  8. namespace remap {
  9. struct TestArg {
  10. param::Remap param;
  11. TensorShape src;
  12. TensorShape map_xy;
  13. TensorShape dst;
  14. TestArg(param::Remap param_, TensorShape src_, TensorShape map_xy_,
  15. TensorShape dst_)
  16. : param(param_), src(src_), map_xy(map_xy_), dst(dst_) {}
  17. };
  18. static inline std::vector<TestArg> get_nchw_args() {
  19. std::vector<TestArg> args;
  20. param::Remap param;
  21. std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NCHW};
  22. std::vector<param::Remap::InterpolationMode> interp_mode_vec = {
  23. param::Remap::InterpolationMode::NEAREST,
  24. param::Remap::InterpolationMode::LINEAR};
  25. std::vector<param::Remap::BorderMode> border_mode_vec = {
  26. param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
  27. param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
  28. param::Remap::BorderMode::REPLICATE};
  29. // current do not test this.
  30. std::vector<float> scalar;
  31. for (auto fmt : format_vec) {
  32. for (auto interp_mode : interp_mode_vec) {
  33. for (auto border_type : border_mode_vec) {
  34. param.format = fmt;
  35. param.imode = interp_mode;
  36. param.border_type = border_type;
  37. args.emplace_back(
  38. param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2},
  39. TensorShape{70000, 1, 2, 2});
  40. args.emplace_back(
  41. param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2},
  42. TensorShape{1, 1, 2, 2});
  43. args.emplace_back(
  44. param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2},
  45. TensorShape{1, 3, 2, 2});
  46. args.emplace_back(
  47. param, TensorShape{1, 10, 100, 100},
  48. TensorShape{1, 100, 100, 2}, TensorShape{1, 10, 100, 100});
  49. args.emplace_back(
  50. param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2},
  51. TensorShape{2, 4, 100, 200});
  52. args.emplace_back(
  53. param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2},
  54. TensorShape{2, 4, 20, 30});
  55. args.emplace_back(
  56. param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2},
  57. TensorShape{2, 4, 20, 30});
  58. }
  59. }
  60. }
  61. return args;
  62. }
  63. static inline std::vector<TestArg> get_nhwcd4_args() {
  64. std::vector<TestArg> args;
  65. param::Remap param;
  66. param.format = param::Remap::Format::NHWCD4;
  67. param.imode = param::Remap::InterpolationMode::LINEAR;
  68. param.border_type = param::Remap::BorderMode::CONSTANT;
  69. // FIXME: when fractional part of bval is not zero, naive and opencl bankend may
  70. // have different rounding result
  71. param.scalar = 77;
  72. args.emplace_back(
  73. param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
  74. TensorShape{2, 4, 1, 6, 4});
  75. args.emplace_back(
  76. param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
  77. TensorShape{2, 2, 1, 3, 4});
  78. param.imode = param::Remap::InterpolationMode::NEAREST;
  79. args.emplace_back(
  80. param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
  81. TensorShape{2, 4, 1, 6, 4});
  82. args.emplace_back(
  83. param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
  84. TensorShape{2, 2, 1, 3, 4});
  85. return args;
  86. }
  87. static inline std::vector<TestArg> get_nhwc_args() {
  88. std::vector<TestArg> args;
  89. param::Remap param;
  90. std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NHWC};
  91. std::vector<param::Remap::InterpolationMode> interp_mode_vec = {
  92. param::Remap::InterpolationMode::NEAREST,
  93. param::Remap::InterpolationMode::LINEAR};
  94. std::vector<param::Remap::BorderMode> border_mode_vec = {
  95. param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
  96. param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
  97. param::Remap::BorderMode::REPLICATE};
  98. // current do not test this.
  99. std::vector<float> scalar;
  100. for (auto fmt : format_vec) {
  101. for (auto interp_mode : interp_mode_vec) {
  102. for (auto border_type : border_mode_vec) {
  103. param.format = fmt;
  104. param.imode = interp_mode;
  105. param.border_type = border_type;
  106. param.scalar = 12.f;
  107. args.emplace_back(
  108. param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2},
  109. TensorShape{70000, 2, 2, 1});
  110. args.emplace_back(
  111. param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2},
  112. TensorShape{1, 2, 2, 1});
  113. args.emplace_back(
  114. param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2},
  115. TensorShape{1, 2, 2, 3});
  116. args.emplace_back(
  117. param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2},
  118. TensorShape{1, 2, 2, 66});
  119. args.emplace_back(
  120. param, TensorShape{1, 100, 100, 10},
  121. TensorShape{1, 100, 100, 2}, TensorShape{1, 100, 100, 10});
  122. args.emplace_back(
  123. param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2},
  124. TensorShape{2, 100, 200, 4});
  125. args.emplace_back(
  126. param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2},
  127. TensorShape{2, 20, 30, 4});
  128. args.emplace_back(
  129. param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2},
  130. TensorShape{2, 20, 30, 4});
  131. }
  132. }
  133. }
  134. return args;
  135. }
  136. } // namespace remap
  137. } // namespace test
  138. } // namespace megdnn
  139. // vim: syntax=cpp.doxygen