|
|
@@ -50,15 +50,23 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout( |
|
|
|
deduce_reformat_layout(relayout_src, *args.filter_layout, |
|
|
|
inner_weight_layout, |
|
|
|
RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT); |
|
|
|
deduce_reformat_layout(relayout_src, *args.dst_layout, inner_dst_layout, |
|
|
|
RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, |
|
|
|
args.filter_meta.group); |
|
|
|
deduce_reformat_layout(relayout_src, *args.bias_layout, inner_bias_layout, |
|
|
|
RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, |
|
|
|
args.filter_meta.group); |
|
|
|
deduce_reformat_layout(relayout_src, *args.z_layout, inner_z_layout, |
|
|
|
RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, |
|
|
|
args.filter_meta.group); |
|
|
|
bool dst_float = args.dst_layout->dtype.enumv() == DTypeEnum::Float32; |
|
|
|
if (dst_float) { |
|
|
|
inner_dst_layout = *args.dst_layout; |
|
|
|
inner_bias_layout = *args.bias_layout; |
|
|
|
inner_z_layout = *args.z_layout; |
|
|
|
} else { |
|
|
|
deduce_reformat_layout(relayout_src, *args.dst_layout, inner_dst_layout, |
|
|
|
RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, |
|
|
|
args.filter_meta.group); |
|
|
|
deduce_reformat_layout(relayout_src, *args.bias_layout, |
|
|
|
inner_bias_layout, |
|
|
|
RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, |
|
|
|
args.filter_meta.group); |
|
|
|
deduce_reformat_layout(relayout_src, *args.z_layout, inner_z_layout, |
|
|
|
RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, |
|
|
|
args.filter_meta.group); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( |
|
|
@@ -70,8 +78,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( |
|
|
|
auto&& param = args.opr->param(); |
|
|
|
bool is_format_ok = param.format == param::ConvBias::Format::NCHW; |
|
|
|
bool is_version_ok = CUDNN_VERSION >= 7500; |
|
|
|
bool is_dtype_ok = |
|
|
|
args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; |
|
|
|
bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; |
|
|
|
bool is_bias_ok = |
|
|
|
args.bias_layout->ndim == 0 || |
|
|
|
(args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && |
|
|
@@ -90,17 +97,23 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( |
|
|
|
TensorLayout inner_z_layout; |
|
|
|
make_inner_layout(args, inner_src_layout, inner_weight_layout, |
|
|
|
inner_dst_layout, inner_bias_layout, inner_z_layout); |
|
|
|
auto opr = args.handle->create_operator<ConvBiasForward>(); |
|
|
|
Param inner_conv_param = args.opr->param(); |
|
|
|
inner_conv_param.format = Param::Format::NCHW4; |
|
|
|
size_t ws_dst = 0, ws_bias = 0, ws_z = 0; |
|
|
|
if (args.dst_layout->dtype.enumv() == DTypeEnum::Float32) { |
|
|
|
inner_conv_param.format = Param::Format::NCHW4_NCHW; |
|
|
|
} else { |
|
|
|
inner_conv_param.format = Param::Format::NCHW4; |
|
|
|
ws_dst = inner_dst_layout.span().dist_byte(); |
|
|
|
ws_bias = inner_bias_layout.span().dist_byte(); |
|
|
|
ws_z = inner_z_layout.span().dist_byte(); |
|
|
|
} |
|
|
|
auto opr = args.handle->create_operator<ConvBiasForward>(); |
|
|
|
opr->param() = inner_conv_param; |
|
|
|
return WorkspaceBundle(ptr, {inner_src_layout.span().dist_byte(), |
|
|
|
inner_weight_layout.span().dist_byte(), |
|
|
|
inner_dst_layout.span().dist_byte(), |
|
|
|
inner_bias_layout.span().dist_byte(), |
|
|
|
inner_z_layout.span().dist_byte(), |
|
|
|
opr->get_workspace_in_bytes( |
|
|
|
inner_src_layout, inner_weight_layout, |
|
|
|
return WorkspaceBundle( |
|
|
|
ptr, |
|
|
|
{inner_src_layout.span().dist_byte(), |
|
|
|
inner_weight_layout.span().dist_byte(), ws_dst, ws_bias, ws_z, |
|
|
|
opr->get_workspace_in_bytes(inner_src_layout, inner_weight_layout, |
|
|
|
inner_bias_layout, inner_z_layout, |
|
|
|
inner_dst_layout, nullptr)}); |
|
|
|
} |
|
|
@@ -145,22 +158,33 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( |
|
|
|
TensorND inner_bias(bundle.get(3), inner_bias_layout); |
|
|
|
TensorND inner_z(bundle.get(4), inner_z_layout); |
|
|
|
|
|
|
|
bool dst_float = args.dst_layout->dtype.enumv() == DTypeEnum::Float32; |
|
|
|
|
|
|
|
Param inner_conv_param = args.opr->param(); |
|
|
|
inner_conv_param.format = Param::Format::NCHW4; |
|
|
|
inner_conv_param.format = |
|
|
|
dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4; |
|
|
|
auto inner_opr = args.handle->create_operator<ConvBiasForward>(); |
|
|
|
inner_opr->param() = inner_conv_param; |
|
|
|
|
|
|
|
relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {}); |
|
|
|
relayout_weight->exec(*args.filter_tensor, inner_weight, {}); |
|
|
|
if (inner_bias_layout.ndim > 0) { |
|
|
|
relayout_nchw_nchw4->exec(*args.bias_tensor, inner_bias, {}); |
|
|
|
} |
|
|
|
if (inner_z_layout.ndim > 0) { |
|
|
|
relayout_nchw_nchw4->exec(*args.z_tensor, inner_z, {}); |
|
|
|
|
|
|
|
if (dst_float) { |
|
|
|
inner_opr->exec(inner_src, inner_weight, *args.bias_tensor, |
|
|
|
*args.z_tensor, *args.dst_tensor, nullptr, |
|
|
|
Workspace((dt_byte*)bundle.get(5), bundle.get_size(5))); |
|
|
|
} else { |
|
|
|
if (inner_bias_layout.ndim > 0) { |
|
|
|
relayout_nchw_nchw4->exec(*args.bias_tensor, inner_bias, {}); |
|
|
|
} |
|
|
|
if (inner_z_layout.ndim > 0) { |
|
|
|
relayout_nchw_nchw4->exec(*args.z_tensor, inner_z, {}); |
|
|
|
} |
|
|
|
inner_opr->exec(inner_src, inner_weight, inner_bias, inner_z, inner_dst, |
|
|
|
nullptr, |
|
|
|
Workspace((dt_byte*)bundle.get(5), bundle.get_size(5))); |
|
|
|
relayout_nchw4_nchw->exec(inner_dst, *args.dst_tensor, {}); |
|
|
|
} |
|
|
|
inner_opr->exec(inner_src, inner_weight, inner_bias, inner_z, inner_dst, |
|
|
|
nullptr, Workspace((dt_byte*)bundle.get(5), bundle.get_size(5))); |
|
|
|
relayout_nchw4_nchw->exec(inner_dst, *args.dst_tensor, {}); |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen |