|
|
@@ -30,6 +30,7 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, |
|
|
|
auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param, |
|
|
|
Handle* handle, size_t pack_size) { |
|
|
|
megdnn_assert((param.format == param::MatrixMul::Format::MK4 || |
|
|
|
param.format == param::MatrixMul::Format::MK4_DOT || |
|
|
|
param.format == param::MatrixMul::Format::MK8) && |
|
|
|
tensors.size() == 3); |
|
|
|
param::MatrixMul new_param = param; |
|
|
@@ -41,18 +42,34 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, |
|
|
|
TensorLayoutArray default_layouts, mk4_layouts; |
|
|
|
if (param.transposeA) { |
|
|
|
default_layouts.emplace_back(tensors[0].layout.reshape({K, M})); |
|
|
|
mk4_layouts.emplace_back( |
|
|
|
default_layouts.back() |
|
|
|
.reshape({K / pack_size, M / pack_size, pack_size, |
|
|
|
pack_size}) |
|
|
|
.dimshuffle({0, 2, 1, 3})); |
|
|
|
if (param.format == param::MatrixMul::Format::MK4_DOT) { |
|
|
|
mk4_layouts.emplace_back( |
|
|
|
default_layouts.back() |
|
|
|
.reshape({K / pack_size, M / pack_size, |
|
|
|
pack_size, pack_size}) |
|
|
|
.dimshuffle({0, 3, 1, 2})); |
|
|
|
} else { |
|
|
|
mk4_layouts.emplace_back( |
|
|
|
default_layouts.back() |
|
|
|
.reshape({K / pack_size, M / pack_size, |
|
|
|
pack_size, pack_size}) |
|
|
|
.dimshuffle({0, 2, 1, 3})); |
|
|
|
} |
|
|
|
} else { |
|
|
|
default_layouts.emplace_back(tensors[0].layout.reshape({M, K})); |
|
|
|
mk4_layouts.emplace_back( |
|
|
|
default_layouts.back() |
|
|
|
.reshape({M / pack_size, K / pack_size, pack_size, |
|
|
|
pack_size}) |
|
|
|
.dimshuffle({0, 3, 1, 2})); |
|
|
|
if (param.format == param::MatrixMul::Format::MK4_DOT) { |
|
|
|
mk4_layouts.emplace_back( |
|
|
|
default_layouts.back() |
|
|
|
.reshape({M / pack_size, K / pack_size, |
|
|
|
pack_size, pack_size}) |
|
|
|
.dimshuffle({0, 2, 1, 3})); |
|
|
|
} else { |
|
|
|
mk4_layouts.emplace_back( |
|
|
|
default_layouts.back() |
|
|
|
.reshape({M / pack_size, K / pack_size, |
|
|
|
pack_size, pack_size}) |
|
|
|
.dimshuffle({0, 3, 1, 2})); |
|
|
|
} |
|
|
|
} |
|
|
|
if (param.transposeB) { |
|
|
|
default_layouts.emplace_back(tensors[1].layout.reshape({N, K})); |
|
|
@@ -238,6 +255,11 @@ TEST_F(NAIVE, MATRIX_MUL_MK8) { |
|
|
|
dtype::Int16(), dtype::Int16(), dtype::Int32()); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(NAIVE, MATRIX_MUL_MK4_DOT) { |
|
|
|
run_matmul_mk_format(handle(), param::MatrixMul::Format::MK4_DOT, |
|
|
|
dtype::Int8(), dtype::Int8(), dtype::Int32()); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) { |
|
|
|
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false); |
|
|
|
MatrixMul::Param param, fp32_param; |
|
|
|