|
|
@@ -167,12 +167,21 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec( |
|
|
|
bias_ptr = |
|
|
|
reinterpret_cast<void*>(args.workspace.raw_ptr + |
|
|
|
args.filter_layout->span().dist_byte()); |
|
|
|
reorder_imma_filter_bias<4, 64>( |
|
|
|
reinterpret_cast<int8_t*>(filter_ptr), |
|
|
|
reinterpret_cast<int32_t*>(bias_ptr), |
|
|
|
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), |
|
|
|
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, |
|
|
|
stream); |
|
|
|
if (args.z_layout->ndim > 0) { |
|
|
|
reorder_imma_filter_bias<4, 64>( |
|
|
|
reinterpret_cast<int8_t*>(filter_ptr), |
|
|
|
reinterpret_cast<int32_t*>(bias_ptr), |
|
|
|
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), |
|
|
|
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, |
|
|
|
stream); |
|
|
|
} else { |
|
|
|
reorder_imma_filter_bias<4, 64, true>( |
|
|
|
reinterpret_cast<int8_t*>(filter_ptr), |
|
|
|
reinterpret_cast<int32_t*>(bias_ptr), |
|
|
|
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), |
|
|
|
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, |
|
|
|
stream); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh, |
|
|
|