You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_base.h 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. /**
  2. * \file dnn/src/fallback/elemwise_helper/kimpl/op_base.h
  3. */
  4. #pragma once
  5. #include <cmath>
  6. #include "megdnn/dtype.h"
  7. #include "megdnn/oprs.h"
  8. #include "src/common/utils.h"
  9. #include "src/fallback/elemwise/gi_impl/gi_mathfun.h"
  10. #include "src/fallback/quantized_converter.h"
  11. #include "src/fallback/general_intrinsic/gi_float.h"
  12. #include "src/fallback/general_intrinsic/gi_int.h"
  13. namespace megdnn {
  14. namespace fallback {
  15. ////////////////////////// unary //////////////////////////
  16. template <typename _src_ctype, typename _dst_ctype = _src_ctype>
  17. struct OpBase {
  18. using src_ctype = _src_ctype;
  19. using dst_ctype = _dst_ctype;
  20. OpBase() = default;
  21. };
  22. template <typename src_ctype, typename dst_ctype = src_ctype>
  23. struct UnaryOpBase : OpBase<src_ctype, dst_ctype> {
  24. using OpBase<src_ctype, dst_ctype>::OpBase;
  25. UnaryOpBase() = default;
  26. UnaryOpBase(DType /*src_dtype*/, DType /*dst_dtype*/) {}
  27. };
  28. #define OPERATOR_UNARY_QINT8_FALLBACK \
  29. GI_INT16_t vsrct0 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc, 0)); \
  30. GI_INT32_V2_t tmp; \
  31. GiSetSubVectorInt32V2(tmp, 0, GiMoveLowLongInt16(vsrct0)); \
  32. GiSetSubVectorInt32V2(tmp, 1, GiMoveHighLongInt16(vsrct0)); \
  33. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(tmp)); \
  34. GI_INT16_t vsrct1 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc, 0)); \
  35. GiSetSubVectorInt32V2(tmp, 0, GiMoveLowLongInt16(vsrct1)); \
  36. GiSetSubVectorInt32V2(tmp, 1, GiMoveHighLongInt16(vsrct1)); \
  37. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 8), operator()(tmp)); \
  38. GI_INT16_t vsrct2 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc, 1)); \
  39. GiSetSubVectorInt32V2(tmp, 0, GiMoveLowLongInt16(vsrct2)); \
  40. GiSetSubVectorInt32V2(tmp, 1, GiMoveHighLongInt16(vsrct2)); \
  41. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 16), operator()(tmp)); \
  42. GI_INT16_t vsrct3 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc, 1)); \
  43. GiSetSubVectorInt32V2(tmp, 0, GiMoveLowLongInt16(vsrct3)); \
  44. GiSetSubVectorInt32V2(tmp, 1, GiMoveHighLongInt16(vsrct3)); \
  45. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 24), operator()(tmp))
  46. //! scale_src = src.scale; scale_dst = 1.f / dst.scale (div -> mul)
  47. //! scale = src.scale / dst.scale
  48. template <>
  49. struct UnaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> {
  50. using OpBase::OpBase;
  51. float scale_src, scale_dst;
  52. GI_FLOAT32_FIXLEN_t vscale_src, vscale_dst;
  53. float scale;
  54. GI_FLOAT32_FIXLEN_t vscale;
  55. void init(float src_scale, float dst_scale) {
  56. scale_src = src_scale;
  57. vscale_src = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_src));
  58. scale_dst = 1.f / dst_scale;
  59. vscale_dst = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_dst));
  60. scale = src_scale / dst_scale;
  61. vscale = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale));
  62. }
  63. UnaryOpBase(DType src_dtype, DType dst_dtype) {
  64. float src_scale = src_dtype.param<dtype::QuantizedS8>().scale;
  65. float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
  66. init(src_scale, dst_scale);
  67. }
  68. UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); }
  69. };
  70. template <>
  71. struct UnaryOpBase<dt_qint32, dt_qint8> : OpBase<dt_qint32, dt_qint8> {
  72. using OpBase::OpBase;
  73. using src_ctype = dt_qint32;
  74. using dst_ctype = dt_qint8;
  75. float scale;
  76. GI_FLOAT32_FIXLEN_t vscale;
  77. float scale_src, scale_dst;
  78. GI_FLOAT32_FIXLEN_t vscale_src, vscale_dst;
  79. void init(float src_scale, float dst_scale) {
  80. scale_src = src_scale;
  81. vscale_src = GiFloat32Type2FixLenType(GiBroadcastFloat32(src_scale));
  82. scale_dst = 1 / dst_scale;
  83. vscale_dst = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_dst));
  84. scale = src_scale / dst_scale;
  85. vscale = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale));
  86. }
  87. UnaryOpBase(DType src_dtype, DType dst_dtype) {
  88. float src_scale = src_dtype.param<dtype::QuantizedS32>().scale;
  89. float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
  90. init(src_scale, dst_scale);
  91. }
  92. UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); }
  93. };
  94. ////////////////////////// binary //////////////////////////
  95. template <typename src_ctype, typename dst_ctype = src_ctype>
  96. struct BinaryOpBase : OpBase<src_ctype, dst_ctype> {
  97. using OpBase<src_ctype, dst_ctype>::OpBase;
  98. BinaryOpBase() = default;
  99. BinaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*dst_dtype*/) {}
  100. };
  101. /* ================= binary op for quantized types ================== */
  102. #define OPERATOR_BINARY_QINT8_FALLBACK \
  103. GI_INT16_t vsrct0_0 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc0, 0)); \
  104. GI_INT16_t vsrct1_0 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc1, 0)); \
  105. GI_INT32_V2_t tmp0, tmp1; \
  106. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0_0)); \
  107. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0_0)); \
  108. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1_0)); \
  109. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1_0)); \
  110. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(tmp0, tmp1)); \
  111. GI_INT16_t vsrct0_1 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc0, 0)); \
  112. GI_INT16_t vsrct1_1 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc1, 0)); \
  113. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0_1)); \
  114. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0_1)); \
  115. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1_1)); \
  116. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1_1)); \
  117. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 8), operator()(tmp0, tmp1)); \
  118. GI_INT16_t vsrct0_2 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc0, 1)); \
  119. GI_INT16_t vsrct1_2 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc1, 1)); \
  120. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0_2)); \
  121. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0_2)); \
  122. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1_2)); \
  123. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1_2)); \
  124. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 16), operator()(tmp0, tmp1)); \
  125. GI_INT16_t vsrct0_3 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc0, 1)); \
  126. GI_INT16_t vsrct1_3 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc1, 1)); \
  127. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0_3)); \
  128. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0_3)); \
  129. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1_3)); \
  130. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1_3)); \
  131. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 24), operator()(tmp0, tmp1));
  132. //! scale_src0 = src0.scale; scale_src1 = src1.scale; scale_dst = 1.f /
  133. //! dst.scale scale0 = src0.scale / dst.scale; scale1 = src1.scale / dst.scale
  134. template <>
  135. struct BinaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> {
  136. using OpBase::OpBase;
  137. using src_ctype = dt_qint8;
  138. using dst_ctype = dt_qint8;
  139. float scale_src0, scale_src1, scale_dst;
  140. GI_FLOAT32_FIXLEN_t vscale_src0, vscale_src1, vscale_dst;
  141. float scale0, scale1;
  142. GI_FLOAT32_FIXLEN_t vscale0, vscale1;
  143. void init(float src0_scale, float src1_scale, float dst_scale) {
  144. scale_src0 = src0_scale;
  145. vscale_src0 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_src0));
  146. scale_src1 = src1_scale;
  147. vscale_src1 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_src1));
  148. scale_dst = 1.f / dst_scale;
  149. vscale_dst = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_dst));
  150. scale0 = src0_scale / dst_scale;
  151. vscale0 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale0));
  152. scale1 = src1_scale / dst_scale;
  153. vscale1 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale1));
  154. }
  155. BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) {
  156. float src0_scale = src0_dtype.param<dtype::QuantizedS8>().scale;
  157. float src1_scale = src1_dtype.param<dtype::QuantizedS8>().scale;
  158. float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
  159. init(src0_scale, src1_scale, dst_scale);
  160. }
  161. BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) {
  162. init(src0_scale, src1_scale, dst_scale);
  163. }
  164. };
  165. template <>
  166. struct BinaryOpBase<dt_qint32, dt_qint8> : OpBase<dt_qint32, dt_qint8> {
  167. using OpBase::OpBase;
  168. using src_ctype = dt_qint32;
  169. using dst_ctype = dt_qint8;
  170. float scale0, scale1;
  171. GI_FLOAT32_FIXLEN_t vscale0, vscale1;
  172. float scale_src0, scale_src1, scale_dst;
  173. GI_FLOAT32_FIXLEN_t vscale_src0, vscale_src1, vscale_dst;
  174. void init(float src0_scale, float src1_scale, float dst_scale) {
  175. scale_src0 = src0_scale;
  176. vscale_src0 = GiFloat32Type2FixLenType(GiBroadcastFloat32(src0_scale));
  177. scale_src1 = src1_scale;
  178. vscale_src1 = GiFloat32Type2FixLenType(GiBroadcastFloat32(src1_scale));
  179. scale_dst = 1 / dst_scale;
  180. vscale_dst = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_dst));
  181. scale0 = src0_scale / dst_scale;
  182. vscale0 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale0));
  183. scale1 = src1_scale / dst_scale;
  184. vscale1 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale1));
  185. }
  186. BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) {
  187. float src0_scale = src0_dtype.param<dtype::QuantizedS32>().scale;
  188. float src1_scale = src1_dtype.param<dtype::QuantizedS32>().scale;
  189. float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
  190. init(src0_scale, src1_scale, dst_scale);
  191. }
  192. BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) {
  193. init(src0_scale, src1_scale, dst_scale);
  194. }
  195. };
  196. ////////////////////////// ternary //////////////////////////
  197. template <typename src_ctype, typename dst_ctype = src_ctype>
  198. struct TernaryOpBase : OpBase<src_ctype, dst_ctype> {
  199. using OpBase<src_ctype, dst_ctype>::OpBase;
  200. TernaryOpBase() = default;
  201. TernaryOpBase(
  202. DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*src2_dtype*/,
  203. DType /*dst_dtype*/) {}
  204. };
  205. #define OPERATOR_TERNARY_QINT8_FALLBACK \
  206. GI_INT16_t vsrct0 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc0, 0)); \
  207. GI_INT16_t vsrct1 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc1, 0)); \
  208. GI_INT16_t vsrct2 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc2, 0)); \
  209. GI_INT32_V2_t tmp0, tmp1, tmp2; \
  210. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0)); \
  211. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0)); \
  212. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1)); \
  213. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1)); \
  214. GiSetSubVectorInt32V2(tmp2, 0, GiMoveLowLongInt16(vsrct2)); \
  215. GiSetSubVectorInt32V2(tmp2, 1, GiMoveHighLongInt16(vsrct2)); \
  216. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(tmp0, tmp1, tmp2)); \
  217. vsrct0 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc0, 0)); \
  218. vsrct1 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc1, 0)); \
  219. vsrct2 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc2, 0)); \
  220. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0)); \
  221. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0)); \
  222. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1)); \
  223. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1)); \
  224. GiSetSubVectorInt32V2(tmp2, 0, GiMoveLowLongInt16(vsrct2)); \
  225. GiSetSubVectorInt32V2(tmp2, 1, GiMoveHighLongInt16(vsrct2)); \
  226. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 8), operator()(tmp0, tmp1, tmp2)); \
  227. vsrct0 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc0, 1)); \
  228. vsrct1 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc1, 1)); \
  229. vsrct2 = GiMoveLowLongInt8(GiGetSubVectorInt8V2(vsrc2, 1)); \
  230. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0)); \
  231. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0)); \
  232. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1)); \
  233. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1)); \
  234. GiSetSubVectorInt32V2(tmp2, 0, GiMoveLowLongInt16(vsrct2)); \
  235. GiSetSubVectorInt32V2(tmp2, 1, GiMoveHighLongInt16(vsrct2)); \
  236. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 16), operator()(tmp0, tmp1, tmp2)); \
  237. vsrct0 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc0, 1)); \
  238. vsrct1 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc1, 1)); \
  239. vsrct2 = GiMoveHighLongInt8(GiGetSubVectorInt8V2(vsrc2, 1)); \
  240. GiSetSubVectorInt32V2(tmp0, 0, GiMoveLowLongInt16(vsrct0)); \
  241. GiSetSubVectorInt32V2(tmp0, 1, GiMoveHighLongInt16(vsrct0)); \
  242. GiSetSubVectorInt32V2(tmp1, 0, GiMoveLowLongInt16(vsrct1)); \
  243. GiSetSubVectorInt32V2(tmp1, 1, GiMoveHighLongInt16(vsrct1)); \
  244. GiSetSubVectorInt32V2(tmp2, 0, GiMoveLowLongInt16(vsrct2)); \
  245. GiSetSubVectorInt32V2(tmp2, 1, GiMoveHighLongInt16(vsrct2)); \
  246. GiStoreLowInt8(reinterpret_cast<int8_t*>(dst + 24), operator()(tmp0, tmp1, tmp2));
  247. /*========================= ternaty op for quanzited ====================*/
  248. template <>
  249. struct TernaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> {
  250. using OpBase::OpBase;
  251. using src_ctype = dt_qint8;
  252. using dst_ctype = dt_qint8;
  253. float scale_src0, scale_src1, scale_src2, scale_dst;
  254. GI_FLOAT32_FIXLEN_t vscale_src0, vscale_src1, vscale_src2, vscale_dst;
  255. float scale0, scale1, scale2;
  256. GI_FLOAT32_FIXLEN_t vscale0, vscale1, vscale2;
  257. void init(float src0_scale, float src1_scale, float src2_scale, float dst_scale) {
  258. scale_src0 = src0_scale;
  259. scale_src1 = src1_scale;
  260. scale_src2 = src2_scale;
  261. scale_dst = 1.f / dst_scale;
  262. vscale_src0 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_src0));
  263. vscale_src1 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_src1));
  264. vscale_src2 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_src2));
  265. vscale_dst = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale_dst));
  266. scale0 = src0_scale / dst_scale;
  267. scale1 = src1_scale / dst_scale;
  268. scale2 = src2_scale / dst_scale;
  269. vscale0 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale0));
  270. vscale1 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale1));
  271. vscale2 = GiFloat32Type2FixLenType(GiBroadcastFloat32(scale2));
  272. }
  273. TernaryOpBase(
  274. DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype) {
  275. float src0_scale = src0_dtype.param<dtype::QuantizedS8>().scale;
  276. float src1_scale = src1_dtype.param<dtype::QuantizedS8>().scale;
  277. float src2_scale = src2_dtype.param<dtype::QuantizedS8>().scale;
  278. float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
  279. init(src0_scale, src1_scale, src2_scale, dst_scale);
  280. }
  281. TernaryOpBase(
  282. float src0_scale, float src1_scale, float src2_scale, float dst_scale) {
  283. init(src0_scale, src1_scale, src2_scale, dst_scale);
  284. }
  285. };
  286. ////////////////////////// fixup //////////////////////////
  287. struct FixupBase {
  288. GI_INT32_FIXLEN_t vmultiplier, vshift;
  289. FixupBase(float scale) {
  290. //! ignore Fixup if scale >= 0.5, using typecvt instead of shift &
  291. //! multiplier, as it may introduce errors.
  292. if (scale >= 0.5)
  293. return;
  294. int shift = static_cast<int>(::ceilf(::log2f(0.5 / scale)));
  295. scale *= ::powf(2, shift);
  296. //! Using double can get full precision here, but it can be ignored.
  297. vmultiplier = GiInt32Type2FixLenType(GiBroadcastInt32(
  298. std::round(static_cast<double>(scale) * ((2LL) << 30))));
  299. vshift = GiInt32Type2FixLenType(GiBroadcastInt32(-shift));
  300. }
  301. };
  302. //////////////////////// quantization common ////////////////////
  303. template <typename src_type, typename dst_type, typename Op>
  304. struct UnaryQuantizationOp;
  305. template <typename Op>
  306. struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qint8> {
  307. using UnaryOpBase<dt_qint8, dt_qint8>::UnaryOpBase;
  308. constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
  309. Op op;
  310. void operator()(const dt_qint8& src, dt_qint8* dst) const {
  311. *dst = operator()(src);
  312. }
  313. dt_qint8 operator()(const dt_qint8& src) const {
  314. float fsrc = src.as_int8() * this->scale_src;
  315. fsrc = op(fsrc);
  316. fsrc = fsrc * this->scale_dst;
  317. return QConverter::convert<dt_qint8, float>(fsrc);
  318. }
  319. void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
  320. OPERATOR_UNARY_QINT8_FALLBACK;
  321. }
  322. GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
  323. auto vitem0 = GiMultiplyFloat32(
  324. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc, 0)),
  325. GiFixLenType2GiFloat32Type(this->vscale_src));
  326. auto vitem1 = GiMultiplyFloat32(
  327. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc, 1)),
  328. GiFixLenType2GiFloat32Type(this->vscale_src));
  329. GI_FLOAT32_V2_t tmp;
  330. GiSetSubVectorFloat32V2(tmp, 0, vitem0);
  331. GiSetSubVectorFloat32V2(tmp, 1, vitem1);
  332. auto val = this->op(tmp);
  333. GI_FLOAT32_t a = GiMultiplyFloat32(
  334. GiGetSubVectorFloat32V2(val, 0),
  335. GiFixLenType2GiFloat32Type(this->vscale_dst));
  336. GI_FLOAT32_t b = GiMultiplyFloat32(
  337. GiGetSubVectorFloat32V2(val, 1),
  338. GiFixLenType2GiFloat32Type(this->vscale_dst));
  339. GiSetSubVectorFloat32V2(val, 0, a);
  340. GiSetSubVectorFloat32V2(val, 1, b);
  341. return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val);
  342. }
  343. };
  344. template <typename src_type, typename dst_type, typename Op>
  345. struct BinaryQuantizationOp;
  346. template <typename Op>
  347. struct BinaryQuantizationOp<dt_qint8, dt_qint8, Op> : BinaryOpBase<dt_qint8, dt_qint8> {
  348. using BinaryOpBase<dt_qint8, dt_qint8>::BinaryOpBase;
  349. constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
  350. Op op;
  351. void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
  352. *dst = operator()(src0, src1);
  353. }
  354. dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
  355. float fsrc0 = src0.as_int8() * this->scale_src0;
  356. float fsrc1 = src1.as_int8() * this->scale_src1;
  357. float fdst = op(fsrc0, fsrc1);
  358. fdst = fdst * this->scale_dst;
  359. return QConverter::convert<dt_qint8, float>(fdst);
  360. }
  361. void operator()(
  362. const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
  363. OPERATOR_BINARY_QINT8_FALLBACK;
  364. }
  365. GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
  366. auto val0 = GiMultiplyFloat32(
  367. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc0, 0)),
  368. GiFixLenType2GiFloat32Type(this->vscale_src0));
  369. auto val1 = GiMultiplyFloat32(
  370. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc0, 1)),
  371. GiFixLenType2GiFloat32Type(this->vscale_src0));
  372. auto val2 = GiMultiplyFloat32(
  373. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc1, 0)),
  374. GiFixLenType2GiFloat32Type(this->vscale_src1));
  375. auto val3 = GiMultiplyFloat32(
  376. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc1, 1)),
  377. GiFixLenType2GiFloat32Type(this->vscale_src1));
  378. GI_FLOAT32_V2_t tmp0, tmp1;
  379. GiSetSubVectorFloat32V2(tmp0, 0, val0);
  380. GiSetSubVectorFloat32V2(tmp0, 1, val1);
  381. GiSetSubVectorFloat32V2(tmp1, 0, val2);
  382. GiSetSubVectorFloat32V2(tmp1, 1, val3);
  383. auto val = op(tmp0, tmp1);
  384. GI_FLOAT32_t a = GiMultiplyFloat32(
  385. GiGetSubVectorFloat32V2(val, 0),
  386. GiFixLenType2GiFloat32Type(this->vscale_dst));
  387. GI_FLOAT32_t b = GiMultiplyFloat32(
  388. GiGetSubVectorFloat32V2(val, 1),
  389. GiFixLenType2GiFloat32Type(this->vscale_dst));
  390. GiSetSubVectorFloat32V2(val, 0, a);
  391. GiSetSubVectorFloat32V2(val, 1, b);
  392. return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val);
  393. }
  394. };
  395. template <typename src_type, typename dst_type, typename Op>
  396. struct TernaryQuantizationOp;
  397. template <typename Op>
  398. struct TernaryQuantizationOp<dt_qint8, dt_qint8, Op>
  399. : TernaryOpBase<dt_qint8, dt_qint8> {
  400. using TernaryOpBase<dt_qint8, dt_qint8>::TernaryOpBase;
  401. constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
  402. Op op;
  403. void operator()(
  404. const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2,
  405. dt_qint8* dst) const {
  406. *dst = operator()(src0, src1, src2);
  407. }
  408. dt_qint8 operator()(
  409. const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2) const {
  410. float fsrc0 = src0.as_int8() * this->scale_src0;
  411. float fsrc1 = src1.as_int8() * this->scale_src1;
  412. float fsrc2 = src2.as_int8() * this->scale_src2;
  413. float fdst = op(fsrc0, fsrc1, fsrc2);
  414. fdst = fdst * this->scale_dst;
  415. return QConverter::convert<dt_qint8, float>(fdst);
  416. }
  417. void operator()(
  418. const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1,
  419. const GI_INT8_V2_t& vsrc2, dt_qint8* dst) const {
  420. OPERATOR_TERNARY_QINT8_FALLBACK;
  421. }
  422. GI_INT8_t operator()(
  423. const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
  424. const GI_INT32_V2_t& vsrc2) const {
  425. auto val0 = GiMultiplyFloat32(
  426. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc0, 0)),
  427. GiFixLenType2GiFloat32Type(this->vscale_src0));
  428. auto val1 = GiMultiplyFloat32(
  429. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc0, 1)),
  430. GiFixLenType2GiFloat32Type(this->vscale_src0));
  431. auto val2 = GiMultiplyFloat32(
  432. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc1, 0)),
  433. GiFixLenType2GiFloat32Type(this->vscale_src1));
  434. auto val3 = GiMultiplyFloat32(
  435. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc1, 1)),
  436. GiFixLenType2GiFloat32Type(this->vscale_src1));
  437. auto val4 = GiMultiplyFloat32(
  438. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc2, 0)),
  439. GiFixLenType2GiFloat32Type(this->vscale_src2));
  440. auto val5 = GiMultiplyFloat32(
  441. GiCastToFloat32(GiGetSubVectorInt32V2(vsrc2, 1)),
  442. GiFixLenType2GiFloat32Type(this->vscale_src2));
  443. GI_FLOAT32_V2_t tmp0, tmp1, tmp2;
  444. GiSetSubVectorFloat32V2(tmp0, 0, val0);
  445. GiSetSubVectorFloat32V2(tmp0, 1, val1);
  446. GiSetSubVectorFloat32V2(tmp1, 0, val2);
  447. GiSetSubVectorFloat32V2(tmp1, 1, val3);
  448. GiSetSubVectorFloat32V2(tmp2, 0, val4);
  449. GiSetSubVectorFloat32V2(tmp2, 1, val5);
  450. auto val = op(tmp0, tmp1, tmp2);
  451. GI_FLOAT32_t a = GiMultiplyFloat32(
  452. GiGetSubVectorFloat32V2(val, 0),
  453. GiFixLenType2GiFloat32Type(this->vscale_dst));
  454. GI_FLOAT32_t b = GiMultiplyFloat32(
  455. GiGetSubVectorFloat32V2(val, 1),
  456. GiFixLenType2GiFloat32Type(this->vscale_dst));
  457. GiSetSubVectorFloat32V2(val, 0, a);
  458. GiSetSubVectorFloat32V2(val, 1, b);
  459. return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val);
  460. }
  461. };
  462. } // namespace fallback
  463. } // namespace megdnn
  464. // vim: syntax=cpp.doxygen