You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reduce.cpp 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #include "megdnn/oprs.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/rng.h"
  4. #include "test/cuda/fixture.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. TEST_F(CUDA, REDUCE) {
  8. using Mode = Reduce::Param::Mode;
  9. Checker<Reduce> checker(handle_cuda());
  10. UniformFloatRNG rng(-1.0f, 1.0f);
  11. checker.set_epsilon(1e-2);
  12. checker.set_rng(0, &rng);
  13. checker.set_param({Mode::SUM, 1});
  14. // 1-step
  15. checker.execs({{2, 64, 32}, {}});
  16. // 2-step
  17. checker.execs({{2, 192, 32}, {}});
  18. // 3-step
  19. checker.execs({{2, 4333, 32}, {}});
  20. // single reduce
  21. checker.execs({{2, 1, 1}, {}});
  22. checker.execs({{2, 1 + 1, 1}, {}});
  23. checker.execs({{2, 2048 + 1, 1}, {}});
  24. checker.execs({{2, 2048 * 2048 + 1, 1}, {}});
  25. checker.execs({{2, 1 + 1, 31}, {}});
  26. checker.execs({{2, 16 + 1, 31}, {}});
  27. checker.execs({{2, 16 * 16 + 1, 31}, {}});
  28. checker.execs({{2, 16 * 16 * 16 + 1, 31}, {}});
  29. checker.execs({{2, 16 * 16 * 16 * 16 + 1, 31}, {}});
  30. #if MEGDNN_TEGRA_X1 || MEGDNN_TEGRA_X2
  31. checker.execs({{2, 8 * 16 * 16 * 16 * 16 + 1, 31}, {}});
  32. #else
  33. checker.execs({{2, 16 * 16 * 16 * 16 * 16 + 1, 31}, {}});
  34. #endif
  35. checker.execs({{3, 256 * 256 + 1, 2}, {}});
  36. checker.execs({{3, 128 * 128 + 1, 3}, {}});
  37. checker.execs({{3, 64 * 64 + 1, 7}, {}});
  38. checker.execs({{3, 32 * 32 + 1, 15}, {}});
  39. checker.execs({{3, 512, 500}, {}});
  40. // very large reduce
  41. checker.execs({{1, 4194304, 1}, {}});
  42. // inputs have nan
  43. {
  44. const auto nan = std::numeric_limits<float>::quiet_NaN();
  45. UniformFloatWithValueRNG rng1 =
  46. UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan);
  47. checker.set_allow_invalid_check(true).set_rng(0, &rng1);
  48. for (auto mode : {Mode::MIN, Mode::MAX}) {
  49. checker.set_param({mode, 1});
  50. checker.execs({{2, 64, 32}, {}});
  51. }
  52. checker.set_allow_invalid_check(false);
  53. }
  54. checker.set_rng(0, &rng);
  55. auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype,
  56. Reduce::DataType data_type) {
  57. for (int32_t axis : {0, 1, 2, 3}) {
  58. if (data_type == Reduce::DataType::DEFAULT &&
  59. src_dtype == dtype::Float16()) {
  60. checker.set_epsilon(1e-2);
  61. } else {
  62. checker.set_epsilon(1e-3);
  63. }
  64. Reduce::Param param{mode, axis, data_type};
  65. auto dst_shape = TensorShape{2, 3, 100, 5};
  66. dst_shape[axis] = 1;
  67. checker.set_dtype(0, src_dtype)
  68. .set_dtype(1, dst_dtype)
  69. .set_param(param)
  70. .execs({{2, 3, 100, 5}, dst_shape});
  71. }
  72. };
  73. for (auto mode :
  74. {Mode::SUM, Mode::MEAN, Mode::SUM_SQR, Mode::PRODUCT, Mode::MIN, Mode::MAX}) {
  75. for (auto dtype :
  76. std::vector<DType>{dtype::Float16(), dtype::Float32(), dtype::Int32()}) {
  77. check(mode, dtype, dtype, Reduce::DataType::DEFAULT);
  78. }
  79. check(mode, dtype::Float16(), dtype::Float32(),
  80. Reduce::DataType::FLOAT_O32xC32);
  81. check(mode, dtype::Int32(), dtype::Float32(), Reduce::DataType::FLOAT_O32xC32);
  82. check(mode, dtype::Float16(), dtype::Float16(),
  83. Reduce::DataType::FLOAT_O16xC32);
  84. check(mode, dtype::Float32(), dtype::Float16(),
  85. Reduce::DataType::FLOAT_O16xC32);
  86. ASSERT_THROW(
  87. check(mode, dtype::Int32(), dtype::Float16(),
  88. Reduce::DataType::FLOAT_O16xC32),
  89. MegDNNError);
  90. ASSERT_THROW(
  91. check(mode, dtype::Float16(), dtype::Float16(),
  92. Reduce::DataType::FLOAT_IO16xC32),
  93. MegDNNError);
  94. }
  95. {
  96. // very large reduce for I16CO32
  97. Reduce::Param param{Mode::SUM_SQR, 1, Reduce::Param::DataType::FLOAT_O32xC32};
  98. checker.set_dtype(0, dtype::Float16())
  99. .set_dtype(1, dtype::Float32())
  100. .set_param(param)
  101. .execs({{1, 4194304, 1}, {1, 1, 1}});
  102. }
  103. {
  104. // large reduce_mean for O16C32
  105. Reduce::Param param{Mode::MEAN, 1, Reduce::Param::DataType::FLOAT_O16xC32};
  106. checker.set_dtype(0, dtype::Float16())
  107. .set_dtype(1, dtype::Float16())
  108. .set_param(param)
  109. .execs({{1, 65536, 5}, {1, 1, 5}});
  110. }
  111. }
  112. // vim: syntax=cpp.doxygen