From 9320bf92af053e278747570ffddb0370ab757351 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 22 May 2020 14:16:48 +0800 Subject: [PATCH] feat(mgb/dnn): add matmul mk4 dot naive test GitOrigin-RevId: 2f16d4f89b900101977270eb6446541e5d558a32 --- dnn/test/naive/matrix_mul.cpp | 42 ++++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/dnn/test/naive/matrix_mul.cpp b/dnn/test/naive/matrix_mul.cpp index 2d04d013..4bfa4596 100644 --- a/dnn/test/naive/matrix_mul.cpp +++ b/dnn/test/naive/matrix_mul.cpp @@ -30,6 +30,7 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param, Handle* handle, size_t pack_size) { megdnn_assert((param.format == param::MatrixMul::Format::MK4 || + param.format == param::MatrixMul::Format::MK4_DOT || param.format == param::MatrixMul::Format::MK8) && tensors.size() == 3); param::MatrixMul new_param = param; @@ -41,18 +42,34 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, TensorLayoutArray default_layouts, mk4_layouts; if (param.transposeA) { default_layouts.emplace_back(tensors[0].layout.reshape({K, M})); - mk4_layouts.emplace_back( - default_layouts.back() - .reshape({K / pack_size, M / pack_size, pack_size, - pack_size}) - .dimshuffle({0, 2, 1, 3})); + if (param.format == param::MatrixMul::Format::MK4_DOT) { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({K / pack_size, M / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 3, 1, 2})); + } else { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({K / pack_size, M / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 2, 1, 3})); + } } else { default_layouts.emplace_back(tensors[0].layout.reshape({M, K})); - mk4_layouts.emplace_back( - default_layouts.back() - .reshape({M / pack_size, K / pack_size, pack_size, - pack_size}) - .dimshuffle({0, 3, 1, 2})); + if (param.format == param::MatrixMul::Format::MK4_DOT) { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({M / pack_size, K / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 2, 1, 3})); + } else { + mk4_layouts.emplace_back( + default_layouts.back() + .reshape({M / pack_size, K / pack_size, + pack_size, pack_size}) + .dimshuffle({0, 3, 1, 2})); + } } if (param.transposeB) { default_layouts.emplace_back(tensors[1].layout.reshape({N, K})); @@ -238,6 +255,11 @@ TEST_F(NAIVE, MATRIX_MUL_MK8) { dtype::Int16(), dtype::Int16(), dtype::Int32()); } +TEST_F(NAIVE, MATRIX_MUL_MK4_DOT) { + run_matmul_mk_format(handle(), param::MatrixMul::Format::MK4_DOT, + dtype::Int8(), dtype::Int8(), dtype::Int32()); +} + TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) { Checker checker(handle(), /* check_dispatch */ false); MatrixMul::Param param, fp32_param;