GitOrigin-RevId: b6af21e8e3
release-1.1
@@ -1310,4 +1310,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, | |||
int32_t); | |||
#endif | |||
/* ===================== Int8x8x16 K8x8x8 algo ===================== */ | |||
namespace { | |||
void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto trA = kern_param.trA, trB = kern_param.trB; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
C_type = kern_param.C_type; | |||
const auto Aptr = kern_param.A<dt_int8>(), | |||
Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int16>(); | |||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, | |||
B_type, C_type); | |||
megdnn::matmul::GemmInterleaved< | |||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, | |||
strategy) | |||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||
kern_param.workspace_ptr); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return can_be_treated_as_int8x8x16(kern_size_param) && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
!kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred( | |||
const KernSizeParam&) const { | |||
return true; | |||
} | |||
size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace( | |||
const KernSizeParam& kern_size_param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) { | |||
auto M = kern_size_param.M, N = kern_size_param.N, | |||
K = kern_size_param.K; | |||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
C_type = kern_size_param.C_type; | |||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, | |||
B_type, C_type); | |||
return megdnn::matmul::GemmInterleaved< | |||
matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, | |||
strategy) | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern( | |||
const KernSizeParam&) const { | |||
return int8x8x16_mk4_8x8x8_kern; | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x16MK4_K8x8x8Impl"_hash, | |||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t, | |||
int16_t); | |||
// vim: syntax=cpp.doxygen |
@@ -202,6 +202,22 @@ public: | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||
} | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
@@ -2101,6 +2101,62 @@ static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) { | |||
vreinterpretq_s32_s8(input2), 3); | |||
} | |||
template <typename T> | |||
static inline void interleave_8x8_mk4_b(const T*& inptr0, const T*& inptr1, | |||
T*& outptr) { | |||
static_assert( | |||
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | |||
"transpose_8x4_1_b only support uint8_t and int8_t"); | |||
asm volatile( | |||
"ld1 {v0.4s}, [%[inptr0]], #16\n" | |||
"ld1 {v1.4s}, [%[inptr1]], #16\n" | |||
"ld1 {v2.4s}, [%[inptr0]], #16\n" | |||
"ld1 {v3.4s}, [%[inptr1]], #16\n" | |||
"zip1 v4.4s, v0.4s, v1.4s \n" | |||
"zip2 v5.4s, v0.4s, v1.4s \n" | |||
"zip1 v6.4s, v2.4s, v3.4s\n" | |||
"zip2 v7.4s, v2.4s, v3.4s\n" | |||
"st1 {v4.4s},[%[outptr]],#16\n" | |||
"st1 {v5.4s},[%[outptr]],#16\n" | |||
"st1 {v6.4s},[%[outptr]],#16\n" | |||
"st1 {v7.4s},[%[outptr]],#16\n" | |||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||
[outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); | |||
} | |||
template <typename T> | |||
static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1, | |||
T* outptr) { | |||
static_assert( | |||
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | |||
"transpose_8x4_1_b only support uint8_t and int8_t"); | |||
asm volatile( | |||
"ld4 {v0.8b-v3.8b}, [%[inptr0]], #32\n" | |||
"ld4 {v4.8b-v7.8b}, [%[inptr1]], #32\n" | |||
"st1 {v0.2s},[%[outptr]],#8\n" | |||
"st1 {v1.2s},[%[outptr]],#8\n" | |||
"st1 {v2.2s},[%[outptr]],#8\n" | |||
"st1 {v3.2s},[%[outptr]],#8\n" | |||
"st1 {v4.2s},[%[outptr]],#8\n" | |||
"st1 {v5.2s},[%[outptr]],#8\n" | |||
"st1 {v6.2s},[%[outptr]],#8\n" | |||
"st1 {v7.2s},[%[outptr]],#8\n" | |||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||
[outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); | |||
} | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -13,6 +13,7 @@ | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
@@ -357,4 +358,81 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
} | |||
} | |||
// ===========================gemm_s8x8x16_mk4_8x8x8================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | |||
void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool) const { | |||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, | |||
ymax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool) const { | |||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
constexpr size_t pack_size = 4; | |||
constexpr size_t pack_m = 8; | |||
constexpr size_t pack_n = 8; | |||
const size_t remain_n = N % pack_n; | |||
size_t remain_m = M % pack_m; | |||
K = round_up<size_t>(K, 8); | |||
size_t KSIZE8 = K * pack_n; | |||
size_t m_idx = 0; | |||
for (; m_idx + pack_m <= M; m_idx += pack_m) { | |||
int16_t* output = C + (m_idx / pack_size * LDC); | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_m, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += KSIZE8; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_m, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += KSIZE8; | |||
} | |||
packA += KSIZE8; | |||
} | |||
if (remain_m == 4) { | |||
int16_t* output = C + (m_idx / pack_size * LDC); | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, 4, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, 4, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -26,6 +26,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16, | |||
16, 12, 4, false, false, | |||
gemm_s8x8x16_mk4_16x12_a53); | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, false, | |||
gemm_s8x8x16_mk4_8x8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
@@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | |||
AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; | |||
AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8; | |||
AlgoInt8x8x16MK4_K8x8x8 int8x8x16_mk4_k8x8x8; | |||
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | |||
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | |||
@@ -73,6 +74,7 @@ public: | |||
#endif | |||
all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||
@@ -57,6 +57,7 @@ private: | |||
#else | |||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||
#endif | |||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||
class AlgoPack; | |||
}; | |||
@@ -122,6 +122,20 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) { | |||
std::move(args)); | |||
} | |||
TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_MK4) { | |||
std::vector<matrix_mul::TestArg> args; | |||
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}) | |||
for (size_t n : | |||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24}) | |||
for (size_t k : | |||
{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, | |||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29}) | |||
args.emplace_back(m, n, k, 0); | |||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
handle(), "AARCH64_INT8X8X16_MK4_K8X8X8", | |||
param::MatrixMul::Format::MK4, 1, 1e-3, | |||
std::move(args)); | |||
} | |||
TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) { | |||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
handle(), "AARCH64_INT8X8X16_MK4_4X4X8", | |||
@@ -396,6 +410,71 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { | |||
run(384, 384, 384); | |||
} | |||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) { | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
Benchmarker<MatrixMul> benchmarker(handle()); | |||
Benchmarker<MatrixMul> benchmarker_mk4(handle()); | |||
Benchmarker<MatrixMul> benchmarker_mk4_4x4x8(handle()); | |||
benchmarker.set_times(RUNS) | |||
.set_dtype(0, dtype::Int8{}) | |||
.set_dtype(1, dtype::Int8{}) | |||
.set_dtype(2, dtype::Int16{}) | |||
.set_param(param) | |||
.set_display(false); | |||
benchmarker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16")); | |||
param.format = MatrixMul::Param::Format::MK4; | |||
benchmarker_mk4.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>( | |||
"AARCH64_INT8X8X16_MK4_K8X8X8" | |||
)); | |||
benchmarker_mk4.set_times(RUNS) | |||
.set_dtype(0, dtype::Int8{}) | |||
.set_dtype(1, dtype::Int8{}) | |||
.set_dtype(2, dtype::Int16{}) | |||
.set_param(param) | |||
.set_display(false); | |||
benchmarker_mk4_4x4x8.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8")); | |||
benchmarker_mk4_4x4x8.set_times(RUNS) | |||
.set_dtype(0, dtype::Int8{}) | |||
.set_dtype(1, dtype::Int8{}) | |||
.set_dtype(2, dtype::Int16{}) | |||
.set_param(param) | |||
.set_display(false); | |||
auto run = [&](size_t M, size_t N, size_t K) { | |||
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
auto mk_used = benchmarker_mk4.exec( | |||
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||
RUNS; | |||
auto mk4_4x4x8_used = | |||
benchmarker_mk4_4x4x8.exec( | |||
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||
RUNS; | |||
float computations = 2.f * M * K * N * 1e-6; | |||
printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " | |||
"%f Gflops speedup: %f, mk4_4x4x8 %f Gflops %f ms speedup: %f\n", | |||
M, K, N, default_used, computations / default_used, mk_used, | |||
computations / mk_used, default_used / mk_used, | |||
computations / mk4_4x4x8_used, mk4_4x4x8_used , mk4_4x4x8_used/mk_used); | |||
}; | |||
run(384, 384, 384); | |||
run(512, 512, 512); | |||
run(1024, 1024, 384); | |||
run(256, 256, 384); | |||
for(int m = 32; m <= 512;m*=2) | |||
for(int n = 32; n <= 512;n*=2) | |||
for(int k = 32; k < 512;k*=2){ | |||
run(m,n,k); | |||
} | |||
} | |||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||