|
|
@@ -38,22 +38,27 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, |
|
|
|
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], |
|
|
|
LDC = C.layout.stride[0]; |
|
|
|
|
|
|
|
#define cb(_itype, _otype, _comp_type) \ |
|
|
|
if (param.format == param::MatrixMul::Format::DEFAULT) { \ |
|
|
|
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ |
|
|
|
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ |
|
|
|
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ |
|
|
|
A.layout.dtype, B.layout.dtype); \ |
|
|
|
} else if (param.format == param::MatrixMul::Format::MK4) { \ |
|
|
|
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ |
|
|
|
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ |
|
|
|
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ |
|
|
|
A.layout.dtype, B.layout.dtype); \ |
|
|
|
} else if (param.format == param::MatrixMul::Format::MK8) { \ |
|
|
|
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ |
|
|
|
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ |
|
|
|
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ |
|
|
|
A.layout.dtype, B.layout.dtype); \ |
|
|
|
#define cb(_itype, _otype, _comp_type) \ |
|
|
|
if (param.format == param::MatrixMul::Format::DEFAULT) { \ |
|
|
|
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ |
|
|
|
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ |
|
|
|
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ |
|
|
|
A.layout.dtype, B.layout.dtype); \ |
|
|
|
} else if (param.format == param::MatrixMul::Format::MK4) { \ |
|
|
|
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ |
|
|
|
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ |
|
|
|
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ |
|
|
|
A.layout.dtype, B.layout.dtype); \ |
|
|
|
} else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ |
|
|
|
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ |
|
|
|
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ |
|
|
|
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ |
|
|
|
A.layout.dtype, B.layout.dtype); \ |
|
|
|
} else if (param.format == param::MatrixMul::Format::MK8) { \ |
|
|
|
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ |
|
|
|
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ |
|
|
|
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ |
|
|
|
A.layout.dtype, B.layout.dtype); \ |
|
|
|
} |
|
|
|
|
|
|
|
if (A.layout.dtype == dtype::Float32()) { |
|
|
|