Browse Source

feat(mgb/dnn): add matmul mk4 dot naive test

GitOrigin-RevId: 2f16d4f89b
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
9320bf92af
1 changed files with 32 additions and 10 deletions
  1. +32
    -10
      dnn/test/naive/matrix_mul.cpp

+ 32
- 10
dnn/test/naive/matrix_mul.cpp View File

@@ -30,6 +30,7 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format,
auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param,
Handle* handle, size_t pack_size) {
megdnn_assert((param.format == param::MatrixMul::Format::MK4 ||
param.format == param::MatrixMul::Format::MK4_DOT ||
param.format == param::MatrixMul::Format::MK8) &&
tensors.size() == 3);
param::MatrixMul new_param = param;
@@ -41,18 +42,34 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format,
TensorLayoutArray default_layouts, mk4_layouts;
if (param.transposeA) {
default_layouts.emplace_back(tensors[0].layout.reshape({K, M}));
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({K / pack_size, M / pack_size, pack_size,
pack_size})
.dimshuffle({0, 2, 1, 3}));
if (param.format == param::MatrixMul::Format::MK4_DOT) {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({K / pack_size, M / pack_size,
pack_size, pack_size})
.dimshuffle({0, 3, 1, 2}));
} else {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({K / pack_size, M / pack_size,
pack_size, pack_size})
.dimshuffle({0, 2, 1, 3}));
}
} else {
default_layouts.emplace_back(tensors[0].layout.reshape({M, K}));
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({M / pack_size, K / pack_size, pack_size,
pack_size})
.dimshuffle({0, 3, 1, 2}));
if (param.format == param::MatrixMul::Format::MK4_DOT) {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({M / pack_size, K / pack_size,
pack_size, pack_size})
.dimshuffle({0, 2, 1, 3}));
} else {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({M / pack_size, K / pack_size,
pack_size, pack_size})
.dimshuffle({0, 3, 1, 2}));
}
}
if (param.transposeB) {
default_layouts.emplace_back(tensors[1].layout.reshape({N, K}));
@@ -238,6 +255,11 @@ TEST_F(NAIVE, MATRIX_MUL_MK8) {
dtype::Int16(), dtype::Int16(), dtype::Int32());
}

TEST_F(NAIVE, MATRIX_MUL_MK4_DOT) {
run_matmul_mk_format(handle(), param::MatrixMul::Format::MK4_DOT,
dtype::Int8(), dtype::Int8(), dtype::Int32());
}

TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) {
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
MatrixMul::Param param, fp32_param;


Loading…
Cancel
Save