GitOrigin-RevId: 7c6fbdfa97
tags/v0.5.0
@@ -433,7 +433,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | ||||
Doc('MK8', 'Split 8 from M and K, better for neon compute:' | Doc('MK8', 'Split 8 from M and K, better for neon compute:' | ||||
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | ||||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))')) | |||||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | |||||
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | |||||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | |||||
) | ) | ||||
(pdef('Winograd', 'winograd param used in convbias'). | (pdef('Winograd', 'winograd param used in convbias'). | ||||
@@ -186,6 +186,8 @@ size_t MatrixMulForward::pack_size(const Param::Format format) { | |||||
return 1; | return 1; | ||||
case Param::Format::MK4: | case Param::Format::MK4: | ||||
return 4; | return 4; | ||||
case Param::Format::MK4_DOT: | |||||
return 4; | |||||
case Param::Format::MK8: | case Param::Format::MK8: | ||||
return 8; | return 8; | ||||
default: | default: | ||||
@@ -84,6 +84,35 @@ void run_matrix_mul_mk4_tpl(const itype* A, const itype* B, otype* C, size_t M, | |||||
template <typename itype, typename otype, bool transA, bool transB, | template <typename itype, typename otype, bool transA, bool transB, | ||||
typename comp_type = otype> | typename comp_type = otype> | ||||
void run_matrix_mul_mk4_dot_tpl(const itype* A, const itype* B, otype* C, | |||||
size_t M, size_t N, size_t K, size_t LDA, | |||||
size_t LDB, size_t LDC, const DType& A_type, | |||||
const DType& B_type) { | |||||
Getter<itype, comp_type> getterA(A_type), getterB(B_type); | |||||
for (size_t m = 0; m < M; ++m) { | |||||
for (size_t n = 0; n < N; ++n) { | |||||
comp_type res[4] = {comp_type(0)}; | |||||
for (size_t k = 0; k < K; ++k) { | |||||
for (size_t i = 0; i < 4; i++) { | |||||
comp_type av, bv; | |||||
for (size_t j = 0; j < 4; j++) { | |||||
av = transA ? getterA(A[k * LDA + m * 16 + 4 * i + j]) | |||||
: getterA(A[m * LDA + k * 16 + 4 * i + j]), | |||||
bv = transB ? getterB(B[n * LDB + k * 4 + j]) | |||||
: getterB(B[k * LDB + n * 4 + j]); | |||||
res[i] += av * bv; | |||||
} | |||||
} | |||||
} | |||||
for (size_t i = 0; i < 4; i++) { | |||||
C[m * LDC + n * 4 + i] = res[i]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <typename itype, typename otype, bool transA, bool transB, | |||||
typename comp_type = otype> | |||||
void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, | void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, | ||||
size_t N, size_t K, size_t LDA, size_t LDB, | size_t N, size_t K, size_t LDA, size_t LDB, | ||||
size_t LDC, const DType& A_type, | size_t LDC, const DType& A_type, | ||||
@@ -38,22 +38,27 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | ||||
LDC = C.layout.stride[0]; | LDC = C.layout.stride[0]; | ||||
#define cb(_itype, _otype, _comp_type) \ | |||||
if (param.format == param::MatrixMul::Format::DEFAULT) { \ | |||||
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
A.layout.dtype, B.layout.dtype); \ | |||||
} else if (param.format == param::MatrixMul::Format::MK4) { \ | |||||
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
A.layout.dtype, B.layout.dtype); \ | |||||
} else if (param.format == param::MatrixMul::Format::MK8) { \ | |||||
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
A.layout.dtype, B.layout.dtype); \ | |||||
#define cb(_itype, _otype, _comp_type) \ | |||||
if (param.format == param::MatrixMul::Format::DEFAULT) { \ | |||||
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
A.layout.dtype, B.layout.dtype); \ | |||||
} else if (param.format == param::MatrixMul::Format::MK4) { \ | |||||
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
A.layout.dtype, B.layout.dtype); \ | |||||
} else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ | |||||
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
A.layout.dtype, B.layout.dtype); \ | |||||
} else if (param.format == param::MatrixMul::Format::MK8) { \ | |||||
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
A.layout.dtype, B.layout.dtype); \ | |||||
} | } | ||||
if (A.layout.dtype == dtype::Float32()) { | if (A.layout.dtype == dtype::Float32()) { | ||||