|
|
@@ -1066,7 +1066,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_conv_opr = [&filter_mode](OperatorNodeBase* opr, |
|
|
|
auto size_one_conv_to_dense_conv = |
|
|
|
[](VarNode* origin_filter_input, |
|
|
|
megdnn::param::Convolution::Sparse sparse) { |
|
|
|
VarNode* reshaped_filter = origin_filter_input; |
|
|
|
bool is_size_one_group_conv = false; |
|
|
|
if (sparse == megdnn::param::Convolution::Sparse::GROUP && |
|
|
|
origin_filter_input->shape()[0] == 1) { |
|
|
|
is_size_one_group_conv = true; |
|
|
|
auto new_shape = origin_filter_input->shape(); |
|
|
|
new_shape.ndim = 4; |
|
|
|
for (int i = 0; i < 4; i++) { |
|
|
|
new_shape[i] = origin_filter_input->shape()[i + 1]; |
|
|
|
} |
|
|
|
SymbolVar new_var(origin_filter_input); |
|
|
|
reshaped_filter = new_var.reshape(new_shape).node(); |
|
|
|
} |
|
|
|
return std::make_tuple(reshaped_filter, is_size_one_group_conv); |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_conv_opr = |
|
|
|
[&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); |
|
|
@@ -1131,19 +1151,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2); |
|
|
|
conv_src = new_inp[0]; |
|
|
|
} |
|
|
|
VarNode* reshaped_filter; |
|
|
|
bool is_size_one_group_conv; |
|
|
|
std::tie(reshaped_filter, is_size_one_group_conv) = |
|
|
|
size_one_conv_to_dense_conv(new_inp[1], |
|
|
|
conv_opr.param().sparse); |
|
|
|
auto new_conv_param = conv_opr.param(); |
|
|
|
if (is_size_one_group_conv) { |
|
|
|
new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE; |
|
|
|
} |
|
|
|
mgb_assert(new_inp[1]->format().type() != |
|
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
|
param.mode = filter_mode(conv_opr.param().sparse, new_inp[1]); |
|
|
|
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); |
|
|
|
param.mode = filter_mode(new_conv_param.sparse, reshaped_filter); |
|
|
|
auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param); |
|
|
|
conv_weights = relayout_weight.node(); |
|
|
|
auto new_param = conv_opr.param(); |
|
|
|
new_param.format = megdnn::param::Convolution::Format::NHWCD4; |
|
|
|
new_conv_param.format = megdnn::param::Convolution::Format::NHWCD4; |
|
|
|
mgb_assert(conv_src->shape().ndim == 5 && |
|
|
|
conv_src->format().type() == |
|
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
|
auto new_conv_opr = opr::Convolution::make( |
|
|
|
conv_src, conv_weights, new_param, conv_opr.execution_policy(), |
|
|
|
conv_src, conv_weights, new_conv_param, conv_opr.execution_policy(), |
|
|
|
conv_opr.config()); |
|
|
|
OperatorNodeBase* ret = new_conv_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_opr.shape().ndim == 5 && |
|
|
@@ -1152,7 +1180,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_conv_bias_opr = [&filter_mode](OperatorNodeBase* opr, |
|
|
|
auto replace_conv_bias_opr = |
|
|
|
[&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
@@ -1221,9 +1250,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
mgb_assert(new_inp[1]->format().type() != |
|
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
|
|
|
|
|
VarNode* reshaped_filter; |
|
|
|
bool is_size_one_group_conv; |
|
|
|
std::tie(reshaped_filter, is_size_one_group_conv) = |
|
|
|
size_one_conv_to_dense_conv(new_inp[1], |
|
|
|
conv_bias_opr.param().sparse); |
|
|
|
auto new_conv_param = conv_bias_opr.param(); |
|
|
|
if (is_size_one_group_conv) { |
|
|
|
new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE; |
|
|
|
} |
|
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
|
param.mode = filter_mode(conv_bias_opr.param().sparse, new_inp[1]); |
|
|
|
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); |
|
|
|
param.mode = filter_mode(new_conv_param.sparse, reshaped_filter); |
|
|
|
auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param); |
|
|
|
conv_bias_weights = relayout_weight.node(); |
|
|
|
|
|
|
|
mgb_assert(new_inp.size() < 4, |
|
|
@@ -1238,19 +1276,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
conv_bias_bias = new_inp[2]; |
|
|
|
} |
|
|
|
|
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NHWCD4; |
|
|
|
new_conv_param.format = megdnn::param::ConvBias::Format::NHWCD4; |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 5 && |
|
|
|
conv_bias_src->format().type() == |
|
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
|
SymbolVar new_conv_bias_opr; |
|
|
|
if (has_bias) { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param, |
|
|
|
conv_bias_src, conv_bias_weights, conv_bias_bias, new_conv_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
} else { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_weights, new_param, |
|
|
|
conv_bias_src, conv_bias_weights, new_conv_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
} |
|
|
|
OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr(); |
|
|
|