Browse Source

perf(dnn/cuda): optimize int4 sass conv main loop and epilogue without fuse_z

GitOrigin-RevId: 4274e58d64
release-1.5
Megvii Engine Team 4 years ago
parent
commit
cff61a53d4
1 changed files with 15 additions and 6 deletions
  1. +15
    -6
      dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp

+ 15
- 6
dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp View File

@@ -167,12 +167,21 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec(
bias_ptr =
reinterpret_cast<void*>(args.workspace.raw_ptr +
args.filter_layout->span().dist_byte());
reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
if (args.z_layout->ndim > 0) {
reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
} else {
reorder_imma_filter_bias<4, 64, true>(
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream);
}
}

uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh,


Loading…
Cancel
Save