You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/nn.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/random_state.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. TEST_F(NAIVE, POOLING_QUANTIZED) {
  8. using Mode = Pooling::Param::Mode;
  9. Checker<Pooling> checker(handle(), /* check_dispatch */ false);
  10. Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2};
  11. auto dt = dtype::Quantized8Asymm(0.1f, (uint8_t)128);
  12. Testcase input{
  13. TensorValue({1, 1, 3, 3}, dt, {90, 136, 85, 48, 9, 226, 118, 109, 87}), {}};
  14. checker.set_param(param).exect(
  15. input, Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {90, 136, 118, 226})});
  16. param = {Mode::AVERAGE, 1, 1, 2, 2, 2, 2};
  17. checker.set_param(param).exect(
  18. input, Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {119, 119, 106, 108})});
  19. param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2, 2, 2};
  20. checker.set_param(param).exect(
  21. input, Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {90, 111, 83, 108})});
  22. auto dt32 = dtype::QuantizedS32(0.233f);
  23. Testcase input32{
  24. TensorValue(
  25. {1, 1, 3, 3}, dt32,
  26. {12315, 10086, 10010, 12306, 23333, 19191, 9987, 12450, 12345}),
  27. {}};
  28. param = {Mode::MAX, 1, 1, 2, 2, 2, 2};
  29. checker.set_param(param).exect(
  30. input32,
  31. Testcase{
  32. {}, TensorValue({1, 1, 2, 2}, dt32, {12315, 10086, 12306, 23333})});
  33. }
  34. TEST_F(NAIVE, POOLING_QUANTIZED_Q4) {
  35. using Mode = Pooling::Param::Mode;
  36. Checker<Pooling> checker(handle(), /* check_dispatch */ false);
  37. {
  38. auto q4_dt = dtype::QuantizedS4(1.f);
  39. std::vector<int> i8_src_vec{1, 2, 3, 4, 5, 6, 7, -1, -2};
  40. std::vector<int> i8_max_dst_vec{1, 3, 7, 6};
  41. std::vector<int> i8_avg_dst_vec{0, 1, 3, 2};
  42. std::vector<int> i8_avg_exclu_dst_vec{1, 3, 6, 2};
  43. Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2};
  44. Testcase input{TensorValueLowbit4({1, 1, 3, 3}, q4_dt, i8_src_vec), {}};
  45. checker.set_param(param).exect(
  46. input,
  47. Testcase{{}, TensorValueLowbit4({1, 1, 2, 2}, q4_dt, i8_max_dst_vec)});
  48. param = {Mode::AVERAGE, 1, 1, 2, 2, 2, 2};
  49. checker.set_param(param).exect(
  50. input,
  51. Testcase{{}, TensorValueLowbit4({1, 1, 2, 2}, q4_dt, i8_avg_dst_vec)});
  52. param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2, 2, 2};
  53. checker.set_param(param).exect(
  54. input,
  55. Testcase{
  56. {},
  57. TensorValueLowbit4({1, 1, 2, 2}, q4_dt, i8_avg_exclu_dst_vec)});
  58. }
  59. {
  60. auto u4_dt = dtype::Quantized4Asymm(0.1f, 3);
  61. std::vector<int> u8_src_vec{1, 2, 3, 4, 5, 6, 7, 8, 9};
  62. std::vector<int> u8_max_dst_vec{1, 3, 7, 9};
  63. std::vector<int> u8_avg_dst_vec{3, 3, 4, 7};
  64. std::vector<int> u8_avg_exclu_dst_vec{1, 3, 6, 7};
  65. Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2};
  66. Testcase input{TensorValueLowbit4({1, 1, 3, 3}, u4_dt, u8_src_vec), {}};
  67. checker.set_param(param).exect(
  68. input,
  69. Testcase{{}, TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_max_dst_vec)});
  70. param = {Mode::AVERAGE, 1, 1, 2, 2, 2, 2};
  71. checker.set_param(param).exect(
  72. input,
  73. Testcase{{}, TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_dst_vec)});
  74. param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2, 2, 2};
  75. checker.set_param(param).exect(
  76. input,
  77. Testcase{
  78. {},
  79. TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_exclu_dst_vec)});
  80. }
  81. }
  82. TEST_F(NAIVE, POOLING_INT_AVERAGE) {
  83. using Mode = Pooling::Param::Mode;
  84. Checker<Pooling> checker(handle(), /* check_dispatch */ false);
  85. auto dt = dtype::Int8();
  86. Pooling::Param param = {Mode::AVERAGE, 0, 0, 1, 1, 2, 2};
  87. Testcase input_positive{
  88. TensorValue(
  89. {1, 1, 3, 3}, dt, {127, 127, 127, 127, 127, 127, 127, 127, 127}),
  90. {}};
  91. Testcase input_negative{
  92. TensorValue(
  93. {1, 1, 3, 3}, dt,
  94. {-127, -127, -127, -127, -127, -127, -127, -127, -127}),
  95. {}};
  96. checker.set_param(param).exect(
  97. input_positive,
  98. Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})});
  99. checker.set_param(param).exect(
  100. input_negative,
  101. Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})});
  102. param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 0, 0, 1, 1, 2, 2};
  103. checker.set_param(param).exect(
  104. input_positive,
  105. Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})});
  106. checker.set_param(param).exect(
  107. input_negative,
  108. Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})});
  109. }
  110. // vim: syntax=cpp.doxygen