|
@@ -1619,6 +1619,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
megdnn::param::Convolution::Format::NCHW4; |
|
|
megdnn::param::Convolution::Format::NCHW4; |
|
|
megdnn::param::ConvBias::Format conv_bias_format = |
|
|
megdnn::param::ConvBias::Format conv_bias_format = |
|
|
megdnn::param::ConvBias::Format::NCHW4; |
|
|
megdnn::param::ConvBias::Format::NCHW4; |
|
|
|
|
|
megdnn::param::ConvBias::Format conv_bias_format_nchw4_nchw = |
|
|
|
|
|
megdnn::param::ConvBias::Format::NCHW4_NCHW; |
|
|
megdnn::param::BatchConvBias::Format batch_conv_bias_format = |
|
|
megdnn::param::BatchConvBias::Format batch_conv_bias_format = |
|
|
megdnn::param::BatchConvBias::Format::NCHW4; |
|
|
megdnn::param::BatchConvBias::Format::NCHW4; |
|
|
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; |
|
|
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; |
|
@@ -1821,6 +1823,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
return new_opr; |
|
|
return new_opr; |
|
|
}; |
|
|
}; |
|
|
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, |
|
|
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, |
|
|
|
|
|
conv_bias_format_nchw4_nchw, |
|
|
src_to_nchw4_mode]( |
|
|
src_to_nchw4_mode]( |
|
|
OperatorNodeBase* opr, |
|
|
OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
@@ -1851,19 +1854,27 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
conv_bias_filter = new_filter.node(); |
|
|
conv_bias_filter = new_filter.node(); |
|
|
// format: NCHW --> NCHW4 |
|
|
// format: NCHW --> NCHW4 |
|
|
auto new_param = conv_bias_opr.param(); |
|
|
auto new_param = conv_bias_opr.param(); |
|
|
new_param.format = conv_bias_format; |
|
|
|
|
|
|
|
|
if (conv_bias_opr.output().size() > 0 && |
|
|
|
|
|
conv_bias_opr.output(0)->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
|
|
new_param.format = conv_bias_format_nchw4_nchw; |
|
|
|
|
|
} else { |
|
|
|
|
|
new_param.format = conv_bias_format; |
|
|
|
|
|
} |
|
|
if (new_inp.size() == 2) { |
|
|
if (new_inp.size() == 2) { |
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || |
|
|
|
|
|
new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
return new_opr; |
|
|
return new_opr; |
|
|
} |
|
|
} |
|
|
// bias: NCHW --> NCHW4 |
|
|
|
|
|
|
|
|
// bias: NCHW --> NCHW4 when bias_dtype is not Float32 |
|
|
VarNode* conv_bias_bias = new_inp[2]; |
|
|
VarNode* conv_bias_bias = new_inp[2]; |
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
|
|
|
|
|
if (new_inp[2]->dtype().enumv() != DTypeEnum::Float32 && |
|
|
|
|
|
new_inp[2]->shape().ndim == 4) { |
|
|
auto new_bias = |
|
|
auto new_bias = |
|
|
RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode); |
|
|
RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode); |
|
|
conv_bias_bias = new_bias.node(); |
|
|
conv_bias_bias = new_bias.node(); |
|
@@ -1873,13 +1884,16 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || |
|
|
|
|
|
new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
return new_opr; |
|
|
return new_opr; |
|
|
} |
|
|
} |
|
|
// z_inp: NCHW --> NCHW4 |
|
|
|
|
|
|
|
|
// z_inp: NCHW --> NCHW4 when bias_dtype is not Float32 |
|
|
VarNode* z_inp = new_inp[3]; |
|
|
VarNode* z_inp = new_inp[3]; |
|
|
if (new_inp[3]->shape().ndim == 4) { |
|
|
|
|
|
|
|
|
if (new_inp[3]->dtype().enumv() != DTypeEnum::Float32 && |
|
|
|
|
|
new_inp[3]->shape().ndim == 4) { |
|
|
auto new_z = |
|
|
auto new_z = |
|
|
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode); |
|
|
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode); |
|
|
z_inp = new_z.node(); |
|
|
z_inp = new_z.node(); |
|
@@ -1889,8 +1903,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
new_param, conv_bias_opr.execution_policy(), |
|
|
new_param, conv_bias_opr.execution_policy(), |
|
|
conv_bias_opr.config()); |
|
|
conv_bias_opr.config()); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || |
|
|
|
|
|
new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
return new_opr; |
|
|
return new_opr; |
|
|
}; |
|
|
}; |
|
|
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, |
|
|
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, |
|
|