You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.h 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #pragma once
  2. #include <cstddef>
  3. #include "megdnn/basic_types.h"
  4. #include "megdnn/opr_param_defs.h"
  5. namespace megdnn {
  6. namespace test {
  7. namespace pooling {
  8. struct TestArg {
  9. param::Pooling param;
  10. TensorShape ishape;
  11. TestArg(param::Pooling param, TensorShape ishape) : param(param), ishape(ishape) {}
  12. };
  13. inline std::vector<TestArg> get_args() {
  14. std::vector<TestArg> args;
  15. using Param = param::Pooling;
  16. using Mode = param::Pooling::Mode;
  17. // ppssww
  18. for (size_t i = 32; i < 40; ++i) {
  19. args.emplace_back(
  20. Param{Mode::AVERAGE, 1, 1, 2, 2, 2, 2}, TensorShape{2, 3, i, i + 1});
  21. /* reserved for future test */
  22. /*
  23. args.emplace_back(Param{Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2,
  24. 2, 2}, TensorShape{2, 3, i, i+1});
  25. */
  26. args.emplace_back(
  27. Param{Mode::MAX, 1, 1, 2, 2, 2, 2}, TensorShape{2, 3, i, i + 1});
  28. }
  29. for (size_t i = 32; i < 40; ++i) {
  30. args.emplace_back(
  31. Param{Mode::MAX, 1, 1, 2, 2, 3, 3}, TensorShape{2, 3, i, i + 1});
  32. }
  33. for (uint32_t ph : {0, 1, 2})
  34. for (uint32_t pw : {0, 1, 2}) {
  35. args.emplace_back(
  36. Param{Mode::MAX, ph, pw, 1, 1, 3, 3}, TensorShape{2, 3, 20, 22});
  37. }
  38. // small shape for float16
  39. for (size_t i = 5; i < 10; ++i) {
  40. args.emplace_back(
  41. Param{Mode::AVERAGE, 1, 1, 2, 2, 2, 2}, TensorShape{2, 3, i, i + 1});
  42. /* reserved for future test */
  43. /*
  44. args.emplace_back(Param{Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2,
  45. 2, 2}, TensorShape{2, 3, i, i+1});
  46. */
  47. }
  48. for (uint32_t ph : {0, 1, 2})
  49. for (uint32_t pw : {0, 1, 2}) {
  50. args.emplace_back(
  51. Param{Mode::MAX, ph, pw, 1, 1, 3, 3}, TensorShape{1, 2, 10, 11});
  52. }
  53. return args;
  54. }
  55. } // namespace pooling
  56. } // namespace test
  57. } // namespace megdnn
  58. // vim: syntax=cpp.doxygen