|
|
@@ -49,15 +49,33 @@ void kern_naive(const MatrixMulImpl::KernParam& kern_param) { |
|
|
|
MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) { |
|
|
|
size_t M = kern_param.M, N = kern_param.N, K = kern_param.K; |
|
|
|
size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; |
|
|
|
|
|
|
|
#define DISPATCH(TA, TB) \ |
|
|
|
if (kern_param.trA == TA && kern_param.trB == TB) { \ |
|
|
|
naive::dispatch_ta_tb<TA, TB>( \ |
|
|
|
kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ |
|
|
|
kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \ |
|
|
|
kern_param.A_type, kern_param.B_type, kern_param.C_type, \ |
|
|
|
kern_param.format, kern_param.compute_mode); \ |
|
|
|
return; \ |
|
|
|
auto get_pack_size = [kern_param]() -> size_t { |
|
|
|
switch (kern_param.format) { |
|
|
|
case param::MatrixMul::Format::MK4: |
|
|
|
case param::MatrixMul::Format::MK4_DOT: |
|
|
|
return 4_z; |
|
|
|
case param::MatrixMul::Format::MK8: |
|
|
|
return 8_z; |
|
|
|
default: |
|
|
|
return 1_z; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
size_t pack_size = get_pack_size(); |
|
|
|
megdnn_assert( |
|
|
|
(M % pack_size == 0 && K % pack_size == 0), |
|
|
|
"M and N must time of pack_size M: %zu N: %zu pack_size: %zu", |
|
|
|
M, N, pack_size); |
|
|
|
|
|
|
|
#define DISPATCH(TA, TB) \ |
|
|
|
if (kern_param.trA == TA && kern_param.trB == TB) { \ |
|
|
|
naive::dispatch_ta_tb<TA, TB>( \ |
|
|
|
kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ |
|
|
|
kern_param.workspace_ptr, M / pack_size, N, K / pack_size, \ |
|
|
|
LDA, LDB, LDC, kern_param.A_type, kern_param.B_type, \ |
|
|
|
kern_param.C_type, kern_param.format, \ |
|
|
|
kern_param.compute_mode); \ |
|
|
|
return; \ |
|
|
|
} |
|
|
|
DISPATCH(true, true); |
|
|
|
DISPATCH(true, false); |
|
|
|