You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #pragma once
  2. #include <algorithm>
  3. #include <numeric>
  4. #include "megdnn/basic_types.h"
  5. #include "megdnn/dtype.h"
  6. #include "src/common/utils.h"
  7. using namespace megdnn;
  8. /* anonymous namespace */
  9. namespace {
  10. using Mode = Reduce::Mode;
  11. /* Reduce Trait */
  12. template <Mode mode, typename ctype>
  13. struct Trait;
  14. template <typename ctype>
  15. struct Trait<Mode::SUM, ctype> {
  16. static const ctype INIT;
  17. static ctype apply(ctype x, ctype y) { return x + y; }
  18. static ctype visit(ctype x) { return x; }
  19. static ctype write(ctype x, size_t) { return x; }
  20. };
  21. template <typename ctype>
  22. const ctype Trait<Mode::SUM, ctype>::INIT = ctype(0);
  23. template <typename ctype>
  24. struct Trait<Mode::MEAN, ctype> {
  25. static const ctype INIT;
  26. static ctype apply(ctype x, ctype y) { return x + y; }
  27. static ctype visit(ctype x) { return x; }
  28. static ctype write(ctype x, size_t B) { return x / (ctype)B; }
  29. };
  30. template <typename ctype>
  31. const ctype Trait<Mode::MEAN, ctype>::INIT = ctype(0);
  32. template <typename ctype>
  33. struct Trait<Mode::SUM_SQR, ctype> {
  34. static const ctype INIT;
  35. static ctype apply(ctype x, ctype y) { return x + y; }
  36. static ctype visit(ctype x) { return x * x; }
  37. static ctype write(ctype x, size_t) { return x; }
  38. };
  39. template <typename ctype>
  40. const ctype Trait<Mode::SUM_SQR, ctype>::INIT = ctype(0);
  41. template <typename ctype>
  42. struct Trait<Mode::PRODUCT, ctype> {
  43. static const ctype INIT;
  44. static ctype apply(ctype x, ctype y) { return x * y; }
  45. static ctype visit(ctype x) { return x; }
  46. static ctype write(ctype x, size_t) { return x; }
  47. };
  48. template <typename ctype>
  49. const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1);
  50. template <typename ctype>
  51. struct Trait<Mode::MIN, ctype> {
  52. static ctype apply(ctype x, ctype y) { return x < y ? x : y; }
  53. static ctype visit(ctype x) { return x; }
  54. static ctype write(ctype x, size_t) { return x; }
  55. };
  56. template <>
  57. struct Trait<Mode::MIN, dt_float32> {
  58. using ctype = dt_float32;
  59. static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; }
  60. static ctype visit(ctype x) { return x; }
  61. static ctype write(ctype x, size_t) { return x; }
  62. };
  63. template <typename ctype>
  64. struct Trait<Mode::MAX, ctype> {
  65. static ctype apply(ctype x, ctype y) { return x > y ? x : y; }
  66. static ctype visit(ctype x) { return x; }
  67. static ctype write(ctype x, size_t) { return x; }
  68. };
  69. template <>
  70. struct Trait<Mode::MAX, dt_float32> {
  71. using ctype = dt_float32;
  72. static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; }
  73. static ctype visit(ctype x) { return x; }
  74. static ctype write(ctype x, size_t) { return x; }
  75. };
  76. /* NormOp */
  77. template <typename ctype>
  78. struct NormOp;
  79. template <>
  80. struct NormOp<dt_float32> {
  81. typedef dt_float32 ctype;
  82. static const ctype INIT;
  83. static ctype apply(ctype x, ctype y) { return x + y; }
  84. static ctype visit(ctype x, dt_float32 p) { return powf(fabs(x), p); }
  85. static ctype write(ctype x, size_t, dt_float32 p) { return powf(x, 1.f / p); }
  86. };
  87. #if !MEGDNN_DISABLE_FLOAT16
  88. template <>
  89. struct NormOp<dt_float16> {
  90. typedef dt_float16 ctype;
  91. static const ctype INIT;
  92. static ctype apply(ctype x, ctype y) { return x + y; }
  93. static ctype visit(ctype x, dt_float32 p) {
  94. return half_float::pow(half_float::abs(x), half_float::half(p));
  95. }
  96. static ctype write(ctype x, size_t, dt_float32 p) {
  97. return half_float::pow(x, half_float::half(1.f / p));
  98. }
  99. };
  100. #endif
  101. template <typename ctype>
  102. struct NormZeroOp;
  103. template <>
  104. struct NormZeroOp<dt_float32> {
  105. typedef dt_float32 ctype;
  106. static const ctype INIT;
  107. static ctype apply(ctype x, ctype y) { return x + y; }
  108. static ctype visit(ctype x) { return x - 0.f < 0.00001f ? 0.f : 1.f; }
  109. static ctype write(ctype x, size_t) { return x; }
  110. };
  111. #if !MEGDNN_DISABLE_FLOAT16
  112. template <>
  113. struct NormZeroOp<dt_float16> {
  114. typedef dt_float16 ctype;
  115. static const ctype INIT;
  116. static ctype apply(ctype x, ctype y) { return x + y; }
  117. static ctype visit(ctype x) {
  118. return x - half_float::half(0.f) < half_float::half(0.00001f)
  119. ? half_float::half(0.f)
  120. : half_float::half(1.f);
  121. }
  122. static ctype write(ctype x, size_t) { return x; }
  123. };
  124. #endif
  125. } // namespace