You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hswish.h 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. /**
  2. * \file dnn/src/fallback/elemwise_helper/kimpl/hswish.h
  3. */
  4. #pragma once
  5. #include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h"
  6. #include "src/fallback/elemwise_helper/kimpl/op_base.h"
  7. namespace megdnn {
  8. namespace fallback {
  9. template <typename src_ctype, typename dst_ctype = src_ctype>
  10. struct HSwishOpBase : UnaryOpBase<src_ctype, dst_ctype> {
  11. using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
  12. void operator()(const src_ctype& src, dst_ctype* dst) const {
  13. *dst = operator()(src);
  14. }
  15. dst_ctype operator()(const src_ctype& src) const {
  16. float tmp = src;
  17. tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
  18. return (tmp);
  19. }
  20. };
  21. //! h_swish(x) = x * clip(x + 3, 0, 6) / 6
  22. template <typename src_ctype, typename dst_ctype = src_ctype>
  23. struct HSwishOp;
  24. #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
  25. template <> \
  26. struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \
  27. using HSwishOpBase::HSwishOpBase; \
  28. using HSwishOpBase::operator(); \
  29. constexpr static size_t SIMD_WIDTH = _simd_width; \
  30. void operator()(const _simd_type2& src, _ctype* dst) const { \
  31. auto vitem = operator()(src); \
  32. GiStore##_func_suffix(dst, vitem.val[0]); \
  33. GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
  34. } \
  35. void operator()(const _simd_type& src, _ctype* dst) const { \
  36. auto vitem = operator()(src); \
  37. GiStore##_func_suffix(dst, vitem); \
  38. } \
  39. _simd_type2 operator()(const _simd_type2& src) const { \
  40. auto val1 = src.val[0]; \
  41. auto val2 = src.val[1]; \
  42. H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \
  43. return {{val1, val2}}; \
  44. } \
  45. _simd_type operator()(const _simd_type& src) const { \
  46. auto val_zero = GiBroadcast##_func_suffix(0.f); \
  47. auto val_six = GiBroadcast##_func_suffix(6.f); \
  48. auto val_three = GiBroadcast##_func_suffix(3.f); \
  49. auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \
  50. auto clip1 = GiMaximum##_func_suffix( \
  51. GiMinimum##_func_suffix( \
  52. GiAdd##_func_suffix(src, val_three), val_six), \
  53. val_zero); \
  54. return GiMultiply##_func_suffix( \
  55. GiMultiply##_func_suffix(src, clip1), val_rec_six); \
  56. } \
  57. };
  58. OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
  59. #undef OP
  60. template <>
  61. struct HSwishOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
  62. using UnaryOpBase::UnaryOpBase;
  63. void operator()(const dt_qint32& src, dt_qint8* dst) const {
  64. *dst = operator()(src);
  65. }
  66. dt_qint8 operator()(const dt_qint32& src) const {
  67. float tmp = src.as_int32() * this->scale_src;
  68. tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
  69. tmp *= this->scale_dst;
  70. return QConverter::convert<dt_qint8, float>(tmp);
  71. }
  72. };
  73. template <>
  74. struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> {
  75. using HSwishOpBase::HSwishOpBase;
  76. using HSwishOpBase::operator();
  77. constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);
  78. void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
  79. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
  80. }
  81. GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
  82. auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src);
  83. auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src);
  84. H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1);
  85. vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst);
  86. vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst);
  87. return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
  88. }
  89. };
  90. } // namespace fallback
  91. } // namespace megdnn
  92. #include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h"
  93. // vim: syntax=cpp.doxygen