|
@@ -1205,17 +1205,33 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
return ret; |
|
|
return ret; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_resize_opr = [](OperatorNodeBase* opr, |
|
|
|
|
|
|
|
|
/* This helper function guarantees the format convert pass won't change |
|
|
|
|
|
* output var's channel. Changing output's channel will cause channel |
|
|
|
|
|
* mismatch problem for replacing conv/conv_bias operator. |
|
|
|
|
|
*/ |
|
|
|
|
|
auto replace_helper = [](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) -> OperatorNodeBase* { |
|
|
|
|
|
auto&& new_shp = new_inp[0]->shape(); |
|
|
|
|
|
size_t inp_channel = new_shp[1]; |
|
|
|
|
|
if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) { |
|
|
|
|
|
auto new_opr = serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
|
|
opr->config()); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
} |
|
|
|
|
|
return nullptr; |
|
|
|
|
|
}; |
|
|
|
|
|
auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { |
|
|
|
|
|
return opr_shallow_copy; |
|
|
|
|
|
} |
|
|
auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>(); |
|
|
auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>(); |
|
|
mgb_assert(resize_opr.param().format == |
|
|
mgb_assert(resize_opr.param().format == |
|
|
megdnn::param::Resize::Format::NCHW, |
|
|
megdnn::param::Resize::Format::NCHW, |
|
|
"ConvertFormat Pass only support converting NCHW to NHWCD4"); |
|
|
"ConvertFormat Pass only support converting NCHW to NHWCD4"); |
|
|
VarNode* inp = nullptr; |
|
|
VarNode* inp = nullptr; |
|
|
if (new_inp[0]->shape().ndim == 4) { |
|
|
if (new_inp[0]->shape().ndim == 4) { |
|
|
// new input src is NCHW |
|
|
|
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; |
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; |
|
|
auto rf = opr::RelayoutFormat::make(new_inp[0], param); |
|
|
auto rf = opr::RelayoutFormat::make(new_inp[0], param); |
|
@@ -1235,9 +1251,13 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
return new_resize_opr.node()->owner_opr(); |
|
|
return new_resize_opr.node()->owner_opr(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
|
|
|
auto replace_warp_perspective_opr = [replace_helper]( |
|
|
|
|
|
OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { |
|
|
|
|
|
return opr_shallow_copy; |
|
|
|
|
|
} |
|
|
auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>(); |
|
|
auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>(); |
|
|
mgb_assert(warp_opr.param().format == |
|
|
mgb_assert(warp_opr.param().format == |
|
|
megdnn::param::WarpPerspective::Format::NCHW, |
|
|
megdnn::param::WarpPerspective::Format::NCHW, |
|
@@ -1273,9 +1293,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
return new_warp_opr.node()->owner_opr(); |
|
|
return new_warp_opr.node()->owner_opr(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_warp_affine_opr = [](OperatorNodeBase* opr, |
|
|
|
|
|
|
|
|
auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { |
|
|
|
|
|
return opr_shallow_copy; |
|
|
|
|
|
} |
|
|
auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>(); |
|
|
auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>(); |
|
|
mgb_assert(warp_opr.param().format == |
|
|
mgb_assert(warp_opr.param().format == |
|
|
megdnn::param::WarpAffine::Format::NCHW, |
|
|
megdnn::param::WarpAffine::Format::NCHW, |
|
@@ -1303,9 +1326,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
return new_warp_opr.node()->owner_opr(); |
|
|
return new_warp_opr.node()->owner_opr(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_pooling_opr = [](OperatorNodeBase* opr, |
|
|
|
|
|
|
|
|
auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { |
|
|
|
|
|
return opr_shallow_copy; |
|
|
|
|
|
} |
|
|
auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); |
|
|
auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); |
|
|
mgb_assert(pooling_opr.param().format == |
|
|
mgb_assert(pooling_opr.param().format == |
|
|
megdnn::param::Pooling::Format::NCHW, |
|
|
megdnn::param::Pooling::Format::NCHW, |
|
@@ -1344,24 +1370,6 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
mgb_assert(new_inp[i]->shape().ndim == 5 && |
|
|
mgb_assert(new_inp[i]->shape().ndim == 5 && |
|
|
new_inp[i]->format().type() == |
|
|
new_inp[i]->format().type() == |
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
// Oprs which will change the shape of input like concat, |
|
|
|
|
|
// reshape etc. should not be used after cd4 convertion padding, |
|
|
|
|
|
// due to the padding info will be lost and we cannot recover |
|
|
|
|
|
// the origin unpadded data. For example, concat two tensors of |
|
|
|
|
|
// shape {1, 6, 128, 128}, if both tensors convert to cd4 then |
|
|
|
|
|
// the channel will be 8, and the result of concat channel will |
|
|
|
|
|
// be 16, but there will be 2 padding zeros in the middle of |
|
|
|
|
|
// channel axis, which will cause problems in succeding opr. |
|
|
|
|
|
if (opr->dyn_typeinfo() == opr::Concat::typeinfo()) { |
|
|
|
|
|
auto concat = try_cast_as_op<opr::Concat>(opr); |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
!(concat->param().axis == 1 && |
|
|
|
|
|
concat->input(i)->shape()[1] % 4 != 0), |
|
|
|
|
|
"We cannot concat tensor in channel axis which has " |
|
|
|
|
|
"been padded, as it may lost padding pos if we " |
|
|
|
|
|
"pass " |
|
|
|
|
|
"the output to conv etc."); |
|
|
|
|
|
} |
|
|
|
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; |
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; |
|
|
auto rf = opr::RelayoutFormat::make(new_inp[i], param); |
|
|
auto rf = opr::RelayoutFormat::make(new_inp[i], param); |
|
|