GitOrigin-RevId: 7812900244
release-1.2
@@ -88,6 +88,7 @@ enum class AlgoDataType : uint32_t { | |||||
QUINT8X8X32 = 1 << 3, | QUINT8X8X32 = 1 << 3, | ||||
INT8X8X16 = 1 << 4, | INT8X8X16 = 1 << 4, | ||||
INT16X16X32 = 1 << 5, | INT16X16X32 = 1 << 5, | ||||
INT4X4X16 = 1 << 6, | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -17,6 +17,7 @@ | |||||
#include "src/aarch64/matrix_mul/int8/strategy.h" | #include "src/aarch64/matrix_mul/int8/strategy.h" | ||||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | ||||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/quint8/strategy.h" | #include "src/aarch64/matrix_mul/quint8/strategy.h" | ||||
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | ||||
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | ||||
@@ -1394,4 +1395,75 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, | |||||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, | aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, | ||||
int8_t, int16_t, AlgoDataType::INT8X8X16, | int8_t, int16_t, AlgoDataType::INT8X8X16, | ||||
MK4); | MK4); | ||||
/* ===================== Int4x4x16 K8x8x8 algo ===================== */ | |||||
namespace { | |||||
void int4x4x16_k8x8x16_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||||
midout_iv("int4x4x16_k8x8x8_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||||
C_type = kern_param.C_type; | |||||
const auto Aptr = kern_param.A<dt_int8>(), | |||||
Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int16>(); | |||||
aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, | |||||
C_type); | |||||
megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s4x4x16_s4_8x8x8>( | |||||
M, N, K, trA, trB, strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||||
kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS4 && | |||||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16 && | |||||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
(kern_size_param.K & 1) == 0 && (kern_size_param.N & 1) == 0; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||||
return true; | |||||
} | |||||
size_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||||
midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, | |||||
K = kern_size_param.K; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, | |||||
C_type); | |||||
return megdnn::matmul::GemmInterleaved<matmul::gemm_s4x4x16_s4_8x8x8>( | |||||
M, N, K, trA, trB, strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_kern( | |||||
const KernSizeParam&) const { | |||||
return int4x4x16_k8x8x16_kern; | |||||
} | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt4x4x16K8x8x8, | |||||
megdnn_aarch64_matmul_kern, | |||||
"AlgoInt4x4x16K8x8x8Impl"_hash, | |||||
aarch64::matmul::gemm_s4x4x16_s4_8x8x8, | |||||
int8_t, int16_t, AlgoDataType::INT4X4X16, | |||||
DEFAULT); | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -192,6 +192,19 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) | MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT4X4X16_K8X8X8) | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
@@ -925,6 +925,42 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, | |||||
: "v0", "v1", "v2", "v3", "memory"); | : "v0", "v1", "v2", "v3", "memory"); | ||||
} | } | ||||
template <typename T> | |||||
static inline void interleave_8x4_1_b_with_shift( | |||||
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, | |||||
const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, | |||||
T* outptr) { | |||||
static_assert(sizeof(T) == 1, "only support size == 1"); | |||||
asm volatile( | |||||
"ld1 {v0.s}[0], [%[inptr0]], #4\n" | |||||
"ld1 {v0.s}[1], [%[inptr1]], #4\n" | |||||
"ld1 {v0.s}[2], [%[inptr2]], #4\n" | |||||
"ld1 {v0.s}[3], [%[inptr3]], #4\n" | |||||
"ld1 {v1.s}[0], [%[inptr4]], #4\n" | |||||
"ld1 {v1.s}[1], [%[inptr5]], #4\n" | |||||
"ld1 {v1.s}[2], [%[inptr6]], #4\n" | |||||
"ld1 {v1.s}[3], [%[inptr7]], #4\n" | |||||
"shl v2.16b, v0.16b, #4\n" | |||||
"shl v5.16b, v1.16b, #4\n" | |||||
"sshr v3.16b, v0.16b, #4\n" // hig | |||||
"sshr v4.16b, v2.16b, #4\n" // low | |||||
"sshr v6.16b, v1.16b, #4\n" // hig | |||||
"sshr v7.16b, v5.16b, #4\n" // low | |||||
"zip1 v8.16b, v4.16b, v3.16b\n" | |||||
"zip2 v9.16b, v4.16b, v3.16b\n" | |||||
"zip1 v10.16b, v7.16b, v6.16b\n" | |||||
"zip2 v11.16b, v7.16b, v6.16b\n" | |||||
"st1 {v8.16b-v11.16b},[%[outptr]],#64" | |||||
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), | |||||
[ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3), | |||||
[ inptr4 ] "+r"(inptr4), [ inptr5 ] "+r"(inptr5), | |||||
[ inptr6 ] "+r"(inptr6), [ inptr7 ] "+r"(inptr7), | |||||
[ outptr ] "+r"(outptr) | |||||
: | |||||
: "v0", "v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","memory"); | |||||
} | |||||
template <typename T> | template <typename T> | ||||
static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1, | static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, | const T*& inptr2, const T*& inptr3, | ||||
@@ -1059,6 +1095,7 @@ static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, | |||||
: "v0", "v1", "v2", "v3", "v4", "cc", "memory"); | : "v0", "v1", "v2", "v3", "v4", "cc", "memory"); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1, | static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, | const T*& inptr2, const T*& inptr3, | ||||
@@ -1773,6 +1810,54 @@ static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void transpose_4x8_1_b_with_shift(const T*& inptr0, const T*& inptr1, | |||||
const T*& inptr2, const T*& inptr3, | |||||
const T*& inptr4, const T*& inptr5, | |||||
const T*& inptr6, const T*& inptr7, | |||||
T*& outptr) { | |||||
static int8x16_t shuffle_idx = {0, 4, 8, 12, 1, 5, 9, 13, | |||||
2, 6, 10, 14, 3, 7, 11, 15}; | |||||
static_assert( | |||||
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | |||||
"transpose_8x4_1_b only support uint8_t and int8_t"); | |||||
asm volatile( | |||||
"ld1 {v0.s}[0], [%[inptr0]], #4\n" // A1A2A3A4 | |||||
"ld1 {v0.s}[1], [%[inptr1]], #4\n" // B1B2B3B4 | |||||
"ld1 {v0.s}[2], [%[inptr2]], #4\n" // C1C2C3C4 | |||||
"ld1 {v0.s}[3], [%[inptr3]], #4\n" // D1D2D3D4 | |||||
"ld1 {v1.s}[0], [%[inptr4]], #4\n" // E1E2E3E4 | |||||
"ld1 {v1.s}[1], [%[inptr5]], #4\n" // F1F2F3F4 | |||||
"ld1 {v1.s}[2], [%[inptr6]], #4\n" // G1G2G3G4 | |||||
"ld1 {v1.s}[3], [%[inptr7]], #4\n" // H1H2H3H4 | |||||
"tbl v2.16b, {v0.16b}, %[shuffle_idx].16b \n" // A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4 | |||||
"tbl v3.16b, {v1.16b}, %[shuffle_idx].16b \n" // E1F1G1H1E2F2G2H2E3F3G3H3E4F4G4H4 | |||||
"zip1 v4.4s, v2.4s, v3.4s\n" // A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2 | |||||
"zip2 v5.4s, v2.4s, v3.4s\n" // A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4 | |||||
"shl v6.16b, v4.16b, #4\n" | |||||
"sshr v7.16b, v4.16b, #4\n" // hig | |||||
"sshr v8.16b, v6.16b, #4\n" // low | |||||
"shl v9.16b, v5.16b, #4\n" | |||||
"sshr v10.16b, v5.16b, #4\n" // hig | |||||
"sshr v11.16b, v9.16b, #4\n" // low | |||||
"zip1 v0.2d,v8.2d,v7.2d\n" | |||||
"zip2 v1.2d,v8.2d,v7.2d\n" | |||||
"zip1 v2.2d,v11.2d,v10.2d\n" | |||||
"zip2 v3.2d,v11.2d,v10.2d\n" | |||||
"st1 {v0.2d-v3.2d},[%[outptr]],#64\n" | |||||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||||
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), | |||||
[inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | |||||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [shuffle_idx]"+w"(shuffle_idx), | |||||
[outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","v8","v9","v10","v11","memory"); | |||||
} | |||||
template <typename T> | |||||
static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, | static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, | const T*& inptr2, const T*& inptr3, | ||||
const T*& inptr4, const T*& inptr5, | const T*& inptr4, const T*& inptr5, | ||||
@@ -0,0 +1,913 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int4x4x16/kernel_8x8x8.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <inttypes.h> | |||||
#include <cstring> | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul_s4_4x4x16 { | |||||
/** | |||||
* Overview of register layout: | |||||
* | |||||
* +---------+---------+---------+---------+ | |||||
* |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]| | |||||
* Rhs +---------+---------+---------+---------+ | |||||
* Lhs | | | | |||||
* | |||||
* +--------+ - - - - +---------+---------+---------+---------+ | |||||
* |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]| | |||||
* |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]| | |||||
* |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]| | |||||
* |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]| | |||||
* +--------+ - - - - +---------+---------+---------+---------+ | |||||
* | |||||
* Accumulator | |||||
*/ | |||||
static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
K /= 8; | |||||
LDC = LDC * sizeof(int16_t); | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
// clang-format off | |||||
#define LOAD_LINE(reg_index, n) \ | |||||
"cmp x8, #0 \n" \ | |||||
"beq 105f\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"blt 100" n "f\n" \ | |||||
"ld1 {v" reg_index ".8h}, [x" n "], #16\n" \ | |||||
"b 101" n "f\n" \ | |||||
"100" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"blt 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #3\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[3], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[4], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #5\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[5], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #6\n" \ | |||||
"beq 101" n "f\n" \ | |||||
"ld1 {v" reg_index ".h}[6], [x" n "], #2\n" \ | |||||
"101" n ":\n" \ | |||||
"sub x8, x8, #1\n" | |||||
#define LOAD_C \ | |||||
"mov x8, %x[m_remain]\n" \ | |||||
LOAD_LINE("24", "0") \ | |||||
LOAD_LINE("25", "1") \ | |||||
LOAD_LINE("26", "2") \ | |||||
LOAD_LINE("27", "3") \ | |||||
LOAD_LINE("28", "4") \ | |||||
LOAD_LINE("29", "5") \ | |||||
LOAD_LINE("30", "6") \ | |||||
LOAD_LINE("31", "7") \ | |||||
"105:\n" | |||||
#define STORE_LINE(reg_index, n) \ | |||||
"cmp x8, #0 \n" \ | |||||
"beq 105f\n" \ | |||||
"cmp %w[n_remain], #8\n" \ | |||||
"blt 102" n "f\n" \ | |||||
"st1 {v" reg_index ".8h}, [x" n "], #16\n" \ | |||||
"b 103" n "f\n" \ | |||||
"102" n ":\n" \ | |||||
"cmp %w[n_remain], #0\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #1\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #2\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #3\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[3], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #4\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[4], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #5\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[5], [x" n "], #2\n" \ | |||||
"cmp %w[n_remain], #6\n" \ | |||||
"beq 103" n "f\n" \ | |||||
"st1 {v" reg_index ".h}[6], [x" n "], #2\n" \ | |||||
"103" n ":\n" \ | |||||
"sub x8, x8, #1\n" | |||||
#define STORE_C \ | |||||
"mov x8, %x[m_remain]\n" \ | |||||
STORE_LINE("24", "0") \ | |||||
STORE_LINE("25", "1") \ | |||||
STORE_LINE("26", "2") \ | |||||
STORE_LINE("27", "3") \ | |||||
STORE_LINE("28", "4") \ | |||||
STORE_LINE("29", "5") \ | |||||
STORE_LINE("30", "6") \ | |||||
STORE_LINE("31", "7") \ | |||||
"105:\n" | |||||
// clang-format on | |||||
register int16_t* outptr asm("x0") = output; | |||||
asm volatile( | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"add x4, x3, %x[LDC]\n" | |||||
"add x5, x4, %x[LDC]\n" | |||||
"add x6, x5, %x[LDC]\n" | |||||
"add x7, x6, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 2f\n" LOAD_C | |||||
"b 1f\n" | |||||
"2:\n" // Clear the C regs. | |||||
"eor v24.16b, v24.16b, v24.16b\n" | |||||
"eor v25.16b, v25.16b, v25.16b\n" | |||||
"eor v26.16b, v26.16b, v26.16b\n" | |||||
"eor v27.16b, v27.16b, v27.16b\n" | |||||
"eor v28.16b, v28.16b, v28.16b\n" | |||||
"eor v29.16b, v29.16b, v29.16b\n" | |||||
"eor v30.16b, v30.16b, v30.16b\n" | |||||
"eor v31.16b, v31.16b, v31.16b\n" | |||||
// General loop. | |||||
"1:\n" | |||||
"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
"dup v0.8b,v20.b[0]\n" | |||||
"dup v1.8b,v20.b[1]\n" | |||||
"dup v2.8b,v20.b[2]\n" | |||||
"dup v3.8b,v20.b[3]\n" | |||||
"ld1 {v22.16b}, [%[a_ptr]],#16\n" | |||||
"ld1 {v23.16b}, [%[a_ptr]],#16\n" | |||||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||||
"dup v4.8b,v20.b[4]\n" | |||||
"dup v5.8b,v20.b[5]\n" | |||||
"dup v6.8b,v20.b[6]\n" | |||||
"dup v7.8b,v20.b[7]\n" | |||||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||||
"dup v8.8b,v20.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v16.8b\n" | |||||
"dup v9.8b,v20.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v16.8b\n" | |||||
"dup v10.8b,v20.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v16.8b\n" | |||||
"dup v11.8b,v20.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v16.8b\n" | |||||
"dup v12.8b,v20.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v16.8b\n" | |||||
"dup v13.8b,v20.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v16.8b\n" | |||||
"dup v14.8b,v20.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v16.8b\n" | |||||
"dup v15.8b,v20.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v16.8b\n" | |||||
"ld1 {v18.8b}, [%[b_ptr]], 8\n" | |||||
"dup v0.8b,v21.b[0]\n" | |||||
"smlal v24.8h, v8.8b, v17.8b\n" | |||||
"dup v1.8b,v21.b[1]\n" | |||||
"smlal v25.8h, v9.8b, v17.8b\n" | |||||
"dup v2.8b,v21.b[2]\n" | |||||
"smlal v26.8h, v10.8b, v17.8b\n" | |||||
"dup v3.8b,v21.b[3]\n" | |||||
"smlal v27.8h, v11.8b, v17.8b\n" | |||||
"dup v4.8b,v21.b[4]\n" | |||||
"smlal v28.8h, v12.8b, v17.8b\n" | |||||
"dup v5.8b,v21.b[5]\n" | |||||
"smlal v29.8h, v13.8b, v17.8b\n" | |||||
"dup v6.8b,v21.b[6]\n" | |||||
"smlal v30.8h, v14.8b, v17.8b\n" | |||||
"dup v7.8b,v21.b[7]\n" | |||||
"smlal v31.8h, v15.8b, v17.8b\n" | |||||
"ld1 {v19.8b}, [%[b_ptr]], 8\n" | |||||
"dup v8.8b,v21.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v18.8b\n" | |||||
"dup v9.8b,v21.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v18.8b\n" | |||||
"dup v10.8b,v21.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v18.8b\n" | |||||
"dup v11.8b,v21.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v18.8b\n" | |||||
"dup v12.8b,v21.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v18.8b\n" | |||||
"dup v13.8b,v21.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v18.8b\n" | |||||
"dup v14.8b,v21.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v18.8b\n" | |||||
"dup v15.8b,v21.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v18.8b\n" | |||||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||||
"dup v0.8b,v22.b[0]\n" | |||||
"smlal v24.8h, v8.8b, v19.8b\n" | |||||
"dup v1.8b,v22.b[1]\n" | |||||
"smlal v25.8h, v9.8b, v19.8b\n" | |||||
"dup v2.8b,v22.b[2]\n" | |||||
"smlal v26.8h, v10.8b, v19.8b\n" | |||||
"dup v3.8b,v22.b[3]\n" | |||||
"smlal v27.8h, v11.8b, v19.8b\n" | |||||
"dup v4.8b,v22.b[4]\n" | |||||
"smlal v28.8h, v12.8b, v19.8b\n" | |||||
"dup v5.8b,v22.b[5]\n" | |||||
"smlal v29.8h, v13.8b, v19.8b\n" | |||||
"dup v6.8b,v22.b[6]\n" | |||||
"smlal v30.8h, v14.8b, v19.8b\n" | |||||
"dup v7.8b,v22.b[7]\n" | |||||
"smlal v31.8h, v15.8b, v19.8b\n" | |||||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||||
"dup v8.8b,v22.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v16.8b\n" | |||||
"dup v9.8b,v22.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v16.8b\n" | |||||
"dup v10.8b,v22.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v16.8b\n" | |||||
"dup v11.8b,v22.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v16.8b\n" | |||||
"dup v12.8b,v22.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v16.8b\n" | |||||
"dup v13.8b,v22.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v16.8b\n" | |||||
"dup v14.8b,v22.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v16.8b\n" | |||||
"dup v15.8b,v22.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v16.8b\n" | |||||
"ld1 {v18.8b}, [%[b_ptr]], 8\n" | |||||
"dup v0.8b,v23.b[0]\n" | |||||
"smlal v24.8h, v8.8b, v17.8b\n" | |||||
"dup v1.8b,v23.b[1]\n" | |||||
"smlal v25.8h, v9.8b, v17.8b\n" | |||||
"dup v2.8b,v23.b[2]\n" | |||||
"smlal v26.8h, v10.8b, v17.8b\n" | |||||
"dup v3.8b,v23.b[3]\n" | |||||
"smlal v27.8h, v11.8b, v17.8b\n" | |||||
"dup v4.8b,v23.b[4]\n" | |||||
"smlal v28.8h, v12.8b, v17.8b\n" | |||||
"dup v5.8b,v23.b[5]\n" | |||||
"smlal v29.8h, v13.8b, v17.8b\n" | |||||
"dup v6.8b,v23.b[6]\n" | |||||
"smlal v30.8h, v14.8b, v17.8b\n" | |||||
"dup v7.8b,v23.b[7]\n" | |||||
"smlal v31.8h, v15.8b, v17.8b\n" | |||||
"ld1 {v19.8b}, [%[b_ptr]], 8\n" | |||||
"dup v8.8b,v23.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v18.8b\n" | |||||
"dup v9.8b,v23.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v18.8b\n" | |||||
"dup v10.8b,v23.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v18.8b\n" | |||||
"dup v11.8b,v23.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v18.8b\n" | |||||
"dup v12.8b,v23.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v18.8b\n" | |||||
"dup v13.8b,v23.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v18.8b\n" | |||||
"dup v14.8b,v23.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v18.8b\n" | |||||
"dup v15.8b,v23.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v18.8b\n" | |||||
"smlal v24.8h, v8.8b, v19.8b\n" | |||||
"smlal v25.8h, v9.8b, v19.8b\n" | |||||
"smlal v26.8h, v10.8b, v19.8b\n" | |||||
"smlal v27.8h, v11.8b, v19.8b\n" | |||||
"smlal v28.8h, v12.8b, v19.8b\n" | |||||
"smlal v29.8h, v13.8b, v19.8b\n" | |||||
"smlal v30.8h, v14.8b, v19.8b\n" | |||||
"smlal v31.8h, v15.8b, v19.8b\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"cbnz %w[K], 1b\n" | |||||
"3:\n" | |||||
// Store back into memory | |||||
STORE_C | |||||
: | |||||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
: | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31"); | |||||
#undef LOAD_LINE | |||||
#undef LOAD_C | |||||
#undef STORE_LINE | |||||
#undef STORE_C | |||||
} | |||||
static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
K /= 8; | |||||
LDC = LDC * sizeof(int16_t); | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
// clang-format off | |||||
#define LOAD_C_8 \ | |||||
"ld1 {v24.8h}, [x0], #16\n" \ | |||||
"ld1 {v25.8h}, [x1], #16\n" \ | |||||
"ld1 {v26.8h}, [x2], #16\n" \ | |||||
"ld1 {v27.8h}, [x3], #16\n" \ | |||||
"ld1 {v28.8h}, [x4], #16\n" \ | |||||
"ld1 {v29.8h}, [x5], #16\n" \ | |||||
"ld1 {v30.8h}, [x6], #16\n" \ | |||||
"ld1 {v31.8h}, [x7], #16\n" \ | |||||
#define STORE_C_8 \ | |||||
"st1 {v24.8h}, [x0], #16\n" \ | |||||
"st1 {v25.8h}, [x1], #16\n" \ | |||||
"st1 {v26.8h}, [x2], #16\n" \ | |||||
"st1 {v27.8h}, [x3], #16\n" \ | |||||
"st1 {v28.8h}, [x4], #16\n" \ | |||||
"st1 {v29.8h}, [x5], #16\n" \ | |||||
"st1 {v30.8h}, [x6], #16\n" \ | |||||
"st1 {v31.8h}, [x7], #16\n" \ | |||||
// clang-format on | |||||
register int16_t* outptr asm("x0") = output; | |||||
asm volatile( | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"add x4, x3, %x[LDC]\n" | |||||
"add x5, x4, %x[LDC]\n" | |||||
"add x6, x5, %x[LDC]\n" | |||||
"add x7, x6, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 2f\n" LOAD_C_8 | |||||
"b 1f\n" | |||||
"2:\n" // Clear the C regs. | |||||
"eor v24.16b, v24.16b, v24.16b\n" | |||||
"eor v25.16b, v25.16b, v25.16b\n" | |||||
"eor v26.16b, v26.16b, v26.16b\n" | |||||
"eor v27.16b, v27.16b, v27.16b\n" | |||||
"eor v28.16b, v28.16b, v28.16b\n" | |||||
"eor v29.16b, v29.16b, v29.16b\n" | |||||
"eor v30.16b, v30.16b, v30.16b\n" | |||||
"eor v31.16b, v31.16b, v31.16b\n" | |||||
// General loop. | |||||
"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | |||||
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | |||||
"1:\n" | |||||
// "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
// "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
"dup v0.8b,v20.b[0]\n" | |||||
"ld1 {v22.16b}, [%[a_ptr]],#16\n" | |||||
"dup v1.8b,v20.b[1]\n" | |||||
"ld1 {v23.16b}, [%[a_ptr]],#16\n" | |||||
"dup v2.8b,v20.b[2]\n" | |||||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||||
"dup v3.8b,v20.b[3]\n" | |||||
"dup v4.8b,v20.b[4]\n" | |||||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||||
"dup v5.8b,v20.b[5]\n" | |||||
"dup v6.8b,v20.b[6]\n" | |||||
"dup v7.8b,v20.b[7]\n" | |||||
"dup v8.8b,v20.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v16.8b\n" | |||||
"dup v9.8b,v20.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v16.8b\n" | |||||
"dup v10.8b,v20.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v16.8b\n" | |||||
"dup v11.8b,v20.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v16.8b\n" | |||||
"dup v12.8b,v20.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v16.8b\n" | |||||
"dup v13.8b,v20.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v16.8b\n" | |||||
"dup v14.8b,v20.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v16.8b\n" | |||||
"dup v15.8b,v20.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v16.8b\n" | |||||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||||
"dup v0.8b,v21.b[0]\n" | |||||
"smlal v24.8h, v8.8b, v17.8b\n" | |||||
"dup v1.8b,v21.b[1]\n" | |||||
"smlal v25.8h, v9.8b, v17.8b\n" | |||||
"dup v2.8b,v21.b[2]\n" | |||||
"smlal v26.8h, v10.8b, v17.8b\n" | |||||
"dup v3.8b,v21.b[3]\n" | |||||
"smlal v27.8h, v11.8b, v17.8b\n" | |||||
"dup v4.8b,v21.b[4]\n" | |||||
"smlal v28.8h, v12.8b, v17.8b\n" | |||||
"dup v5.8b,v21.b[5]\n" | |||||
"smlal v29.8h, v13.8b, v17.8b\n" | |||||
"dup v6.8b,v21.b[6]\n" | |||||
"smlal v30.8h, v14.8b, v17.8b\n" | |||||
"dup v7.8b,v21.b[7]\n" | |||||
"smlal v31.8h, v15.8b, v17.8b\n" | |||||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||||
"dup v8.8b,v21.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v16.8b\n" | |||||
"dup v9.8b,v21.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v16.8b\n" | |||||
"dup v10.8b,v21.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v16.8b\n" | |||||
"dup v11.8b,v21.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v16.8b\n" | |||||
"dup v12.8b,v21.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v16.8b\n" | |||||
"dup v13.8b,v21.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v16.8b\n" | |||||
"dup v14.8b,v21.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v16.8b\n" | |||||
"dup v15.8b,v21.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v16.8b\n" | |||||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||||
"dup v0.8b,v22.b[0]\n" | |||||
"smlal v24.8h, v8.8b, v17.8b\n" | |||||
"dup v1.8b,v22.b[1]\n" | |||||
"smlal v25.8h, v9.8b, v17.8b\n" | |||||
"dup v2.8b,v22.b[2]\n" | |||||
"smlal v26.8h, v10.8b, v17.8b\n" | |||||
"dup v3.8b,v22.b[3]\n" | |||||
"smlal v27.8h, v11.8b, v17.8b\n" | |||||
"dup v4.8b,v22.b[4]\n" | |||||
"smlal v28.8h, v12.8b, v17.8b\n" | |||||
"dup v5.8b,v22.b[5]\n" | |||||
"smlal v29.8h, v13.8b, v17.8b\n" | |||||
"dup v6.8b,v22.b[6]\n" | |||||
"smlal v30.8h, v14.8b, v17.8b\n" | |||||
"dup v7.8b,v22.b[7]\n" | |||||
"smlal v31.8h, v15.8b, v17.8b\n" | |||||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||||
"dup v8.8b,v22.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v16.8b\n" | |||||
"dup v9.8b,v22.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v16.8b\n" | |||||
"dup v10.8b,v22.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v16.8b\n" | |||||
"dup v11.8b,v22.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v16.8b\n" | |||||
"dup v12.8b,v22.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v16.8b\n" | |||||
"dup v13.8b,v22.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v16.8b\n" | |||||
"dup v14.8b,v22.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v16.8b\n" | |||||
"dup v15.8b,v22.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v16.8b\n" | |||||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||||
"dup v0.8b,v23.b[0]\n" | |||||
"smlal v24.8h, v8.8b, v17.8b\n" | |||||
"dup v1.8b,v23.b[1]\n" | |||||
"smlal v25.8h, v9.8b, v17.8b\n" | |||||
"dup v2.8b,v23.b[2]\n" | |||||
"smlal v26.8h, v10.8b, v17.8b\n" | |||||
"dup v3.8b,v23.b[3]\n" | |||||
"smlal v27.8h, v11.8b, v17.8b\n" | |||||
"dup v4.8b,v23.b[4]\n" | |||||
"smlal v28.8h, v12.8b, v17.8b\n" | |||||
"dup v5.8b,v23.b[5]\n" | |||||
"smlal v29.8h, v13.8b, v17.8b\n" | |||||
"dup v6.8b,v23.b[6]\n" | |||||
"smlal v30.8h, v14.8b, v17.8b\n" | |||||
"dup v7.8b,v23.b[7]\n" | |||||
"smlal v31.8h, v15.8b, v17.8b\n" | |||||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||||
"dup v8.8b,v23.b[8]\n" | |||||
"smlal v24.8h, v0.8b, v16.8b\n" | |||||
"dup v9.8b,v23.b[9]\n" | |||||
"smlal v25.8h, v1.8b, v16.8b\n" | |||||
"dup v10.8b,v23.b[10]\n" | |||||
"smlal v26.8h, v2.8b, v16.8b\n" | |||||
"dup v11.8b,v23.b[11]\n" | |||||
"smlal v27.8h, v3.8b, v16.8b\n" | |||||
"dup v12.8b,v23.b[12]\n" | |||||
"smlal v28.8h, v4.8b, v16.8b\n" | |||||
"dup v13.8b,v23.b[13]\n" | |||||
"smlal v29.8h, v5.8b, v16.8b\n" | |||||
"dup v14.8b,v23.b[14]\n" | |||||
"smlal v30.8h, v6.8b, v16.8b\n" | |||||
"dup v15.8b,v23.b[15]\n" | |||||
"smlal v31.8h, v7.8b, v16.8b\n" | |||||
"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
"smlal v24.8h, v8.8b, v17.8b\n" | |||||
"smlal v25.8h, v9.8b, v17.8b\n" | |||||
"smlal v26.8h, v10.8b, v17.8b\n" | |||||
"smlal v27.8h, v11.8b, v17.8b\n" | |||||
"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
"smlal v28.8h, v12.8b, v17.8b\n" | |||||
"smlal v29.8h, v13.8b, v17.8b\n" | |||||
"smlal v30.8h, v14.8b, v17.8b\n" | |||||
"smlal v31.8h, v15.8b, v17.8b\n" | |||||
//"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
//"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"cbnz %w[K], 1b\n" | |||||
"3:\n" | |||||
// Store back into memory | |||||
STORE_C_8 | |||||
: | |||||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
: | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31"); | |||||
#undef LOAD_LINE | |||||
#undef LOAD_C | |||||
#undef STORE_LINE | |||||
#undef STORE_C | |||||
} | |||||
//packa | |||||
static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[8]; | |||||
int8_t tmpbuff0[8]; | |||||
int8_t tmpbuff1[8]; | |||||
int8_t tmpbuff2[8]; | |||||
int8_t tmpbuff3[8]; | |||||
int8_t tmpbuff4[8]; | |||||
int8_t tmpbuff5[8]; | |||||
int8_t tmpbuff6[8]; | |||||
int8_t tmpbuff7[8]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff0, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff1, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff2, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff3, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff4, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff5, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | |||||
ldin /= 2; | |||||
int y = y0; | |||||
for (; y + 7 < ymax; y += 8) { | |||||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
const int8_t* inptr4 = inptr3 + ldin; | |||||
const int8_t* inptr5 = inptr4 + ldin; | |||||
const int8_t* inptr6 = inptr5 + ldin; | |||||
const int8_t* inptr7 = inptr6 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
prefetch_2x(inptr4); | |||||
prefetch_2x(inptr5); | |||||
prefetch_2x(inptr6); | |||||
prefetch_2x(inptr7); | |||||
int K = (kmax - k0)/2; | |||||
//! read 4 * 16 in each row | |||||
for (; K > 3; K -= 4) { | |||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
} | |||||
if (K > 0) { | |||||
std::memcpy(tmpbuff0,inptr0,K); | |||||
std::memcpy(tmpbuff1,inptr1,K); | |||||
std::memcpy(tmpbuff2,inptr2,K); | |||||
std::memcpy(tmpbuff3,inptr3,K); | |||||
std::memcpy(tmpbuff4,inptr4,K); | |||||
std::memcpy(tmpbuff5,inptr5,K); | |||||
std::memcpy(tmpbuff6,inptr6,K); | |||||
std::memcpy(tmpbuff7,inptr7,K); | |||||
inptr0 = tmpbuff0; | |||||
inptr1 = tmpbuff1; | |||||
inptr2 = tmpbuff2; | |||||
inptr3 = tmpbuff3; | |||||
inptr4 = tmpbuff4; | |||||
inptr5 = tmpbuff5; | |||||
inptr6 = tmpbuff6; | |||||
inptr7 = tmpbuff7; | |||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
} | |||||
} | |||||
for (; y < ymax; y += 8) { | |||||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
const int8_t* inptr4 = inptr3 + ldin; | |||||
const int8_t* inptr5 = inptr4 + ldin; | |||||
const int8_t* inptr6 = inptr5 + ldin; | |||||
const int8_t* inptr7 = inptr6 + ldin; | |||||
int K = (kmax - k0)/2; | |||||
//! read 4 * 16 in each row | |||||
for (; K > 3; K -= 4) { | |||||
if (y + 7 >= ymax) { | |||||
switch (y + 7 - ymax) { | |||||
case 6: | |||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
case 5: | |||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
case 4: | |||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
case 3: | |||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
case 2: | |||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
case 1: | |||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
case 0: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
} | |||||
if (K > 0) { | |||||
if (y + 7 >= ymax) { | |||||
switch (y + 7 - ymax) { | |||||
case 6: | |||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
case 5: | |||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
case 4: | |||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
case 3: | |||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
case 2: | |||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
case 1: | |||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
case 0: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
} | |||||
} | |||||
std::memcpy(tmpbuff0,inptr0,K); | |||||
std::memcpy(tmpbuff1,inptr1,K); | |||||
std::memcpy(tmpbuff2,inptr2,K); | |||||
std::memcpy(tmpbuff3,inptr3,K); | |||||
std::memcpy(tmpbuff4,inptr4,K); | |||||
std::memcpy(tmpbuff5,inptr5,K); | |||||
std::memcpy(tmpbuff6,inptr6,K); | |||||
std::memcpy(tmpbuff7,inptr7,K); | |||||
inptr0 = tmpbuff0; | |||||
inptr1 = tmpbuff1; | |||||
inptr2 = tmpbuff2; | |||||
inptr3 = tmpbuff3; | |||||
inptr4 = tmpbuff4; | |||||
inptr5 = tmpbuff5; | |||||
inptr6 = tmpbuff6; | |||||
inptr7 = tmpbuff7; | |||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
} | |||||
} | |||||
} | |||||
//packb | |||||
static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[8]; | |||||
int8_t tmpbuff0[8]; | |||||
int8_t tmpbuff1[8]; | |||||
int8_t tmpbuff2[8]; | |||||
int8_t tmpbuff3[8]; | |||||
int8_t tmpbuff4[8]; | |||||
int8_t tmpbuff5[8]; | |||||
int8_t tmpbuff6[8]; | |||||
int8_t tmpbuff7[8]; | |||||
std::memset(zerobuff, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff0, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff1, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff2, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff3, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff4, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff5, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | |||||
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | |||||
const int ksize = kmax - k0; | |||||
const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 | |||||
int8_t* outptr = out; | |||||
int8_t* outptr_interleave = nullptr; | |||||
int k = k0; | |||||
ldin /= 2; | |||||
xmax = xmax / 2; | |||||
for (; k + 7 < kmax; k += 8) { | |||||
const int8_t* inptr0 = in + k * ldin + x0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
const int8_t* inptr4 = inptr3 + ldin; | |||||
const int8_t* inptr5 = inptr4 + ldin; | |||||
const int8_t* inptr6 = inptr5 + ldin; | |||||
const int8_t* inptr7 = inptr6 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
prefetch_2x(inptr2); | |||||
prefetch_2x(inptr3); | |||||
prefetch_2x(inptr4); | |||||
prefetch_2x(inptr5); | |||||
prefetch_2x(inptr6); | |||||
prefetch_2x(inptr7); | |||||
int x = x0; | |||||
int8_t* outptr_inner = outptr; | |||||
for (; x + 3 < xmax; x += 4) { | |||||
outptr_interleave = outptr_inner; | |||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
outptr_inner += ksize8; | |||||
} | |||||
if (x < xmax) { | |||||
int remainx = xmax - x; | |||||
std::memcpy(tmpbuff0,inptr0,remainx); | |||||
std::memcpy(tmpbuff1,inptr1,remainx); | |||||
std::memcpy(tmpbuff2,inptr2,remainx); | |||||
std::memcpy(tmpbuff3,inptr3,remainx); | |||||
std::memcpy(tmpbuff4,inptr4,remainx); | |||||
std::memcpy(tmpbuff5,inptr5,remainx); | |||||
std::memcpy(tmpbuff6,inptr6,remainx); | |||||
std::memcpy(tmpbuff7,inptr7,remainx); | |||||
inptr0 = tmpbuff0; | |||||
inptr1 = tmpbuff1; | |||||
inptr2 = tmpbuff2; | |||||
inptr3 = tmpbuff3; | |||||
inptr4 = tmpbuff4; | |||||
inptr5 = tmpbuff5; | |||||
inptr6 = tmpbuff6; | |||||
inptr7 = tmpbuff7; | |||||
outptr_interleave = outptr_inner; | |||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
outptr_inner += ksize8; | |||||
} | |||||
outptr += 64; | |||||
} | |||||
if (k < kmax) { | |||||
const int8_t* inptr0 = in + k * ldin + x0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
const int8_t* inptr2 = inptr1 + ldin; | |||||
const int8_t* inptr3 = inptr2 + ldin; | |||||
const int8_t* inptr4 = inptr3 + ldin; | |||||
const int8_t* inptr5 = inptr4 + ldin; | |||||
const int8_t* inptr6 = inptr5 + ldin; | |||||
const int8_t* inptr7 = inptr6 + ldin; | |||||
int k_remain = kmax - k - 1; | |||||
int x = x0; | |||||
int8_t* outptr_inner = outptr; | |||||
for (; x + 3 < xmax; x += 4) { | |||||
switch (k_remain) { | |||||
case 0: | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 2: | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 3: | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 4: | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 5: | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 6: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
break; | |||||
} | |||||
outptr_interleave = outptr_inner; | |||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
outptr_inner += ksize8; | |||||
} | |||||
if (x < xmax) { | |||||
switch (k_remain) { | |||||
case 0: | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 1: | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 2: | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 3: | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 4: | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 5: | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU; | |||||
case 6: | |||||
inptr7 = zerobuff; | |||||
break; | |||||
default: | |||||
megdnn_assert(0); | |||||
break; | |||||
} | |||||
int remainx = xmax - x; | |||||
outptr_interleave = outptr_inner; | |||||
std::memcpy(tmpbuff0,inptr0,remainx); | |||||
std::memcpy(tmpbuff1,inptr1,remainx); | |||||
std::memcpy(tmpbuff2,inptr2,remainx); | |||||
std::memcpy(tmpbuff3,inptr3,remainx); | |||||
std::memcpy(tmpbuff4,inptr4,remainx); | |||||
std::memcpy(tmpbuff5,inptr5,remainx); | |||||
std::memcpy(tmpbuff6,inptr6,remainx); | |||||
std::memcpy(tmpbuff7,inptr7,remainx); | |||||
inptr0 = tmpbuff0; | |||||
inptr1 = tmpbuff1; | |||||
inptr2 = tmpbuff2; | |||||
inptr3 = tmpbuff3; | |||||
inptr4 = tmpbuff4; | |||||
inptr5 = tmpbuff5; | |||||
inptr6 = tmpbuff6; | |||||
inptr7 = tmpbuff7; | |||||
outptr_interleave = outptr_inner; | |||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
outptr_inner += ksize8; | |||||
} | |||||
} | |||||
} | |||||
} // namespace matmul_4x4x16 | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,109 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | |||||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
using namespace megdnn; | |||||
using namespace aarch64; | |||||
using namespace aarch64::matmul; | |||||
// ===========================gemm_s4x4x16_s4_8x8x8================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | |||||
void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} else { | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | |||||
} | |||||
void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
constexpr size_t A_INTERLEAVE = 8; | |||||
constexpr size_t B_INTERLEAVE = 8; | |||||
//! K is packed to times of 8 | |||||
K = round_up<size_t>(K, 8); | |||||
const int K8 = K * 8; | |||||
size_t m = 0; | |||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||||
int16_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
packA += K8; | |||||
} | |||||
for (; m < M; m += A_INTERLEAVE) { | |||||
int16_t* output = C + (m * LDC); | |||||
size_t n = 0; | |||||
const dt_int8* cur_packB = packB; | |||||
for (; n < N; n += B_INTERLEAVE) { | |||||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, A_INTERLEAVE), | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | |||||
cur_packB += K8; | |||||
} | |||||
packA += K8; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||||
gemm_s4x4x16_s4_8x8x8); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -50,6 +50,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#else | #else | ||||
AlgoQuint8K8x8x8 quint8_k8x8x8; | AlgoQuint8K8x8x8 quint8_k8x8x8; | ||||
#endif | #endif | ||||
AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | ||||
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | ||||
@@ -87,6 +88,7 @@ public: | |||||
#else | #else | ||||
m_all_algos.emplace_back(&quint8_k8x8x8); | m_all_algos.emplace_back(&quint8_k8x8x8); | ||||
#endif | #endif | ||||
m_all_algos.emplace_back(&int4x4x16_k8x8x8); | |||||
for (auto&& algo : m_all_algos) { | for (auto&& algo : m_all_algos) { | ||||
m_all_algos_map.emplace(algo->info().desc, algo); | m_all_algos_map.emplace(algo->info().desc, algo); | ||||
@@ -66,8 +66,8 @@ private: | |||||
#else | #else | ||||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | ||||
#endif | #endif | ||||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||||
class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||||
class AlgoPack; | class AlgoPack; | ||||
public: | public: | ||||
static const AlgoPack& algo_pack(); | static const AlgoPack& algo_pack(); | ||||
@@ -33,6 +33,8 @@ void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { | |||||
C_candi = dtype::QuantizedS32(mul_scale(A, B)); | C_candi = dtype::QuantizedS32(mul_scale(A, B)); | ||||
} else if (A.enumv() == DTypeEnum::Quantized4Asymm) { | } else if (A.enumv() == DTypeEnum::Quantized4Asymm) { | ||||
C_candi = dtype::QuantizedS32(mul_scale(A, B)); | C_candi = dtype::QuantizedS32(mul_scale(A, B)); | ||||
} else if (A.enumv() == DTypeEnum::QuantizedS4) { | |||||
C_candi = dtype::QuantizedS16(mul_scale(A, B)); | |||||
} | } | ||||
if (!C.valid()) { | if (!C.valid()) { | ||||
C = C_candi; | C = C_candi; | ||||
@@ -169,6 +171,8 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
A.dtype.enumv() == DTypeEnum::Quantized8Asymm || | A.dtype.enumv() == DTypeEnum::Quantized8Asymm || | ||||
A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | ||||
megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32); | megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32); | ||||
} else if(A.dtype.enumv() == DTypeEnum::QuantizedS4){ | |||||
megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS16); | |||||
} | } | ||||
megdnn_assert(param().compute_mode != | megdnn_assert(param().compute_mode != | ||||
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( | Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( | ||||
@@ -154,6 +154,7 @@ public: | |||||
AARCH64_QUINT8_K8X8X4_DOTPROD, | AARCH64_QUINT8_K8X8X4_DOTPROD, | ||||
AARCH64_QUINT8_GEMV_DOTPROD, | AARCH64_QUINT8_GEMV_DOTPROD, | ||||
AARCH64_QUINT8_K8X8X8, | AARCH64_QUINT8_K8X8X8, | ||||
AARCH64_INT4X4X16_K8X8X8, | |||||
#else | #else | ||||
ARMV7_F32 = 1 << 16, | ARMV7_F32 = 1 << 16, | ||||
ARMV7_F32_MK4_PACK_4X12, | ARMV7_F32_MK4_PACK_4X12, | ||||
@@ -179,6 +179,42 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||||
C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC, | C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC, | ||||
nA.layout.dtype, nB.layout.dtype); | nA.layout.dtype, nB.layout.dtype); | ||||
} | } | ||||
template <bool transA, bool transB> | |||||
void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
_megdnn_tensor_out C, | |||||
_megdnn_workspace workspace, | |||||
const param::MatrixMul& param) { | |||||
auto convert_layout = [](const TensorLayout& layout) { | |||||
auto ret = layout; | |||||
auto param = layout.dtype.param<dtype::QuantizedS4>(); | |||||
ret.dtype = dtype::QuantizedS8(param.scale); | |||||
return ret; | |||||
}; | |||||
TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||||
TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||||
convert_layout(B.layout)}; | |||||
auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | |||||
auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||||
auto out_ptr = | |||||
out.compatible_ptr<int8_t>() + out.layout.span().low_byte; | |||||
for (size_t i = 0; i < in.layout.span().dist_elem(); i += 2) { | |||||
int8_t cur = ptr[i / 2]; | |||||
out_ptr[i] = cur << 4; | |||||
out_ptr[i] = out_ptr[i] >> 4; | |||||
out_ptr[i + 1] = cur >> 4; | |||||
} | |||||
}; | |||||
convert_4to8(A, nA); | |||||
convert_4to8(B, nB); | |||||
auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||||
auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||||
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||||
LDC = C.layout.stride[0]; | |||||
run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>( | |||||
nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(), | |||||
C.compatible_ptr<dt_int16>(), M, N, K, LDA, LDB, LDC, | |||||
nA.layout.dtype, nB.layout.dtype); | |||||
} | |||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -26,7 +26,8 @@ size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | |||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_naive_matmul, | megdnn_naive_matmul, | ||||
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | ||||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm || | |||||
A.dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
return (A.span().dist_elem() + B.span().dist_elem()) * | return (A.span().dist_elem() + B.span().dist_elem()) * | ||||
sizeof(uint8_t); | sizeof(uint8_t); | ||||
} | } | ||||
@@ -104,6 +105,11 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
param.format == param::MatrixMul::Format::DEFAULT) { | param.format == param::MatrixMul::Format::DEFAULT) { | ||||
exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param); | exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param); | ||||
return; | return; | ||||
} else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 && | |||||
param.format == param::MatrixMul::Format::DEFAULT) { | |||||
exec_matrix_mul_qint4x4x16_helper<TA, TB>(A, B, C, workspace, param); | |||||
return; | |||||
} | } | ||||
#undef cb | #undef cb | ||||
megdnn_throw(ssprintf( | megdnn_throw(ssprintf( | ||||
@@ -164,6 +164,55 @@ TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K4x4x16) { | |||||
handle(), "AARCH64_INT8X8X16_K4X4X16"); | handle(), "AARCH64_INT8X8X16_K4X4X16"); | ||||
} | } | ||||
TEST_F(AARCH64, MATRIX_MUL_INT4x4x16_K8x8x8_QUANTIZEDS4) { | |||||
param::MatrixMul param; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
Checker<MatrixMul> checker(handle()); | |||||
checker.set_dtype(0, dtype::QuantizedS4{0.6}) | |||||
.set_dtype(1, dtype::QuantizedS4{0.5}) | |||||
.set_dtype(2, dtype::QuantizedS16{0.6 * 0.5}) | |||||
.set_param(param); | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("AARCH64_INT4X4X16_K8X8X8")); | |||||
auto run = [&](size_t M, size_t N, size_t K) { | |||||
printf("M N K %zu %zu %zu \n", M, N, K); | |||||
TensorShape A, B; | |||||
if (param.transposeA) { | |||||
A = TensorShape{K, M}; | |||||
} else { | |||||
A = TensorShape{M, K}; | |||||
} | |||||
if (param.transposeB) { | |||||
B = TensorShape{N, K}; | |||||
} else { | |||||
B = TensorShape{K, N}; | |||||
} | |||||
checker.exec({A, B, {}}); | |||||
}; | |||||
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 20}) | |||||
for (size_t n : {2, 4, 6, 8, 10, 12, 14, 16, 24}) | |||||
for (size_t k : {2, 4, 6, 8, 10, 12, 14, 16, 32}) | |||||
run(m, n, k); | |||||
for (size_t k = 4; k <= 256; k *= 8) { | |||||
for (size_t m = 4; m <= 256; m *= 4) { | |||||
for (size_t n = 4; n <= 256; n *= 4) { | |||||
run(m, n, k); | |||||
} | |||||
} | |||||
} | |||||
param.transposeA = true; | |||||
run(8,8,8); | |||||
run(16,8,16); | |||||
param.transposeB = true; | |||||
run(8,8,8); | |||||
run(16,16,16); | |||||
} | |||||
TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { | TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { | ||||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | ||||
handle(), "AARCH64_INT16X16X32_K12X8X1"); | handle(), "AARCH64_INT16X16X32_K12X8X1"); | ||||
@@ -410,6 +459,63 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { | |||||
run(384, 384, 384); | run(384, 384, 384); | ||||
} | } | ||||
TEST_F(AARCH64, BENCHMARK_4x4x16_vs_8x8x16) { | |||||
constexpr size_t RUNS = 50; | |||||
param::MatrixMul param; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
Benchmarker<MatrixMul> benchmarker(handle()); | |||||
Benchmarker<MatrixMul> benchmarker_int4_4x4x16(handle()); | |||||
benchmarker_int4_4x4x16.set_times(RUNS) | |||||
.set_dtype(0, dtype::QuantizedS4{0.3}) | |||||
.set_dtype(1, dtype::QuantizedS4{0.3}) | |||||
.set_dtype(2, dtype::QuantizedS16{0.09}) | |||||
.set_param(param) | |||||
.set_display(false); | |||||
benchmarker.set_times(RUNS) | |||||
.set_dtype(0, dtype::Int8{}) | |||||
.set_dtype(1, dtype::Int8{}) | |||||
.set_dtype(2, dtype::Int16{}) | |||||
.set_param(param) | |||||
.set_display(false); | |||||
benchmarker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16")); | |||||
auto run = [&](size_t M, size_t N, size_t K) { | |||||
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
auto int4416_used = | |||||
benchmarker_int4_4x4x16.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
float computations = 2.f * M * K * N * 1e-6; | |||||
printf("run: {%zu{M} %zu{K} %zu{N}} normal 8x8x16 used: %f ms %f " | |||||
"Gflops int4416 used %f int4416_gflops %f speedup %f\n", | |||||
M, K, N, default_used, computations / default_used, int4416_used, | |||||
computations / int4416_used, default_used / int4416_used); | |||||
}; | |||||
for (int m = 32; m <= 1024; m += 32) | |||||
for (int n = 32; n <= 1024; n += 32) | |||||
for (int k = 32; k <= 512; k += 32) | |||||
run(m, n, k); | |||||
run(32, 32, 32); | |||||
run(32, 32, 8); | |||||
run(32, 32, 16); | |||||
run(32, 32, 24); | |||||
run(32 * 2, 32 * 2, 32); | |||||
run(32 * 4, 32 * 4, 32); | |||||
run(32 * 6, 32 * 6, 32); | |||||
run(32 * 8, 32 * 8, 32); | |||||
run(32 * 2, 32 * 2, 32 * 2); | |||||
run(32 * 4, 32 * 4, 32 * 3); | |||||
run(32 * 6, 32 * 6, 32 * 4); | |||||
run(32 * 8, 32 * 8, 32 * 5); | |||||
run(32 * 10, 32 * 10, 32 * 10); | |||||
run(384, 384, 384); | |||||
run(256, 256, 384); | |||||
run(512, 512, 384); | |||||
run(1024, 1024, 384); | |||||
} | |||||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) { | TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) { | ||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
param::MatrixMul param; | param::MatrixMul param; | ||||
@@ -183,6 +183,34 @@ void IIDRNG::gen(const TensorND& tensor) { | |||||
} | } | ||||
return; | return; | ||||
} | } | ||||
if (tensor.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
auto ptr = static_cast<int8_t*>(tensor.raw_ptr); | |||||
if (output_is_float()) { | |||||
for (size_t i = 0; i < nr_elems; i += 2) { | |||||
int8_t val0 = | |||||
tensor.layout.dtype.param<dt_qint4>() | |||||
.quantize(static_cast<float>(gen_single_val())) | |||||
.as_int8(); | |||||
int8_t val1 = | |||||
tensor.layout.dtype.param<dt_qint4>() | |||||
.quantize(static_cast<float>(gen_single_val())) | |||||
.as_int8(); | |||||
ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4); | |||||
} | |||||
} else { | |||||
for (size_t i = 0; i < nr_elems; i += 2) { | |||||
int8_t val0 = static_cast<int8_t>(gen_single_val()); | |||||
int8_t val1 = static_cast<int8_t>(gen_single_val()); | |||||
val0 = std::min(val0,DTypeTrait<dtype::QuantizedS4>::max()); | |||||
val0 = std::max(val0,DTypeTrait<dtype::QuantizedS4>::min()); | |||||
val1 = std::min(val1,DTypeTrait<dtype::QuantizedS4>::max()); | |||||
val1 = std::max(val1,DTypeTrait<dtype::QuantizedS4>::min()); | |||||
ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4); | |||||
} | |||||
} | |||||
return; | |||||
} | |||||
megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", | megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", | ||||
tensor.layout.dtype.name()); | tensor.layout.dtype.name()); | ||||
} | } | ||||
@@ -203,6 +203,67 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) { | |||||
}); | }); | ||||
} | } | ||||
TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||||
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false); | |||||
auto GenTensorValueQuint4 = [](const TensorShape& shape, | |||||
dtype::QuantizedS4 dtype, | |||||
const std::vector<int>& values) { | |||||
TensorND tensor; | |||||
tensor.layout = {shape, dtype}; | |||||
tensor.raw_ptr = | |||||
static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte())); | |||||
uint8_t* ptr = static_cast<uint8_t*>(tensor.raw_ptr); | |||||
megdnn_assert(values.size() == tensor.layout.span().dist_elem()); | |||||
for (size_t i = 0; i < tensor.layout.span().dist_elem(); i += 2) { | |||||
int val0 = values[i], val1 = values[i + 1]; | |||||
ptr[i / 2] =(val0 & 0xF) | (val1 << 4); | |||||
} | |||||
return tensor; | |||||
}; | |||||
using Param = MatrixMul::Param; | |||||
Param param; | |||||
checker.set_param(param); | |||||
checker.set_dtype(2, dtype::QuantizedS16(0.3f * 0.3f)); | |||||
checker.exect( | |||||
Testcase{ | |||||
GenTensorValueQuint4( | |||||
{8, 8}, dtype::QuantizedS4(0.3f), | |||||
{-8, 7, 2, 1, 2, 3, 2, 7, | |||||
2, 5, 3, 3, 7, 4, -7, 1, | |||||
-5, 7, -4, -1, -1, 2, 4, 1, | |||||
7, 2, -6, -2, -6, 3, 4, 4, | |||||
-2, 2, 3, 0, 6, 5, 3, 4, | |||||
-1, -1, -5, 5, 2, 5, 1, 4, | |||||
6, 2, 0, 0, 3, 2, 2, 1, | |||||
-4, -3, 7, 5, 0, 3, 2, 3}), | |||||
GenTensorValueQuint4( | |||||
{8, 8}, dtype::QuantizedS4(0.3f), | |||||
{5, -8, -7, -6, 4, 7, -5, -5, | |||||
-4, 7, -3, -2, 5, 6, 4, 2, | |||||
3, -1, 2, 2, 7, 3, 6, 0, | |||||
5, 4, 0, 2, 2, 3, 3, 2, | |||||
1, -8, -7, -6, 0, -5, -4, 4, | |||||
-3, 7, 1, 6, -2, 2, -1, 5, | |||||
2, 0, 7, 6, 5, 4, 3, 2, | |||||
0, 0, 1, 0, 5, 2, 2, 6}), | |||||
{}}, | |||||
Testcase{ | |||||
{}, | |||||
{}, | |||||
TensorValue( | |||||
{8, 8}, dtype::QuantizedS16(0.3f * 0.3f), | |||||
{-60, 120, 49, 58, 58, 13, 92, 125, | |||||
-5, 0, -116, -70, 22, 9, -14, 46, | |||||
-69, 111, 44, 48, 6, 19, 42, 57, | |||||
-8, 25, 10, 16, 26, 97, -28, -12, | |||||
-12, 14, 2, 26, 48, 7, 24, 93, | |||||
-2, 45, 2, 32, -19, -1, -16, 72, | |||||
23, -44, -52, -34, 45, 53, -28, 6, | |||||
33, 45, 71, 84, 47, 10, 74, 61}) | |||||
}); | |||||
} | |||||
TEST_F(NAIVE, MATRIX_MUL_QUANTIZED8x8x32) { | TEST_F(NAIVE, MATRIX_MUL_QUANTIZED8x8x32) { | ||||
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false); | Checker<MatrixMul> checker(handle(), /* check_dispatch */ false); | ||||
MatrixMul::Param param; | MatrixMul::Param param; | ||||