@@ -71,13 +71,13 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr, | |||||
auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; | auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; | ||||
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | ||||
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_offset * n + n_offset; | |||||
auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset; | |||||
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16(iter_a_ptr, iter_b_ptr, | matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16(iter_a_ptr, iter_b_ptr, | ||||
iter_c_ptr, ldc, k); | iter_c_ptr, ldc, k); | ||||
} | } | ||||
if (n_end < n) { | if (n_end < n) { | ||||
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_offset * n + n_end; | |||||
auto iter_c_ptr = c_ptr + m_offset * ldc + n_end; | |||||
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( | matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_tile, | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_tile, | ||||
n_remain); | n_remain); | ||||
@@ -87,14 +87,14 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr, | |||||
auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; | auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; | ||||
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | ||||
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_end * n + n_offset; | |||||
auto iter_c_ptr = c_ptr + m_end * ldc + n_offset; | |||||
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( | matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | ||||
n_tile); | n_tile); | ||||
} | } | ||||
if (n_end < n) { | if (n_end < n) { | ||||
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_end * n + n_end; | |||||
auto iter_c_ptr = c_ptr + m_end * ldc + n_end; | |||||
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( | matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | ||||
n_remain); | n_remain); | ||||
@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, | |||||
auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; | auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; | ||||
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | ||||
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_offset * n + n_offset; | |||||
auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset; | |||||
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2( | matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); | ||||
} | } | ||||
if (n_remain > 0) { | if (n_remain > 0) { | ||||
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_offset * n + n_end; | |||||
auto iter_c_ptr = c_ptr + m_offset * ldc + n_end; | |||||
if (n_remain <= 8) { | if (n_remain <= 8) { | ||||
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( | matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); | ||||
@@ -79,13 +79,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, | |||||
auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; | auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; | ||||
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { | ||||
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_end * n + n_offset; | |||||
auto iter_c_ptr = c_ptr + m_end * ldc + n_offset; | |||||
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_remain_m( | matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_remain_m( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); | ||||
} | } | ||||
if (n_remain > 0) { | if (n_remain > 0) { | ||||
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_end * n + n_end; | |||||
auto iter_c_ptr = c_ptr + m_end * ldc + n_end; | |||||
if (n_remain <= 8) { | if (n_remain <= 8) { | ||||
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( | matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | ||||
@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, | |||||
auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; | auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; | ||||
for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { | for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { | ||||
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_offset * n + n_offset; | |||||
auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset; | |||||
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2( | matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); | ||||
} | } | ||||
if (n_remain > 0) { | if (n_remain > 0) { | ||||
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_offset * n + n_end; | |||||
auto iter_c_ptr = c_ptr + m_offset * ldc + n_end; | |||||
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_n( | matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_n( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); | ||||
} | } | ||||
@@ -74,13 +74,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, | |||||
auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; | auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; | ||||
for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { | for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { | ||||
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_end * n + n_offset; | |||||
auto iter_c_ptr = c_ptr + m_end * ldc + n_offset; | |||||
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m( | matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); | ||||
} | } | ||||
if (n_remain > 0) { | if (n_remain > 0) { | ||||
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; | ||||
auto iter_c_ptr = c_ptr + m_end * n + n_end; | |||||
auto iter_c_ptr = c_ptr + m_end * ldc + n_end; | |||||
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( | matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( | ||||
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, | ||||
n_remain); | n_remain); | ||||
@@ -78,6 +78,7 @@ protected: | |||||
TensorsConstriant m_tensor_constraint; | TensorsConstriant m_tensor_constraint; | ||||
bool m_no_naive_and_check = false; | bool m_no_naive_and_check = false; | ||||
bool m_stable_check = false; | bool m_stable_check = false; | ||||
bool m_force_deduce_dst = true; | |||||
/** | /** | ||||
* the offset from the start of malloc memory | * the offset from the start of malloc memory | ||||
* | * | ||||
@@ -236,6 +237,12 @@ public: | |||||
return *this; | return *this; | ||||
} | } | ||||
//! froce deduce dst | |||||
Checker& set_force_deduce_dst(bool force_deduce_dst) { | |||||
m_force_deduce_dst = force_deduce_dst; | |||||
return *this; | |||||
} | |||||
Checker& set_no_naive_check(bool no_naive_and_check) { | Checker& set_no_naive_check(bool no_naive_and_check) { | ||||
m_no_naive_and_check = no_naive_and_check; | m_no_naive_and_check = no_naive_and_check; | ||||
return *this; | return *this; | ||||
@@ -343,7 +350,10 @@ void Checker<Opr, Proxy>::exec(TensorLayoutArray layouts) { | |||||
auto opr_cur = this->opr(); | auto opr_cur = this->opr(); | ||||
opr_naive->param() = m_param; | opr_naive->param() = m_param; | ||||
opr_cur->param() = m_param; | opr_cur->param() = m_param; | ||||
m_naive_proxy.deduce_layout(opr_naive.get(), layouts); | |||||
bool deduce_layout = layouts.back().ndim == 0; | |||||
if (deduce_layout || m_force_deduce_dst) { | |||||
m_naive_proxy.deduce_layout(opr_naive.get(), layouts); | |||||
} | |||||
auto exec_naive = [this, &opr_naive, &layouts, | auto exec_naive = [this, &opr_naive, &layouts, | ||||
&opr_relayout](const TensorValueArray& values) { | &opr_relayout](const TensorValueArray& values) { | ||||
TensorValueArray contig_values = values; | TensorValueArray contig_values = values; | ||||
@@ -101,7 +101,7 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_mask( | |||||
size_t Astride = mask & 1 ? m + 2 : k + 2; | size_t Astride = mask & 1 ? m + 2 : k + 2; | ||||
// B: (k, n) | // B: (k, n) | ||||
size_t Bstride = mask & 2 ? k + 2 : n + 2; | size_t Bstride = mask & 2 ? k + 2 : n + 2; | ||||
size_t Cstride = n + 2; | |||||
size_t Cstride = n * 2 + 2; | |||||
args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride); | args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride); | ||||
} | } | ||||
return args; | return args; | ||||
@@ -183,9 +183,11 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||||
Handle* handle, | Handle* handle, | ||||
const ExecutionPolicyAlgoName& algo, | const ExecutionPolicyAlgoName& algo, | ||||
param::MatrixMul::Format format, size_t nbase, | param::MatrixMul::Format format, size_t nbase, | ||||
float eps, std::vector<TestArg>&& user_args) { | |||||
float eps, std::vector<TestArg>&& user_args, | |||||
bool force_deduce_dst) { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | ||||
Checker<Opr> checker(handle); | Checker<Opr> checker(handle); | ||||
checker.set_force_deduce_dst(force_deduce_dst); | |||||
if (!algo.name.empty()) { | if (!algo.name.empty()) { | ||||
checker.set_before_exec_callback(AlgoChecker<Opr>(algo)); | checker.set_before_exec_callback(AlgoChecker<Opr>(algo)); | ||||
} | } | ||||
@@ -245,16 +247,16 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||||
for (auto& arg : args) { | for (auto& arg : args) { | ||||
size_t m = arg.m, n = arg.n, k = arg.k; | size_t m = arg.m, n = arg.n, k = arg.k; | ||||
#if MEGDNN_WITH_CUDA | |||||
//[NOTE]: cublas can only process 4B aligned 8-bit input matrix; | |||||
bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 || | |||||
A_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
A_dtype.enumv() == DTypeEnum::Uint8 || | |||||
A_dtype.enumv() == DTypeEnum::Quantized8Asymm; | |||||
if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) { | |||||
continue; | |||||
if (handle->type() == Handle::HandleType::CUDA) { | |||||
//! NOTE: cublas can only process 4B aligned 8-bit input matrix; | |||||
bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 || | |||||
A_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
A_dtype.enumv() == DTypeEnum::Uint8 || | |||||
A_dtype.enumv() == DTypeEnum::Quantized8Asymm; | |||||
if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) { | |||||
continue; | |||||
} | |||||
} | } | ||||
#endif | |||||
Param param; | Param param; | ||||
param.transposeA = arg.mask & 0x1; | param.transposeA = arg.mask & 0x1; | ||||
@@ -312,20 +314,22 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, | |||||
DType C_dtype, Handle* handle, | DType C_dtype, Handle* handle, | ||||
const ExecutionPolicyAlgoName& algo, | const ExecutionPolicyAlgoName& algo, | ||||
float eps, | float eps, | ||||
std::vector<TestArg>&& args) { | |||||
std::vector<TestArg>&& args, | |||||
bool force_deduce_dst) { | |||||
check_matrix_mul<megdnn::BatchedMatrixMul>( | check_matrix_mul<megdnn::BatchedMatrixMul>( | ||||
A_dtype, B_dtype, C_dtype, handle, algo, | A_dtype, B_dtype, C_dtype, handle, algo, | ||||
param::MatrixMul::Format::DEFAULT, 8, eps, | param::MatrixMul::Format::DEFAULT, 8, eps, | ||||
std::forward<decltype(args)>(args)); | |||||
std::forward<decltype(args)>(args), force_deduce_dst); | |||||
} | } | ||||
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | ||||
Handle* handle, | Handle* handle, | ||||
const ExecutionPolicyAlgoName& algo, | const ExecutionPolicyAlgoName& algo, | ||||
param::MatrixMul::Format format, size_t nbase, | param::MatrixMul::Format format, size_t nbase, | ||||
float eps) { | |||||
float eps, bool force_deduce_dst) { | |||||
check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo, | check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo, | ||||
format, nbase, eps); | |||||
format, nbase, eps, {}, | |||||
force_deduce_dst); | |||||
} | } | ||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
@@ -68,19 +68,21 @@ void check_matrix_mul( | |||||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | ||||
const ExecutionPolicyAlgoName& algo = {"", {}}, | const ExecutionPolicyAlgoName& algo = {"", {}}, | ||||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | ||||
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}); | |||||
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, | |||||
bool force_deduce_dst = true); | |||||
void check_matrix_mul( | void check_matrix_mul( | ||||
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | ||||
const ExecutionPolicyAlgoName& algo = {"", {}}, | const ExecutionPolicyAlgoName& algo = {"", {}}, | ||||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | ||||
size_t nbase = 8, float eps = 1e-3); | |||||
size_t nbase = 8, float eps = 1e-3, bool force_deduce_dst = true); | |||||
void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | ||||
Handle* handle, | Handle* handle, | ||||
const ExecutionPolicyAlgoName& algo = {"", {}}, | const ExecutionPolicyAlgoName& algo = {"", {}}, | ||||
float eps = 1e-3, | float eps = 1e-3, | ||||
std::vector<TestArg>&& args = {}); | |||||
std::vector<TestArg>&& args = {}, | |||||
bool force_deduce_dst = true); | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
std::vector<TestArg> get_benchmark_matmul_args(); | std::vector<TestArg> get_benchmark_matmul_args(); | ||||
@@ -44,21 +44,31 @@ TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) { | |||||
//! FIXME: need to add tests of GEMV and QUINT8 | //! FIXME: need to add tests of GEMV and QUINT8 | ||||
TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { | TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { | ||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
handle(), "X86_INT8X8X32_AVX2_2X4X16"); | |||||
handle(), "X86_INT8X8X32_AVX2_2X4X16", | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, | |||||
false); | |||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
handle(), "X86_INT8X8X32_AVX2_4X16X2"); | |||||
handle(), "X86_INT8X8X32_AVX2_4X16X2", | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, | |||||
false); | |||||
} | } | ||||
TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { | TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { | ||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | ||||
handle(), "X86_INT8X8X16_AVX2"); | |||||
handle(), "X86_INT8X8X16_AVX2", | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, | |||||
false); | |||||
} | } | ||||
TEST_F(X86, MATRIX_MUL_SSE_8X8X16) { | TEST_F(X86, MATRIX_MUL_SSE_8X8X16) { | ||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | ||||
handle(), "X86_INT8X8X16_SSE"); | |||||
handle(), "X86_INT8X8X16_SSE", | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, | |||||
false); | |||||
} | } | ||||
TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { | TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { | ||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
handle(), "X86_INT8X8X32_SSE_4X8X2"); | |||||
handle(), "X86_INT8X8X32_SSE_4X8X2", | |||||
param::MatrixMul::Format::DEFAULT, 8, 1e-3, | |||||
false); | |||||
} | } | ||||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | ||||
@@ -72,7 +82,7 @@ TEST_F(X86, MATRIX_MUL_MKL_PACKA) { | |||||
TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) { | TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) { | ||||
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | ||||
dtype::Float32{}, handle(), "X86_F32MK8_8X8", | dtype::Float32{}, handle(), "X86_F32MK8_8X8", | ||||
param::MatrixMul::Format::MK8, 1); | |||||
param::MatrixMul::Format::MK8, 1, 1e-3, false); | |||||
} | } | ||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||