@@ -138,6 +138,63 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||
return f32_gemv_kern; | |||
} | |||
/* ===================== F32 Gevm algo ===================== */ | |||
namespace { | |||
void gevm_fp32_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDB = kern_param.LDB; | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
arm_common::sgemm_sgemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
} | |||
void gevm_int8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDB = kern_param.LDB; | |||
const auto Aptr = kern_param.A<dt_int8>(), | |||
Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
arm_common::matmul::gemv_like_int8(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoGevm::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
// enumerate the M, N, K, only usable when preferred | |||
bool fp32_ok = | |||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32(); | |||
return (fp32_ok || can_be_treated_as_int8x8x32(kern_size_param)) && | |||
preferred(kern_size_param); | |||
} | |||
bool MatrixMulImpl::AlgoGevm::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
auto M = kern_size_param.M; | |||
return kern_size_param.trB && M == 1; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern( | |||
const KernSizeParam& kern_size_param) const { | |||
if (kern_size_param.A_type == dtype::Float32()) { | |||
return gevm_fp32_kern; | |||
} else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) { | |||
return gevm_int8_kern; | |||
} else { | |||
megdnn_assert( | |||
false, "no avaliable kern got A_type: %s B_type: %s C_type: %s", | |||
kern_size_param.A_type.name(), kern_size_param.B_type.name(), | |||
kern_size_param.C_type.name()); | |||
} | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
/* ===================== F16 Gemv algo ===================== */ | |||
namespace { | |||
@@ -70,6 +70,21 @@ public: | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
#endif | |||
class MatrixMulImpl::AlgoGevm : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "ARM_COMMON_GEVM"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -27,7 +27,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16Gemv f16gemv; | |||
#endif | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
AlgoGevm gevm; | |||
public: | |||
AlgoPack() { | |||
all_algos.emplace_back(&int8x8x16); | |||
@@ -35,6 +36,7 @@ public: | |||
all_algos.emplace_back(&f16gemv); | |||
#endif | |||
all_algos.emplace_back(&int8x8x32_gemv); | |||
all_algos.emplace_back(&gevm); | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
}; | |||
@@ -27,6 +27,7 @@ protected: | |||
static void* const sm_arm_common_algo_type; | |||
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | |||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||
class AlgoGevm; // Arm_common Gemv(support int8 and fp32) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class AlgoF16Gemv; | |||
#endif | |||
@@ -164,6 +164,62 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||
run(M, K, N); | |||
} | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
Param param; | |||
param.transposeA = false; | |||
param.transposeB = true; | |||
TensorShape A, B; | |||
A = TensorShape{M, K}; | |||
B = TensorShape{N, K}; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.execs({A, B, {}}); | |||
}; | |||
// M = 1 | |||
for (size_t N : {1, 10, 16, 33, 64}) | |||
for (size_t K : {7, 512, 1024}) | |||
for (size_t M : {1}) | |||
run(M, K, N); | |||
} | |||
TEST_F(ARM_COMMON, FP32_GEVM) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||
checker.set_epsilon(1e-2); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
Param param; | |||
param.transposeA = false; | |||
param.transposeB = true; | |||
TensorShape A, B; | |||
A = TensorShape{M, K}; | |||
B = TensorShape{N, K}; | |||
checker.set_param(param).execs({A, B, {}}); | |||
}; | |||
// M = 1 | |||
for (size_t M : {1}) | |||
for (size_t K : {1000, 4096, 25088}) | |||
for (size_t N : {1000, 4096}) | |||
run(M, K, N); | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||