@@ -138,6 +138,63 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||||
return f32_gemv_kern; | return f32_gemv_kern; | ||||
} | } | ||||
/* ===================== F32 Gevm algo ===================== */ | |||||
namespace { | |||||
void gevm_fp32_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDB = kern_param.LDB; | |||||
const auto Aptr = kern_param.A<dt_float32>(), | |||||
Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
arm_common::sgemm_sgemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||||
} | |||||
void gevm_int8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDB = kern_param.LDB; | |||||
const auto Aptr = kern_param.A<dt_int8>(), | |||||
Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
arm_common::matmul::gemv_like_int8(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoGevm::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
// enumerate the M, N, K, only usable when preferred | |||||
bool fp32_ok = | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float32(); | |||||
return (fp32_ok || can_be_treated_as_int8x8x32(kern_size_param)) && | |||||
preferred(kern_size_param); | |||||
} | |||||
bool MatrixMulImpl::AlgoGevm::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
auto M = kern_size_param.M; | |||||
return kern_size_param.trB && M == 1; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern( | |||||
const KernSizeParam& kern_size_param) const { | |||||
if (kern_size_param.A_type == dtype::Float32()) { | |||||
return gevm_fp32_kern; | |||||
} else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) { | |||||
return gevm_int8_kern; | |||||
} else { | |||||
megdnn_assert( | |||||
false, "no avaliable kern got A_type: %s B_type: %s C_type: %s", | |||||
kern_size_param.A_type.name(), kern_size_param.B_type.name(), | |||||
kern_size_param.C_type.name()); | |||||
} | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
/* ===================== F16 Gemv algo ===================== */ | /* ===================== F16 Gemv algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -70,6 +70,21 @@ public: | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
}; | }; | ||||
#endif | #endif | ||||
class MatrixMulImpl::AlgoGevm : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARM_COMMON_GEVM"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
}; | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -27,7 +27,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16Gemv f16gemv; | AlgoF16Gemv f16gemv; | ||||
#endif | #endif | ||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||||
AlgoGevm gevm; | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
all_algos.emplace_back(&int8x8x16); | all_algos.emplace_back(&int8x8x16); | ||||
@@ -35,6 +36,7 @@ public: | |||||
all_algos.emplace_back(&f16gemv); | all_algos.emplace_back(&f16gemv); | ||||
#endif | #endif | ||||
all_algos.emplace_back(&int8x8x32_gemv); | all_algos.emplace_back(&int8x8x32_gemv); | ||||
all_algos.emplace_back(&gevm); | |||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
@@ -27,6 +27,7 @@ protected: | |||||
static void* const sm_arm_common_algo_type; | static void* const sm_arm_common_algo_type; | ||||
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | ||||
class AlgoF32Gemv; // Arm_common F32 Gemv | class AlgoF32Gemv; // Arm_common F32 Gemv | ||||
class AlgoGevm; // Arm_common Gemv(support int8 and fp32) | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class AlgoF16Gemv; | class AlgoF16Gemv; | ||||
#endif | #endif | ||||
@@ -164,6 +164,62 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||||
run(M, K, N); | run(M, K, N); | ||||
} | } | ||||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
auto run = [&](size_t M, size_t K, size_t N) { | |||||
Param param; | |||||
param.transposeA = false; | |||||
param.transposeB = true; | |||||
TensorShape A, B; | |||||
A = TensorShape{M, K}; | |||||
B = TensorShape{N, K}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.execs({A, B, {}}); | |||||
}; | |||||
// M = 1 | |||||
for (size_t N : {1, 10, 16, 33, 64}) | |||||
for (size_t K : {7, 512, 1024}) | |||||
for (size_t M : {1}) | |||||
run(M, K, N); | |||||
} | |||||
TEST_F(ARM_COMMON, FP32_GEVM) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||||
checker.set_epsilon(1e-2); | |||||
auto run = [&](size_t M, size_t K, size_t N) { | |||||
Param param; | |||||
param.transposeA = false; | |||||
param.transposeB = true; | |||||
TensorShape A, B; | |||||
A = TensorShape{M, K}; | |||||
B = TensorShape{N, K}; | |||||
checker.set_param(param).execs({A, B, {}}); | |||||
}; | |||||
// M = 1 | |||||
for (size_t M : {1}) | |||||
for (size_t K : {1000, 4096, 25088}) | |||||
for (size_t N : {1000, 4096}) | |||||
run(M, K, N); | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | ||||