Browse Source

feat(dnn/common): add matmul impl for naive with matrix format mk4_dot

GitOrigin-RevId: 7c6fbdfa97
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
a6bc250d1c
4 changed files with 56 additions and 17 deletions
  1. +4
    -1
      dnn/scripts/opr_param_defs.py
  2. +2
    -0
      dnn/src/common/matrix_mul.cpp
  3. +29
    -0
      dnn/src/naive/matrix_mul/matrix_mul_helper.h
  4. +21
    -16
      dnn/src/naive/matrix_mul/opr_impl.cpp

+ 4
- 1
dnn/scripts/opr_param_defs.py View File

@@ -433,7 +433,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'),
Doc('MK8', 'Split 8 from M and K, better for neon compute:' Doc('MK8', 'Split 8 from M and K, better for neon compute:'
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the '
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'))
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'),
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:'
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'))
) )


(pdef('Winograd', 'winograd param used in convbias'). (pdef('Winograd', 'winograd param used in convbias').


+ 2
- 0
dnn/src/common/matrix_mul.cpp View File

@@ -186,6 +186,8 @@ size_t MatrixMulForward::pack_size(const Param::Format format) {
return 1; return 1;
case Param::Format::MK4: case Param::Format::MK4:
return 4; return 4;
case Param::Format::MK4_DOT:
return 4;
case Param::Format::MK8: case Param::Format::MK8:
return 8; return 8;
default: default:


+ 29
- 0
dnn/src/naive/matrix_mul/matrix_mul_helper.h View File

@@ -84,6 +84,35 @@ void run_matrix_mul_mk4_tpl(const itype* A, const itype* B, otype* C, size_t M,


template <typename itype, typename otype, bool transA, bool transB, template <typename itype, typename otype, bool transA, bool transB,
typename comp_type = otype> typename comp_type = otype>
void run_matrix_mul_mk4_dot_tpl(const itype* A, const itype* B, otype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, const DType& A_type,
const DType& B_type) {
Getter<itype, comp_type> getterA(A_type), getterB(B_type);
for (size_t m = 0; m < M; ++m) {
for (size_t n = 0; n < N; ++n) {
comp_type res[4] = {comp_type(0)};
for (size_t k = 0; k < K; ++k) {
for (size_t i = 0; i < 4; i++) {
comp_type av, bv;
for (size_t j = 0; j < 4; j++) {
av = transA ? getterA(A[k * LDA + m * 16 + 4 * i + j])
: getterA(A[m * LDA + k * 16 + 4 * i + j]),
bv = transB ? getterB(B[n * LDB + k * 4 + j])
: getterB(B[k * LDB + n * 4 + j]);
res[i] += av * bv;
}
}
}
for (size_t i = 0; i < 4; i++) {
C[m * LDC + n * 4 + i] = res[i];
}
}
}
}

template <typename itype, typename otype, bool transA, bool transB,
typename comp_type = otype>
void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M,
size_t N, size_t K, size_t LDA, size_t LDB, size_t N, size_t K, size_t LDA, size_t LDB,
size_t LDC, const DType& A_type, size_t LDC, const DType& A_type,


+ 21
- 16
dnn/src/naive/matrix_mul/opr_impl.cpp View File

@@ -38,22 +38,27 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B,
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], auto LDA = A.layout.stride[0], LDB = B.layout.stride[0],
LDC = C.layout.stride[0]; LDC = C.layout.stride[0];


#define cb(_itype, _otype, _comp_type) \
if (param.format == param::MatrixMul::Format::DEFAULT) { \
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4) { \
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
#define cb(_itype, _otype, _comp_type) \
if (param.format == param::MatrixMul::Format::DEFAULT) { \
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4) { \
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4_DOT) { \
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} }


if (A.layout.dtype == dtype::Float32()) { if (A.layout.dtype == dtype::Float32()) {


Loading…
Cancel
Save