diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 0c0aaab6..16207f03 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1205,17 +1205,33 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { TensorFormat::Type::IMAGE2D_PACK4); return ret; }; - - auto replace_resize_opr = [](OperatorNodeBase* opr, + /* This helper function guarantees the format convert pass won't change + * output var's channel. Changing output's channel will cause channel + * mismatch problem for replacing conv/conv_bias operator. + */ + auto replace_helper = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> OperatorNodeBase* { + auto&& new_shp = new_inp[0]->shape(); + size_t inp_channel = new_shp[1]; + if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) { + auto new_opr = serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + return new_opr; + } + return nullptr; + }; + auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); + if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { + return opr_shallow_copy; + } auto& resize_opr = opr->cast_final_safe(); mgb_assert(resize_opr.param().format == megdnn::param::Resize::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NHWCD4"); VarNode* inp = nullptr; if (new_inp[0]->shape().ndim == 4) { - // new input src is NCHW auto param = megdnn::param::RelayoutFormat(); param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; auto rf = opr::RelayoutFormat::make(new_inp[0], param); @@ -1235,9 +1251,13 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { return new_resize_opr.node()->owner_opr(); }; - auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + auto replace_warp_perspective_opr = [replace_helper]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); + if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { + return opr_shallow_copy; + } auto& warp_opr = opr->cast_final_safe(); mgb_assert(warp_opr.param().format == megdnn::param::WarpPerspective::Format::NCHW, @@ -1273,9 +1293,12 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { return new_warp_opr.node()->owner_opr(); }; - auto replace_warp_affine_opr = [](OperatorNodeBase* opr, + auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); + if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { + return opr_shallow_copy; + } auto& warp_opr = opr->cast_final_safe(); mgb_assert(warp_opr.param().format == megdnn::param::WarpAffine::Format::NCHW, @@ -1303,9 +1326,12 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { return new_warp_opr.node()->owner_opr(); }; - auto replace_pooling_opr = [](OperatorNodeBase* opr, + auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); + if (auto opr_shallow_copy = replace_helper(opr, new_inp)) { + return opr_shallow_copy; + } auto& pooling_opr = opr->cast_final_safe(); mgb_assert(pooling_opr.param().format == megdnn::param::Pooling::Format::NCHW, @@ -1344,24 +1370,6 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { mgb_assert(new_inp[i]->shape().ndim == 5 && new_inp[i]->format().type() == TensorFormat::Type::IMAGE2D_PACK4); - // Oprs which will change the shape of input like concat, - // reshape etc. should not be used after cd4 convertion padding, - // due to the padding info will be lost and we cannot recover - // the origin unpadded data. For example, concat two tensors of - // shape {1, 6, 128, 128}, if both tensors convert to cd4 then - // the channel will be 8, and the result of concat channel will - // be 16, but there will be 2 padding zeros in the middle of - // channel axis, which will cause problems in succeding opr. - if (opr->dyn_typeinfo() == opr::Concat::typeinfo()) { - auto concat = try_cast_as_op(opr); - mgb_assert( - !(concat->param().axis == 1 && - concat->input(i)->shape()[1] % 4 != 0), - "We cannot concat tensor in channel axis which has " - "been padded, as it may lost padding pos if we " - "pass " - "the output to conv etc."); - } auto param = megdnn::param::RelayoutFormat(); param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; auto rf = opr::RelayoutFormat::make(new_inp[i], param); diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 235ceb4b..5fe6fa64 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1045,13 +1045,18 @@ TEST(TestGoptInference, ConvertFormatPadIC) { param.sparse = opr::Convolution::Param::Sparse::DENSE; auto w1 = mkcvar("w1", {12, 12, 3, 3}); auto y = opr::Convolution::make(concat, w1, param); - MGB_MARK_USED_VAR(y); SymbolVar y_opt; - ASSERT_THROW(unpack_vector(gopt::optimize_for_inference( - {y}, gopt::OptimizeForInferenceOptions{} - .enable_use_nhwcd4()), - y_opt), - AssertionError); + unpack_vector( + gopt::optimize_for_inference( + {y}, + gopt::OptimizeForInferenceOptions{}.enable_use_nhwcd4()), + y_opt); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } TEST(TestGoptInference, ConvertBatchNormPass) {