diff --git a/dnn/test/naive/matrix_mul.cpp b/dnn/test/naive/matrix_mul.cpp index 2d04d013..4bfa4596 100644 --- a/dnn/test/naive/matrix_mul.cpp +++ b/dnn/test/naive/matrix_mul.cpp @@ -30,6 +30,7 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param, Handle* handle, size_t pack_size) { megdnn_assert((param.format == param::MatrixMul::Format::MK4 || + param.format == param::MatrixMul::Format::MK4_DOT || param.format == param::MatrixMul::Format::MK8) && tensors.size() == 3); param::MatrixMul new_param = param; @@ -41,18 +42,34 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, TensorLayoutArray default_layouts, mk4_layouts; if (param.transposeA) { default_layouts.emplace_back(tensors[0].layout.reshape({K, M})); - mk4_layouts.emplace_back( - default_layouts.back() - .reshape({K / pack_size, M / pack_size, pack_size, - pack_size}) - .dimshuffle({0, 2, 1, 3})); + if (param.format == param::MatrixMul::Format::MK4_DOT) { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({K / pack_size, M / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 3, 1, 2})); + } else { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({K / pack_size, M / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 2, 1, 3})); + } } else { default_layouts.emplace_back(tensors[0].layout.reshape({M, K})); - mk4_layouts.emplace_back( - default_layouts.back() - .reshape({M / pack_size, K / pack_size, pack_size, - pack_size}) - .dimshuffle({0, 3, 1, 2})); + if (param.format == param::MatrixMul::Format::MK4_DOT) { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({M / pack_size, K / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 2, 1, 3})); + } else { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({M / pack_size, K / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 3, 1, 2})); + } } if (param.transposeB) { default_layouts.emplace_back(tensors[1].layout.reshape({N, K})); @@ -238,6 +255,11 @@ TEST_F(NAIVE, MATRIX_MUL_MK8) { dtype::Int16(), dtype::Int16(), dtype::Int32()); } +TEST_F(NAIVE, MATRIX_MUL_MK4_DOT) { + run_matmul_mk_format(handle(), param::MatrixMul::Format::MK4_DOT, + dtype::Int8(), dtype::Int8(), dtype::Int32()); +} + TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) { Checker checker(handle(), /* check_dispatch */ false); MatrixMul::Param param, fp32_param;