GitOrigin-RevId: 7812900244
release-1.2
@@ -88,6 +88,7 @@ enum class AlgoDataType : uint32_t { | |||
QUINT8X8X32 = 1 << 3, | |||
INT8X8X16 = 1 << 4, | |||
INT16X16X32 = 1 << 5, | |||
INT4X4X16 = 1 << 6, | |||
}; | |||
/*! | |||
@@ -17,6 +17,7 @@ | |||
#include "src/aarch64/matrix_mul/int8/strategy.h" | |||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||
#include "src/aarch64/matrix_mul/quint8/strategy.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | |||
@@ -1394,4 +1395,75 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, | |||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, | |||
int8_t, int16_t, AlgoDataType::INT8X8X16, | |||
MK4); | |||
/* ===================== Int4x4x16 K8x8x8 algo ===================== */ | |||
namespace { | |||
void int4x4x16_k8x8x16_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
midout_iv("int4x4x16_k8x8x8_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto trA = kern_param.trA, trB = kern_param.trB; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
C_type = kern_param.C_type; | |||
const auto Aptr = kern_param.A<dt_int8>(), | |||
Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int16>(); | |||
aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, | |||
C_type); | |||
megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s4x4x16_s4_8x8x8>( | |||
M, N, K, trA, trB, strategy) | |||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||
kern_param.workspace_ptr); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS4 && | |||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16 && | |||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
(kern_size_param.K & 1) == 0 && (kern_size_param.N & 1) == 0; | |||
} | |||
bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||
return true; | |||
} | |||
size_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_workspace( | |||
const KernSizeParam& kern_size_param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) { | |||
auto M = kern_size_param.M, N = kern_size_param.N, | |||
K = kern_size_param.K; | |||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
C_type = kern_size_param.C_type; | |||
aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, | |||
C_type); | |||
return megdnn::matmul::GemmInterleaved<matmul::gemm_s4x4x16_s4_8x8x8>( | |||
M, N, K, trA, trB, strategy) | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_kern( | |||
const KernSizeParam&) const { | |||
return int4x4x16_k8x8x16_kern; | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt4x4x16K8x8x8, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt4x4x16K8x8x8Impl"_hash, | |||
aarch64::matmul::gemm_s4x4x16_s4_8x8x8, | |||
int8_t, int16_t, AlgoDataType::INT4X4X16, | |||
DEFAULT); | |||
// vim: syntax=cpp.doxygen |
@@ -192,6 +192,19 @@ public: | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) | |||
}; | |||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT4X4X16_K8X8X8) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
@@ -925,6 +925,42 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, | |||
: "v0", "v1", "v2", "v3", "memory"); | |||
} | |||
template <typename T> | |||
static inline void interleave_8x4_1_b_with_shift( | |||
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, | |||
const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, | |||
T* outptr) { | |||
static_assert(sizeof(T) == 1, "only support size == 1"); | |||
asm volatile( | |||
"ld1 {v0.s}[0], [%[inptr0]], #4\n" | |||
"ld1 {v0.s}[1], [%[inptr1]], #4\n" | |||
"ld1 {v0.s}[2], [%[inptr2]], #4\n" | |||
"ld1 {v0.s}[3], [%[inptr3]], #4\n" | |||
"ld1 {v1.s}[0], [%[inptr4]], #4\n" | |||
"ld1 {v1.s}[1], [%[inptr5]], #4\n" | |||
"ld1 {v1.s}[2], [%[inptr6]], #4\n" | |||
"ld1 {v1.s}[3], [%[inptr7]], #4\n" | |||
"shl v2.16b, v0.16b, #4\n" | |||
"shl v5.16b, v1.16b, #4\n" | |||
"sshr v3.16b, v0.16b, #4\n" // hig | |||
"sshr v4.16b, v2.16b, #4\n" // low | |||
"sshr v6.16b, v1.16b, #4\n" // hig | |||
"sshr v7.16b, v5.16b, #4\n" // low | |||
"zip1 v8.16b, v4.16b, v3.16b\n" | |||
"zip2 v9.16b, v4.16b, v3.16b\n" | |||
"zip1 v10.16b, v7.16b, v6.16b\n" | |||
"zip2 v11.16b, v7.16b, v6.16b\n" | |||
"st1 {v8.16b-v11.16b},[%[outptr]],#64" | |||
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), | |||
[ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3), | |||
[ inptr4 ] "+r"(inptr4), [ inptr5 ] "+r"(inptr5), | |||
[ inptr6 ] "+r"(inptr6), [ inptr7 ] "+r"(inptr7), | |||
[ outptr ] "+r"(outptr) | |||
: | |||
: "v0", "v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","memory"); | |||
} | |||
template <typename T> | |||
static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, | |||
@@ -1059,6 +1095,7 @@ static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, | |||
: "v0", "v1", "v2", "v3", "v4", "cc", "memory"); | |||
} | |||
template <typename T> | |||
static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, | |||
@@ -1773,6 +1810,54 @@ static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, | |||
} | |||
template <typename T> | |||
static inline void transpose_4x8_1_b_with_shift(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, | |||
const T*& inptr4, const T*& inptr5, | |||
const T*& inptr6, const T*& inptr7, | |||
T*& outptr) { | |||
static int8x16_t shuffle_idx = {0, 4, 8, 12, 1, 5, 9, 13, | |||
2, 6, 10, 14, 3, 7, 11, 15}; | |||
static_assert( | |||
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | |||
"transpose_8x4_1_b only support uint8_t and int8_t"); | |||
asm volatile( | |||
"ld1 {v0.s}[0], [%[inptr0]], #4\n" // A1A2A3A4 | |||
"ld1 {v0.s}[1], [%[inptr1]], #4\n" // B1B2B3B4 | |||
"ld1 {v0.s}[2], [%[inptr2]], #4\n" // C1C2C3C4 | |||
"ld1 {v0.s}[3], [%[inptr3]], #4\n" // D1D2D3D4 | |||
"ld1 {v1.s}[0], [%[inptr4]], #4\n" // E1E2E3E4 | |||
"ld1 {v1.s}[1], [%[inptr5]], #4\n" // F1F2F3F4 | |||
"ld1 {v1.s}[2], [%[inptr6]], #4\n" // G1G2G3G4 | |||
"ld1 {v1.s}[3], [%[inptr7]], #4\n" // H1H2H3H4 | |||
"tbl v2.16b, {v0.16b}, %[shuffle_idx].16b \n" // A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4 | |||
"tbl v3.16b, {v1.16b}, %[shuffle_idx].16b \n" // E1F1G1H1E2F2G2H2E3F3G3H3E4F4G4H4 | |||
"zip1 v4.4s, v2.4s, v3.4s\n" // A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2 | |||
"zip2 v5.4s, v2.4s, v3.4s\n" // A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4 | |||
"shl v6.16b, v4.16b, #4\n" | |||
"sshr v7.16b, v4.16b, #4\n" // hig | |||
"sshr v8.16b, v6.16b, #4\n" // low | |||
"shl v9.16b, v5.16b, #4\n" | |||
"sshr v10.16b, v5.16b, #4\n" // hig | |||
"sshr v11.16b, v9.16b, #4\n" // low | |||
"zip1 v0.2d,v8.2d,v7.2d\n" | |||
"zip2 v1.2d,v8.2d,v7.2d\n" | |||
"zip1 v2.2d,v11.2d,v10.2d\n" | |||
"zip2 v3.2d,v11.2d,v10.2d\n" | |||
"st1 {v0.2d-v3.2d},[%[outptr]],#64\n" | |||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), | |||
[inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | |||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [shuffle_idx]"+w"(shuffle_idx), | |||
[outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","v8","v9","v10","v11","memory"); | |||
} | |||
template <typename T> | |||
static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, | |||
const T*& inptr4, const T*& inptr5, | |||
@@ -0,0 +1,913 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int4x4x16/kernel_8x8x8.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <inttypes.h> | |||
#include <cstring> | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul_s4_4x4x16 { | |||
/** | |||
* Overview of register layout: | |||
* | |||
* +---------+---------+---------+---------+ | |||
* |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]| | |||
* Rhs +---------+---------+---------+---------+ | |||
* Lhs | | | | |||
* | |||
* +--------+ - - - - +---------+---------+---------+---------+ | |||
* |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]| | |||
* |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]| | |||
* |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]| | |||
* |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]| | |||
* +--------+ - - - - +---------+---------+---------+---------+ | |||
* | |||
* Accumulator | |||
*/ | |||
static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
// clang-format off | |||
#define LOAD_LINE(reg_index, n) \ | |||
"cmp x8, #0 \n" \ | |||
"beq 105f\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"blt 100" n "f\n" \ | |||
"ld1 {v" reg_index ".8h}, [x" n "], #16\n" \ | |||
"b 101" n "f\n" \ | |||
"100" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"blt 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #3\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[3], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[4], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #5\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[5], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #6\n" \ | |||
"beq 101" n "f\n" \ | |||
"ld1 {v" reg_index ".h}[6], [x" n "], #2\n" \ | |||
"101" n ":\n" \ | |||
"sub x8, x8, #1\n" | |||
#define LOAD_C \ | |||
"mov x8, %x[m_remain]\n" \ | |||
LOAD_LINE("24", "0") \ | |||
LOAD_LINE("25", "1") \ | |||
LOAD_LINE("26", "2") \ | |||
LOAD_LINE("27", "3") \ | |||
LOAD_LINE("28", "4") \ | |||
LOAD_LINE("29", "5") \ | |||
LOAD_LINE("30", "6") \ | |||
LOAD_LINE("31", "7") \ | |||
"105:\n" | |||
#define STORE_LINE(reg_index, n) \ | |||
"cmp x8, #0 \n" \ | |||
"beq 105f\n" \ | |||
"cmp %w[n_remain], #8\n" \ | |||
"blt 102" n "f\n" \ | |||
"st1 {v" reg_index ".8h}, [x" n "], #16\n" \ | |||
"b 103" n "f\n" \ | |||
"102" n ":\n" \ | |||
"cmp %w[n_remain], #0\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[0], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #1\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[1], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #2\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[2], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #3\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[3], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #4\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[4], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #5\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[5], [x" n "], #2\n" \ | |||
"cmp %w[n_remain], #6\n" \ | |||
"beq 103" n "f\n" \ | |||
"st1 {v" reg_index ".h}[6], [x" n "], #2\n" \ | |||
"103" n ":\n" \ | |||
"sub x8, x8, #1\n" | |||
#define STORE_C \ | |||
"mov x8, %x[m_remain]\n" \ | |||
STORE_LINE("24", "0") \ | |||
STORE_LINE("25", "1") \ | |||
STORE_LINE("26", "2") \ | |||
STORE_LINE("27", "3") \ | |||
STORE_LINE("28", "4") \ | |||
STORE_LINE("29", "5") \ | |||
STORE_LINE("30", "6") \ | |||
STORE_LINE("31", "7") \ | |||
"105:\n" | |||
// clang-format on | |||
register int16_t* outptr asm("x0") = output; | |||
asm volatile( | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"add x4, x3, %x[LDC]\n" | |||
"add x5, x4, %x[LDC]\n" | |||
"add x6, x5, %x[LDC]\n" | |||
"add x7, x6, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 2f\n" LOAD_C | |||
"b 1f\n" | |||
"2:\n" // Clear the C regs. | |||
"eor v24.16b, v24.16b, v24.16b\n" | |||
"eor v25.16b, v25.16b, v25.16b\n" | |||
"eor v26.16b, v26.16b, v26.16b\n" | |||
"eor v27.16b, v27.16b, v27.16b\n" | |||
"eor v28.16b, v28.16b, v28.16b\n" | |||
"eor v29.16b, v29.16b, v29.16b\n" | |||
"eor v30.16b, v30.16b, v30.16b\n" | |||
"eor v31.16b, v31.16b, v31.16b\n" | |||
// General loop. | |||
"1:\n" | |||
"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
"dup v0.8b,v20.b[0]\n" | |||
"dup v1.8b,v20.b[1]\n" | |||
"dup v2.8b,v20.b[2]\n" | |||
"dup v3.8b,v20.b[3]\n" | |||
"ld1 {v22.16b}, [%[a_ptr]],#16\n" | |||
"ld1 {v23.16b}, [%[a_ptr]],#16\n" | |||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||
"dup v4.8b,v20.b[4]\n" | |||
"dup v5.8b,v20.b[5]\n" | |||
"dup v6.8b,v20.b[6]\n" | |||
"dup v7.8b,v20.b[7]\n" | |||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v20.b[8]\n" | |||
"smlal v24.8h, v0.8b, v16.8b\n" | |||
"dup v9.8b,v20.b[9]\n" | |||
"smlal v25.8h, v1.8b, v16.8b\n" | |||
"dup v10.8b,v20.b[10]\n" | |||
"smlal v26.8h, v2.8b, v16.8b\n" | |||
"dup v11.8b,v20.b[11]\n" | |||
"smlal v27.8h, v3.8b, v16.8b\n" | |||
"dup v12.8b,v20.b[12]\n" | |||
"smlal v28.8h, v4.8b, v16.8b\n" | |||
"dup v13.8b,v20.b[13]\n" | |||
"smlal v29.8h, v5.8b, v16.8b\n" | |||
"dup v14.8b,v20.b[14]\n" | |||
"smlal v30.8h, v6.8b, v16.8b\n" | |||
"dup v15.8b,v20.b[15]\n" | |||
"smlal v31.8h, v7.8b, v16.8b\n" | |||
"ld1 {v18.8b}, [%[b_ptr]], 8\n" | |||
"dup v0.8b,v21.b[0]\n" | |||
"smlal v24.8h, v8.8b, v17.8b\n" | |||
"dup v1.8b,v21.b[1]\n" | |||
"smlal v25.8h, v9.8b, v17.8b\n" | |||
"dup v2.8b,v21.b[2]\n" | |||
"smlal v26.8h, v10.8b, v17.8b\n" | |||
"dup v3.8b,v21.b[3]\n" | |||
"smlal v27.8h, v11.8b, v17.8b\n" | |||
"dup v4.8b,v21.b[4]\n" | |||
"smlal v28.8h, v12.8b, v17.8b\n" | |||
"dup v5.8b,v21.b[5]\n" | |||
"smlal v29.8h, v13.8b, v17.8b\n" | |||
"dup v6.8b,v21.b[6]\n" | |||
"smlal v30.8h, v14.8b, v17.8b\n" | |||
"dup v7.8b,v21.b[7]\n" | |||
"smlal v31.8h, v15.8b, v17.8b\n" | |||
"ld1 {v19.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v21.b[8]\n" | |||
"smlal v24.8h, v0.8b, v18.8b\n" | |||
"dup v9.8b,v21.b[9]\n" | |||
"smlal v25.8h, v1.8b, v18.8b\n" | |||
"dup v10.8b,v21.b[10]\n" | |||
"smlal v26.8h, v2.8b, v18.8b\n" | |||
"dup v11.8b,v21.b[11]\n" | |||
"smlal v27.8h, v3.8b, v18.8b\n" | |||
"dup v12.8b,v21.b[12]\n" | |||
"smlal v28.8h, v4.8b, v18.8b\n" | |||
"dup v13.8b,v21.b[13]\n" | |||
"smlal v29.8h, v5.8b, v18.8b\n" | |||
"dup v14.8b,v21.b[14]\n" | |||
"smlal v30.8h, v6.8b, v18.8b\n" | |||
"dup v15.8b,v21.b[15]\n" | |||
"smlal v31.8h, v7.8b, v18.8b\n" | |||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||
"dup v0.8b,v22.b[0]\n" | |||
"smlal v24.8h, v8.8b, v19.8b\n" | |||
"dup v1.8b,v22.b[1]\n" | |||
"smlal v25.8h, v9.8b, v19.8b\n" | |||
"dup v2.8b,v22.b[2]\n" | |||
"smlal v26.8h, v10.8b, v19.8b\n" | |||
"dup v3.8b,v22.b[3]\n" | |||
"smlal v27.8h, v11.8b, v19.8b\n" | |||
"dup v4.8b,v22.b[4]\n" | |||
"smlal v28.8h, v12.8b, v19.8b\n" | |||
"dup v5.8b,v22.b[5]\n" | |||
"smlal v29.8h, v13.8b, v19.8b\n" | |||
"dup v6.8b,v22.b[6]\n" | |||
"smlal v30.8h, v14.8b, v19.8b\n" | |||
"dup v7.8b,v22.b[7]\n" | |||
"smlal v31.8h, v15.8b, v19.8b\n" | |||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v22.b[8]\n" | |||
"smlal v24.8h, v0.8b, v16.8b\n" | |||
"dup v9.8b,v22.b[9]\n" | |||
"smlal v25.8h, v1.8b, v16.8b\n" | |||
"dup v10.8b,v22.b[10]\n" | |||
"smlal v26.8h, v2.8b, v16.8b\n" | |||
"dup v11.8b,v22.b[11]\n" | |||
"smlal v27.8h, v3.8b, v16.8b\n" | |||
"dup v12.8b,v22.b[12]\n" | |||
"smlal v28.8h, v4.8b, v16.8b\n" | |||
"dup v13.8b,v22.b[13]\n" | |||
"smlal v29.8h, v5.8b, v16.8b\n" | |||
"dup v14.8b,v22.b[14]\n" | |||
"smlal v30.8h, v6.8b, v16.8b\n" | |||
"dup v15.8b,v22.b[15]\n" | |||
"smlal v31.8h, v7.8b, v16.8b\n" | |||
"ld1 {v18.8b}, [%[b_ptr]], 8\n" | |||
"dup v0.8b,v23.b[0]\n" | |||
"smlal v24.8h, v8.8b, v17.8b\n" | |||
"dup v1.8b,v23.b[1]\n" | |||
"smlal v25.8h, v9.8b, v17.8b\n" | |||
"dup v2.8b,v23.b[2]\n" | |||
"smlal v26.8h, v10.8b, v17.8b\n" | |||
"dup v3.8b,v23.b[3]\n" | |||
"smlal v27.8h, v11.8b, v17.8b\n" | |||
"dup v4.8b,v23.b[4]\n" | |||
"smlal v28.8h, v12.8b, v17.8b\n" | |||
"dup v5.8b,v23.b[5]\n" | |||
"smlal v29.8h, v13.8b, v17.8b\n" | |||
"dup v6.8b,v23.b[6]\n" | |||
"smlal v30.8h, v14.8b, v17.8b\n" | |||
"dup v7.8b,v23.b[7]\n" | |||
"smlal v31.8h, v15.8b, v17.8b\n" | |||
"ld1 {v19.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v23.b[8]\n" | |||
"smlal v24.8h, v0.8b, v18.8b\n" | |||
"dup v9.8b,v23.b[9]\n" | |||
"smlal v25.8h, v1.8b, v18.8b\n" | |||
"dup v10.8b,v23.b[10]\n" | |||
"smlal v26.8h, v2.8b, v18.8b\n" | |||
"dup v11.8b,v23.b[11]\n" | |||
"smlal v27.8h, v3.8b, v18.8b\n" | |||
"dup v12.8b,v23.b[12]\n" | |||
"smlal v28.8h, v4.8b, v18.8b\n" | |||
"dup v13.8b,v23.b[13]\n" | |||
"smlal v29.8h, v5.8b, v18.8b\n" | |||
"dup v14.8b,v23.b[14]\n" | |||
"smlal v30.8h, v6.8b, v18.8b\n" | |||
"dup v15.8b,v23.b[15]\n" | |||
"smlal v31.8h, v7.8b, v18.8b\n" | |||
"smlal v24.8h, v8.8b, v19.8b\n" | |||
"smlal v25.8h, v9.8b, v19.8b\n" | |||
"smlal v26.8h, v10.8b, v19.8b\n" | |||
"smlal v27.8h, v11.8b, v19.8b\n" | |||
"smlal v28.8h, v12.8b, v19.8b\n" | |||
"smlal v29.8h, v13.8b, v19.8b\n" | |||
"smlal v30.8h, v14.8b, v19.8b\n" | |||
"smlal v31.8h, v15.8b, v19.8b\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"cbnz %w[K], 1b\n" | |||
"3:\n" | |||
// Store back into memory | |||
STORE_C | |||
: | |||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
: | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
// clang-format off | |||
#define LOAD_C_8 \ | |||
"ld1 {v24.8h}, [x0], #16\n" \ | |||
"ld1 {v25.8h}, [x1], #16\n" \ | |||
"ld1 {v26.8h}, [x2], #16\n" \ | |||
"ld1 {v27.8h}, [x3], #16\n" \ | |||
"ld1 {v28.8h}, [x4], #16\n" \ | |||
"ld1 {v29.8h}, [x5], #16\n" \ | |||
"ld1 {v30.8h}, [x6], #16\n" \ | |||
"ld1 {v31.8h}, [x7], #16\n" \ | |||
#define STORE_C_8 \ | |||
"st1 {v24.8h}, [x0], #16\n" \ | |||
"st1 {v25.8h}, [x1], #16\n" \ | |||
"st1 {v26.8h}, [x2], #16\n" \ | |||
"st1 {v27.8h}, [x3], #16\n" \ | |||
"st1 {v28.8h}, [x4], #16\n" \ | |||
"st1 {v29.8h}, [x5], #16\n" \ | |||
"st1 {v30.8h}, [x6], #16\n" \ | |||
"st1 {v31.8h}, [x7], #16\n" \ | |||
// clang-format on | |||
register int16_t* outptr asm("x0") = output; | |||
asm volatile( | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"add x4, x3, %x[LDC]\n" | |||
"add x5, x4, %x[LDC]\n" | |||
"add x6, x5, %x[LDC]\n" | |||
"add x7, x6, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 2f\n" LOAD_C_8 | |||
"b 1f\n" | |||
"2:\n" // Clear the C regs. | |||
"eor v24.16b, v24.16b, v24.16b\n" | |||
"eor v25.16b, v25.16b, v25.16b\n" | |||
"eor v26.16b, v26.16b, v26.16b\n" | |||
"eor v27.16b, v27.16b, v27.16b\n" | |||
"eor v28.16b, v28.16b, v28.16b\n" | |||
"eor v29.16b, v29.16b, v29.16b\n" | |||
"eor v30.16b, v30.16b, v30.16b\n" | |||
"eor v31.16b, v31.16b, v31.16b\n" | |||
// General loop. | |||
"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | |||
"1:\n" | |||
// "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
// "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
"dup v0.8b,v20.b[0]\n" | |||
"ld1 {v22.16b}, [%[a_ptr]],#16\n" | |||
"dup v1.8b,v20.b[1]\n" | |||
"ld1 {v23.16b}, [%[a_ptr]],#16\n" | |||
"dup v2.8b,v20.b[2]\n" | |||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||
"dup v3.8b,v20.b[3]\n" | |||
"dup v4.8b,v20.b[4]\n" | |||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
"dup v5.8b,v20.b[5]\n" | |||
"dup v6.8b,v20.b[6]\n" | |||
"dup v7.8b,v20.b[7]\n" | |||
"dup v8.8b,v20.b[8]\n" | |||
"smlal v24.8h, v0.8b, v16.8b\n" | |||
"dup v9.8b,v20.b[9]\n" | |||
"smlal v25.8h, v1.8b, v16.8b\n" | |||
"dup v10.8b,v20.b[10]\n" | |||
"smlal v26.8h, v2.8b, v16.8b\n" | |||
"dup v11.8b,v20.b[11]\n" | |||
"smlal v27.8h, v3.8b, v16.8b\n" | |||
"dup v12.8b,v20.b[12]\n" | |||
"smlal v28.8h, v4.8b, v16.8b\n" | |||
"dup v13.8b,v20.b[13]\n" | |||
"smlal v29.8h, v5.8b, v16.8b\n" | |||
"dup v14.8b,v20.b[14]\n" | |||
"smlal v30.8h, v6.8b, v16.8b\n" | |||
"dup v15.8b,v20.b[15]\n" | |||
"smlal v31.8h, v7.8b, v16.8b\n" | |||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||
"dup v0.8b,v21.b[0]\n" | |||
"smlal v24.8h, v8.8b, v17.8b\n" | |||
"dup v1.8b,v21.b[1]\n" | |||
"smlal v25.8h, v9.8b, v17.8b\n" | |||
"dup v2.8b,v21.b[2]\n" | |||
"smlal v26.8h, v10.8b, v17.8b\n" | |||
"dup v3.8b,v21.b[3]\n" | |||
"smlal v27.8h, v11.8b, v17.8b\n" | |||
"dup v4.8b,v21.b[4]\n" | |||
"smlal v28.8h, v12.8b, v17.8b\n" | |||
"dup v5.8b,v21.b[5]\n" | |||
"smlal v29.8h, v13.8b, v17.8b\n" | |||
"dup v6.8b,v21.b[6]\n" | |||
"smlal v30.8h, v14.8b, v17.8b\n" | |||
"dup v7.8b,v21.b[7]\n" | |||
"smlal v31.8h, v15.8b, v17.8b\n" | |||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v21.b[8]\n" | |||
"smlal v24.8h, v0.8b, v16.8b\n" | |||
"dup v9.8b,v21.b[9]\n" | |||
"smlal v25.8h, v1.8b, v16.8b\n" | |||
"dup v10.8b,v21.b[10]\n" | |||
"smlal v26.8h, v2.8b, v16.8b\n" | |||
"dup v11.8b,v21.b[11]\n" | |||
"smlal v27.8h, v3.8b, v16.8b\n" | |||
"dup v12.8b,v21.b[12]\n" | |||
"smlal v28.8h, v4.8b, v16.8b\n" | |||
"dup v13.8b,v21.b[13]\n" | |||
"smlal v29.8h, v5.8b, v16.8b\n" | |||
"dup v14.8b,v21.b[14]\n" | |||
"smlal v30.8h, v6.8b, v16.8b\n" | |||
"dup v15.8b,v21.b[15]\n" | |||
"smlal v31.8h, v7.8b, v16.8b\n" | |||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||
"dup v0.8b,v22.b[0]\n" | |||
"smlal v24.8h, v8.8b, v17.8b\n" | |||
"dup v1.8b,v22.b[1]\n" | |||
"smlal v25.8h, v9.8b, v17.8b\n" | |||
"dup v2.8b,v22.b[2]\n" | |||
"smlal v26.8h, v10.8b, v17.8b\n" | |||
"dup v3.8b,v22.b[3]\n" | |||
"smlal v27.8h, v11.8b, v17.8b\n" | |||
"dup v4.8b,v22.b[4]\n" | |||
"smlal v28.8h, v12.8b, v17.8b\n" | |||
"dup v5.8b,v22.b[5]\n" | |||
"smlal v29.8h, v13.8b, v17.8b\n" | |||
"dup v6.8b,v22.b[6]\n" | |||
"smlal v30.8h, v14.8b, v17.8b\n" | |||
"dup v7.8b,v22.b[7]\n" | |||
"smlal v31.8h, v15.8b, v17.8b\n" | |||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v22.b[8]\n" | |||
"smlal v24.8h, v0.8b, v16.8b\n" | |||
"dup v9.8b,v22.b[9]\n" | |||
"smlal v25.8h, v1.8b, v16.8b\n" | |||
"dup v10.8b,v22.b[10]\n" | |||
"smlal v26.8h, v2.8b, v16.8b\n" | |||
"dup v11.8b,v22.b[11]\n" | |||
"smlal v27.8h, v3.8b, v16.8b\n" | |||
"dup v12.8b,v22.b[12]\n" | |||
"smlal v28.8h, v4.8b, v16.8b\n" | |||
"dup v13.8b,v22.b[13]\n" | |||
"smlal v29.8h, v5.8b, v16.8b\n" | |||
"dup v14.8b,v22.b[14]\n" | |||
"smlal v30.8h, v6.8b, v16.8b\n" | |||
"dup v15.8b,v22.b[15]\n" | |||
"smlal v31.8h, v7.8b, v16.8b\n" | |||
"ld1 {v16.8b}, [%[b_ptr]], 8\n" | |||
"dup v0.8b,v23.b[0]\n" | |||
"smlal v24.8h, v8.8b, v17.8b\n" | |||
"dup v1.8b,v23.b[1]\n" | |||
"smlal v25.8h, v9.8b, v17.8b\n" | |||
"dup v2.8b,v23.b[2]\n" | |||
"smlal v26.8h, v10.8b, v17.8b\n" | |||
"dup v3.8b,v23.b[3]\n" | |||
"smlal v27.8h, v11.8b, v17.8b\n" | |||
"dup v4.8b,v23.b[4]\n" | |||
"smlal v28.8h, v12.8b, v17.8b\n" | |||
"dup v5.8b,v23.b[5]\n" | |||
"smlal v29.8h, v13.8b, v17.8b\n" | |||
"dup v6.8b,v23.b[6]\n" | |||
"smlal v30.8h, v14.8b, v17.8b\n" | |||
"dup v7.8b,v23.b[7]\n" | |||
"smlal v31.8h, v15.8b, v17.8b\n" | |||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v23.b[8]\n" | |||
"smlal v24.8h, v0.8b, v16.8b\n" | |||
"dup v9.8b,v23.b[9]\n" | |||
"smlal v25.8h, v1.8b, v16.8b\n" | |||
"dup v10.8b,v23.b[10]\n" | |||
"smlal v26.8h, v2.8b, v16.8b\n" | |||
"dup v11.8b,v23.b[11]\n" | |||
"smlal v27.8h, v3.8b, v16.8b\n" | |||
"dup v12.8b,v23.b[12]\n" | |||
"smlal v28.8h, v4.8b, v16.8b\n" | |||
"dup v13.8b,v23.b[13]\n" | |||
"smlal v29.8h, v5.8b, v16.8b\n" | |||
"dup v14.8b,v23.b[14]\n" | |||
"smlal v30.8h, v6.8b, v16.8b\n" | |||
"dup v15.8b,v23.b[15]\n" | |||
"smlal v31.8h, v7.8b, v16.8b\n" | |||
"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
"smlal v24.8h, v8.8b, v17.8b\n" | |||
"smlal v25.8h, v9.8b, v17.8b\n" | |||
"smlal v26.8h, v10.8b, v17.8b\n" | |||
"smlal v27.8h, v11.8b, v17.8b\n" | |||
"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
"smlal v28.8h, v12.8b, v17.8b\n" | |||
"smlal v29.8h, v13.8b, v17.8b\n" | |||
"smlal v30.8h, v14.8b, v17.8b\n" | |||
"smlal v31.8h, v15.8b, v17.8b\n" | |||
//"ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
//"ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"cbnz %w[K], 1b\n" | |||
"3:\n" | |||
// Store back into memory | |||
STORE_C_8 | |||
: | |||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
: | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
//packa | |||
static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[8]; | |||
int8_t tmpbuff0[8]; | |||
int8_t tmpbuff1[8]; | |||
int8_t tmpbuff2[8]; | |||
int8_t tmpbuff3[8]; | |||
int8_t tmpbuff4[8]; | |||
int8_t tmpbuff5[8]; | |||
int8_t tmpbuff6[8]; | |||
int8_t tmpbuff7[8]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff0, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff1, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff2, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff3, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff4, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff5, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | |||
ldin /= 2; | |||
int y = y0; | |||
for (; y + 7 < ymax; y += 8) { | |||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
const int8_t* inptr4 = inptr3 + ldin; | |||
const int8_t* inptr5 = inptr4 + ldin; | |||
const int8_t* inptr6 = inptr5 + ldin; | |||
const int8_t* inptr7 = inptr6 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
prefetch_2x(inptr4); | |||
prefetch_2x(inptr5); | |||
prefetch_2x(inptr6); | |||
prefetch_2x(inptr7); | |||
int K = (kmax - k0)/2; | |||
//! read 4 * 16 in each row | |||
for (; K > 3; K -= 4) { | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
} | |||
if (K > 0) { | |||
std::memcpy(tmpbuff0,inptr0,K); | |||
std::memcpy(tmpbuff1,inptr1,K); | |||
std::memcpy(tmpbuff2,inptr2,K); | |||
std::memcpy(tmpbuff3,inptr3,K); | |||
std::memcpy(tmpbuff4,inptr4,K); | |||
std::memcpy(tmpbuff5,inptr5,K); | |||
std::memcpy(tmpbuff6,inptr6,K); | |||
std::memcpy(tmpbuff7,inptr7,K); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
inptr3 = tmpbuff3; | |||
inptr4 = tmpbuff4; | |||
inptr5 = tmpbuff5; | |||
inptr6 = tmpbuff6; | |||
inptr7 = tmpbuff7; | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
} | |||
} | |||
for (; y < ymax; y += 8) { | |||
const int8_t* inptr0 = inptr + y * ldin + k0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
const int8_t* inptr4 = inptr3 + ldin; | |||
const int8_t* inptr5 = inptr4 + ldin; | |||
const int8_t* inptr6 = inptr5 + ldin; | |||
const int8_t* inptr7 = inptr6 + ldin; | |||
int K = (kmax - k0)/2; | |||
//! read 4 * 16 in each row | |||
for (; K > 3; K -= 4) { | |||
if (y + 7 >= ymax) { | |||
switch (y + 7 - ymax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
} | |||
if (K > 0) { | |||
if (y + 7 >= ymax) { | |||
switch (y + 7 - ymax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
} | |||
} | |||
std::memcpy(tmpbuff0,inptr0,K); | |||
std::memcpy(tmpbuff1,inptr1,K); | |||
std::memcpy(tmpbuff2,inptr2,K); | |||
std::memcpy(tmpbuff3,inptr3,K); | |||
std::memcpy(tmpbuff4,inptr4,K); | |||
std::memcpy(tmpbuff5,inptr5,K); | |||
std::memcpy(tmpbuff6,inptr6,K); | |||
std::memcpy(tmpbuff7,inptr7,K); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
inptr3 = tmpbuff3; | |||
inptr4 = tmpbuff4; | |||
inptr5 = tmpbuff5; | |||
inptr6 = tmpbuff6; | |||
inptr7 = tmpbuff7; | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
} | |||
} | |||
} | |||
//packb | |||
static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[8]; | |||
int8_t tmpbuff0[8]; | |||
int8_t tmpbuff1[8]; | |||
int8_t tmpbuff2[8]; | |||
int8_t tmpbuff3[8]; | |||
int8_t tmpbuff4[8]; | |||
int8_t tmpbuff5[8]; | |||
int8_t tmpbuff6[8]; | |||
int8_t tmpbuff7[8]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff0, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff1, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff2, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff3, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff4, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff5, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | |||
const int ksize = kmax - k0; | |||
const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 | |||
int8_t* outptr = out; | |||
int8_t* outptr_interleave = nullptr; | |||
int k = k0; | |||
ldin /= 2; | |||
xmax = xmax / 2; | |||
for (; k + 7 < kmax; k += 8) { | |||
const int8_t* inptr0 = in + k * ldin + x0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
const int8_t* inptr4 = inptr3 + ldin; | |||
const int8_t* inptr5 = inptr4 + ldin; | |||
const int8_t* inptr6 = inptr5 + ldin; | |||
const int8_t* inptr7 = inptr6 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
prefetch_2x(inptr2); | |||
prefetch_2x(inptr3); | |||
prefetch_2x(inptr4); | |||
prefetch_2x(inptr5); | |||
prefetch_2x(inptr6); | |||
prefetch_2x(inptr7); | |||
int x = x0; | |||
int8_t* outptr_inner = outptr; | |||
for (; x + 3 < xmax; x += 4) { | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
if (x < xmax) { | |||
int remainx = xmax - x; | |||
std::memcpy(tmpbuff0,inptr0,remainx); | |||
std::memcpy(tmpbuff1,inptr1,remainx); | |||
std::memcpy(tmpbuff2,inptr2,remainx); | |||
std::memcpy(tmpbuff3,inptr3,remainx); | |||
std::memcpy(tmpbuff4,inptr4,remainx); | |||
std::memcpy(tmpbuff5,inptr5,remainx); | |||
std::memcpy(tmpbuff6,inptr6,remainx); | |||
std::memcpy(tmpbuff7,inptr7,remainx); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
inptr3 = tmpbuff3; | |||
inptr4 = tmpbuff4; | |||
inptr5 = tmpbuff5; | |||
inptr6 = tmpbuff6; | |||
inptr7 = tmpbuff7; | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
outptr += 64; | |||
} | |||
if (k < kmax) { | |||
const int8_t* inptr0 = in + k * ldin + x0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
const int8_t* inptr2 = inptr1 + ldin; | |||
const int8_t* inptr3 = inptr2 + ldin; | |||
const int8_t* inptr4 = inptr3 + ldin; | |||
const int8_t* inptr5 = inptr4 + ldin; | |||
const int8_t* inptr6 = inptr5 + ldin; | |||
const int8_t* inptr7 = inptr6 + ldin; | |||
int k_remain = kmax - k - 1; | |||
int x = x0; | |||
int8_t* outptr_inner = outptr; | |||
for (; x + 3 < xmax; x += 4) { | |||
switch (k_remain) { | |||
case 0: | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 1: | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 2: | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 3: | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 4: | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 5: | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 6: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
break; | |||
} | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
if (x < xmax) { | |||
switch (k_remain) { | |||
case 0: | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 1: | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 2: | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 3: | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 4: | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 5: | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU; | |||
case 6: | |||
inptr7 = zerobuff; | |||
break; | |||
default: | |||
megdnn_assert(0); | |||
break; | |||
} | |||
int remainx = xmax - x; | |||
outptr_interleave = outptr_inner; | |||
std::memcpy(tmpbuff0,inptr0,remainx); | |||
std::memcpy(tmpbuff1,inptr1,remainx); | |||
std::memcpy(tmpbuff2,inptr2,remainx); | |||
std::memcpy(tmpbuff3,inptr3,remainx); | |||
std::memcpy(tmpbuff4,inptr4,remainx); | |||
std::memcpy(tmpbuff5,inptr5,remainx); | |||
std::memcpy(tmpbuff6,inptr6,remainx); | |||
std::memcpy(tmpbuff7,inptr7,remainx); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
inptr3 = tmpbuff3; | |||
inptr4 = tmpbuff4; | |||
inptr5 = tmpbuff5; | |||
inptr6 = tmpbuff6; | |||
inptr7 = tmpbuff7; | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
} | |||
} | |||
} // namespace matmul_4x4x16 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,109 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
// ===========================gemm_s4x4x16_s4_8x8x8================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | |||
void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} else { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} else { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} | |||
} | |||
void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
//! K is packed to times of 8 | |||
K = round_up<size_t>(K, 8); | |||
const int K8 = K * 8; | |||
size_t m = 0; | |||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | |||
int16_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
packA += K8; | |||
} | |||
for (; m < M; m += A_INTERLEAVE) { | |||
int16_t* output = C + (m * LDC); | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, | |||
std::min<size_t>(M - m, A_INTERLEAVE), | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
packA += K8; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,26 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||
gemm_s4x4x16_s4_8x8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -50,6 +50,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
#else | |||
AlgoQuint8K8x8x8 quint8_k8x8x8; | |||
#endif | |||
AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
@@ -87,6 +88,7 @@ public: | |||
#else | |||
m_all_algos.emplace_back(&quint8_k8x8x8); | |||
#endif | |||
m_all_algos.emplace_back(&int4x4x16_k8x8x8); | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
@@ -66,8 +66,8 @@ private: | |||
#else | |||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||
#endif | |||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||
class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||
class AlgoPack; | |||
public: | |||
static const AlgoPack& algo_pack(); | |||
@@ -33,6 +33,8 @@ void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { | |||
C_candi = dtype::QuantizedS32(mul_scale(A, B)); | |||
} else if (A.enumv() == DTypeEnum::Quantized4Asymm) { | |||
C_candi = dtype::QuantizedS32(mul_scale(A, B)); | |||
} else if (A.enumv() == DTypeEnum::QuantizedS4) { | |||
C_candi = dtype::QuantizedS16(mul_scale(A, B)); | |||
} | |||
if (!C.valid()) { | |||
C = C_candi; | |||
@@ -169,6 +171,8 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, | |||
A.dtype.enumv() == DTypeEnum::Quantized8Asymm || | |||
A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32); | |||
} else if(A.dtype.enumv() == DTypeEnum::QuantizedS4){ | |||
megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS16); | |||
} | |||
megdnn_assert(param().compute_mode != | |||
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( | |||
@@ -154,6 +154,7 @@ public: | |||
AARCH64_QUINT8_K8X8X4_DOTPROD, | |||
AARCH64_QUINT8_GEMV_DOTPROD, | |||
AARCH64_QUINT8_K8X8X8, | |||
AARCH64_INT4X4X16_K8X8X8, | |||
#else | |||
ARMV7_F32 = 1 << 16, | |||
ARMV7_F32_MK4_PACK_4X12, | |||
@@ -179,6 +179,42 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||
C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC, | |||
nA.layout.dtype, nB.layout.dtype); | |||
} | |||
template <bool transA, bool transB> | |||
void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace, | |||
const param::MatrixMul& param) { | |||
auto convert_layout = [](const TensorLayout& layout) { | |||
auto ret = layout; | |||
auto param = layout.dtype.param<dtype::QuantizedS4>(); | |||
ret.dtype = dtype::QuantizedS8(param.scale); | |||
return ret; | |||
}; | |||
TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||
TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||
convert_layout(B.layout)}; | |||
auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | |||
auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||
auto out_ptr = | |||
out.compatible_ptr<int8_t>() + out.layout.span().low_byte; | |||
for (size_t i = 0; i < in.layout.span().dist_elem(); i += 2) { | |||
int8_t cur = ptr[i / 2]; | |||
out_ptr[i] = cur << 4; | |||
out_ptr[i] = out_ptr[i] >> 4; | |||
out_ptr[i + 1] = cur >> 4; | |||
} | |||
}; | |||
convert_4to8(A, nA); | |||
convert_4to8(B, nB); | |||
auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||
auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||
LDC = C.layout.stride[0]; | |||
run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>( | |||
nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(), | |||
C.compatible_ptr<dt_int16>(), M, N, K, LDA, LDB, LDC, | |||
nA.layout.dtype, nB.layout.dtype); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
@@ -26,7 +26,8 @@ size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | |||
MIDOUT_BEGIN( | |||
megdnn_naive_matmul, | |||
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | |||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm || | |||
A.dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
return (A.span().dist_elem() + B.span().dist_elem()) * | |||
sizeof(uint8_t); | |||
} | |||
@@ -104,6 +105,11 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
param.format == param::MatrixMul::Format::DEFAULT) { | |||
exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param); | |||
return; | |||
} else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 && | |||
param.format == param::MatrixMul::Format::DEFAULT) { | |||
exec_matrix_mul_qint4x4x16_helper<TA, TB>(A, B, C, workspace, param); | |||
return; | |||
} | |||
#undef cb | |||
megdnn_throw(ssprintf( | |||
@@ -164,6 +164,55 @@ TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K4x4x16) { | |||
handle(), "AARCH64_INT8X8X16_K4X4X16"); | |||
} | |||
TEST_F(AARCH64, MATRIX_MUL_INT4x4x16_K8x8x8_QUANTIZEDS4) { | |||
param::MatrixMul param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
Checker<MatrixMul> checker(handle()); | |||
checker.set_dtype(0, dtype::QuantizedS4{0.6}) | |||
.set_dtype(1, dtype::QuantizedS4{0.5}) | |||
.set_dtype(2, dtype::QuantizedS16{0.6 * 0.5}) | |||
.set_param(param); | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("AARCH64_INT4X4X16_K8X8X8")); | |||
auto run = [&](size_t M, size_t N, size_t K) { | |||
printf("M N K %zu %zu %zu \n", M, N, K); | |||
TensorShape A, B; | |||
if (param.transposeA) { | |||
A = TensorShape{K, M}; | |||
} else { | |||
A = TensorShape{M, K}; | |||
} | |||
if (param.transposeB) { | |||
B = TensorShape{N, K}; | |||
} else { | |||
B = TensorShape{K, N}; | |||
} | |||
checker.exec({A, B, {}}); | |||
}; | |||
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 20}) | |||
for (size_t n : {2, 4, 6, 8, 10, 12, 14, 16, 24}) | |||
for (size_t k : {2, 4, 6, 8, 10, 12, 14, 16, 32}) | |||
run(m, n, k); | |||
for (size_t k = 4; k <= 256; k *= 8) { | |||
for (size_t m = 4; m <= 256; m *= 4) { | |||
for (size_t n = 4; n <= 256; n *= 4) { | |||
run(m, n, k); | |||
} | |||
} | |||
} | |||
param.transposeA = true; | |||
run(8,8,8); | |||
run(16,8,16); | |||
param.transposeB = true; | |||
run(8,8,8); | |||
run(16,16,16); | |||
} | |||
TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { | |||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||
handle(), "AARCH64_INT16X16X32_K12X8X1"); | |||
@@ -410,6 +459,63 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { | |||
run(384, 384, 384); | |||
} | |||
TEST_F(AARCH64, BENCHMARK_4x4x16_vs_8x8x16) { | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
Benchmarker<MatrixMul> benchmarker(handle()); | |||
Benchmarker<MatrixMul> benchmarker_int4_4x4x16(handle()); | |||
benchmarker_int4_4x4x16.set_times(RUNS) | |||
.set_dtype(0, dtype::QuantizedS4{0.3}) | |||
.set_dtype(1, dtype::QuantizedS4{0.3}) | |||
.set_dtype(2, dtype::QuantizedS16{0.09}) | |||
.set_param(param) | |||
.set_display(false); | |||
benchmarker.set_times(RUNS) | |||
.set_dtype(0, dtype::Int8{}) | |||
.set_dtype(1, dtype::Int8{}) | |||
.set_dtype(2, dtype::Int16{}) | |||
.set_param(param) | |||
.set_display(false); | |||
benchmarker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16")); | |||
auto run = [&](size_t M, size_t N, size_t K) { | |||
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
auto int4416_used = | |||
benchmarker_int4_4x4x16.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
float computations = 2.f * M * K * N * 1e-6; | |||
printf("run: {%zu{M} %zu{K} %zu{N}} normal 8x8x16 used: %f ms %f " | |||
"Gflops int4416 used %f int4416_gflops %f speedup %f\n", | |||
M, K, N, default_used, computations / default_used, int4416_used, | |||
computations / int4416_used, default_used / int4416_used); | |||
}; | |||
for (int m = 32; m <= 1024; m += 32) | |||
for (int n = 32; n <= 1024; n += 32) | |||
for (int k = 32; k <= 512; k += 32) | |||
run(m, n, k); | |||
run(32, 32, 32); | |||
run(32, 32, 8); | |||
run(32, 32, 16); | |||
run(32, 32, 24); | |||
run(32 * 2, 32 * 2, 32); | |||
run(32 * 4, 32 * 4, 32); | |||
run(32 * 6, 32 * 6, 32); | |||
run(32 * 8, 32 * 8, 32); | |||
run(32 * 2, 32 * 2, 32 * 2); | |||
run(32 * 4, 32 * 4, 32 * 3); | |||
run(32 * 6, 32 * 6, 32 * 4); | |||
run(32 * 8, 32 * 8, 32 * 5); | |||
run(32 * 10, 32 * 10, 32 * 10); | |||
run(384, 384, 384); | |||
run(256, 256, 384); | |||
run(512, 512, 384); | |||
run(1024, 1024, 384); | |||
} | |||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) { | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||
@@ -183,6 +183,34 @@ void IIDRNG::gen(const TensorND& tensor) { | |||
} | |||
return; | |||
} | |||
if (tensor.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
auto ptr = static_cast<int8_t*>(tensor.raw_ptr); | |||
if (output_is_float()) { | |||
for (size_t i = 0; i < nr_elems; i += 2) { | |||
int8_t val0 = | |||
tensor.layout.dtype.param<dt_qint4>() | |||
.quantize(static_cast<float>(gen_single_val())) | |||
.as_int8(); | |||
int8_t val1 = | |||
tensor.layout.dtype.param<dt_qint4>() | |||
.quantize(static_cast<float>(gen_single_val())) | |||
.as_int8(); | |||
ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4); | |||
} | |||
} else { | |||
for (size_t i = 0; i < nr_elems; i += 2) { | |||
int8_t val0 = static_cast<int8_t>(gen_single_val()); | |||
int8_t val1 = static_cast<int8_t>(gen_single_val()); | |||
val0 = std::min(val0,DTypeTrait<dtype::QuantizedS4>::max()); | |||
val0 = std::max(val0,DTypeTrait<dtype::QuantizedS4>::min()); | |||
val1 = std::min(val1,DTypeTrait<dtype::QuantizedS4>::max()); | |||
val1 = std::max(val1,DTypeTrait<dtype::QuantizedS4>::min()); | |||
ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4); | |||
} | |||
} | |||
return; | |||
} | |||
megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", | |||
tensor.layout.dtype.name()); | |||
} | |||
@@ -203,6 +203,67 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) { | |||
}); | |||
} | |||
TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false); | |||
auto GenTensorValueQuint4 = [](const TensorShape& shape, | |||
dtype::QuantizedS4 dtype, | |||
const std::vector<int>& values) { | |||
TensorND tensor; | |||
tensor.layout = {shape, dtype}; | |||
tensor.raw_ptr = | |||
static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte())); | |||
uint8_t* ptr = static_cast<uint8_t*>(tensor.raw_ptr); | |||
megdnn_assert(values.size() == tensor.layout.span().dist_elem()); | |||
for (size_t i = 0; i < tensor.layout.span().dist_elem(); i += 2) { | |||
int val0 = values[i], val1 = values[i + 1]; | |||
ptr[i / 2] =(val0 & 0xF) | (val1 << 4); | |||
} | |||
return tensor; | |||
}; | |||
using Param = MatrixMul::Param; | |||
Param param; | |||
checker.set_param(param); | |||
checker.set_dtype(2, dtype::QuantizedS16(0.3f * 0.3f)); | |||
checker.exect( | |||
Testcase{ | |||
GenTensorValueQuint4( | |||
{8, 8}, dtype::QuantizedS4(0.3f), | |||
{-8, 7, 2, 1, 2, 3, 2, 7, | |||
2, 5, 3, 3, 7, 4, -7, 1, | |||
-5, 7, -4, -1, -1, 2, 4, 1, | |||
7, 2, -6, -2, -6, 3, 4, 4, | |||
-2, 2, 3, 0, 6, 5, 3, 4, | |||
-1, -1, -5, 5, 2, 5, 1, 4, | |||
6, 2, 0, 0, 3, 2, 2, 1, | |||
-4, -3, 7, 5, 0, 3, 2, 3}), | |||
GenTensorValueQuint4( | |||
{8, 8}, dtype::QuantizedS4(0.3f), | |||
{5, -8, -7, -6, 4, 7, -5, -5, | |||
-4, 7, -3, -2, 5, 6, 4, 2, | |||
3, -1, 2, 2, 7, 3, 6, 0, | |||
5, 4, 0, 2, 2, 3, 3, 2, | |||
1, -8, -7, -6, 0, -5, -4, 4, | |||
-3, 7, 1, 6, -2, 2, -1, 5, | |||
2, 0, 7, 6, 5, 4, 3, 2, | |||
0, 0, 1, 0, 5, 2, 2, 6}), | |||
{}}, | |||
Testcase{ | |||
{}, | |||
{}, | |||
TensorValue( | |||
{8, 8}, dtype::QuantizedS16(0.3f * 0.3f), | |||
{-60, 120, 49, 58, 58, 13, 92, 125, | |||
-5, 0, -116, -70, 22, 9, -14, 46, | |||
-69, 111, 44, 48, 6, 19, 42, 57, | |||
-8, 25, 10, 16, 26, 97, -28, -12, | |||
-12, 14, 2, 26, 48, 7, 24, 93, | |||
-2, 45, 2, 32, -19, -1, -16, 72, | |||
23, -44, -52, -34, 45, 53, -28, 6, | |||
33, 45, 71, 84, 47, 10, 74, 61}) | |||
}); | |||
} | |||
TEST_F(NAIVE, MATRIX_MUL_QUANTIZED8x8x32) { | |||
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false); | |||
MatrixMul::Param param; | |||