You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rotate.h 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/opr_param_defs.h"
  4. namespace megdnn {
  5. namespace test {
  6. namespace rotate {
  7. struct TestArg {
  8. param::Rotate param;
  9. TensorShape src;
  10. DType dtype;
  11. TestArg(param::Rotate param, TensorShape src, DType dtype)
  12. : param(param), src(src), dtype(dtype) {}
  13. };
  14. static inline std::vector<TestArg> get_args() {
  15. std::vector<TestArg> args;
  16. param::Rotate cur_param;
  17. for (size_t i = 8; i < 129; i *= 4) {
  18. cur_param.clockwise = true;
  19. args.emplace_back(cur_param, TensorShape{1, i, i, 1}, dtype::Uint8());
  20. args.emplace_back(cur_param, TensorShape{1, i, i, 3}, dtype::Uint8());
  21. args.emplace_back(cur_param, TensorShape{2, i, i, 3}, dtype::Uint8());
  22. args.emplace_back(cur_param, TensorShape{2, i, i, 3}, dtype::Float32());
  23. cur_param.clockwise = false;
  24. args.emplace_back(cur_param, TensorShape{2, i, i, 3}, dtype::Uint8());
  25. args.emplace_back(cur_param, TensorShape{2, i, i, 3}, dtype::Float32());
  26. }
  27. std::vector<std::pair<size_t, size_t>> test_cases = {{23, 28}, {17, 3}, {3, 83}};
  28. for (auto&& item : test_cases) {
  29. for (auto&& CH : {1U, 3U}) {
  30. for (bool clockwise : {false, true}) {
  31. cur_param.clockwise = clockwise;
  32. args.emplace_back(
  33. cur_param, TensorShape{1, item.first, item.second, CH},
  34. dtype::Uint8());
  35. args.emplace_back(
  36. cur_param, TensorShape{1, item.first, item.second, CH},
  37. dtype::Float32());
  38. }
  39. }
  40. }
  41. return args;
  42. }
  43. } // namespace rotate
  44. } // namespace test
  45. } // namespace megdnn
  46. // vim: syntax=cpp.doxygen