From bba04f02e5b305455ba094f0fa72e3fc4ae59d5b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 18 May 2021 20:46:10 +0800 Subject: [PATCH] feat(mgb/gopt): add fusion support for conv, astype(s4) and reformat GitOrigin-RevId: 6329ca2c5fe85dc4f7c94610a293ab506a3c2945 --- src/gopt/impl/tensor_reformat.cpp | 133 +++++++++++++++++++++++++++++++++++++- src/gopt/test/inference.cpp | 87 +++++++++++++++++++++++++ src/opr/impl/dnn/convolution.cpp | 8 ++- src/plugin/impl/opr_footprint.cpp | 5 +- 4 files changed, 230 insertions(+), 3 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 75fb9451..8789444f 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -3550,6 +3550,35 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { return y2.node(); }; + auto nchw42nhwc = [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); + auto y1 = opr::Reshape::make(y0, tshp); + return y1.node(); + }; + + auto nhwc2nchw64 = [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 4); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); + auto y0 = opr::Reshape::make(x, tshp); + auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); + return y1.node(); + }; + auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, &nchw42nchw]( OperatorNodeBase* opr) { @@ -3721,6 +3750,106 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { return true; }; + auto try_conv_reformat_nchw42nchw64 = [&rewriter, &nchw42nhwc, &nhwc2nchw64, + &readers](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check reshape + auto reshape1 = + try_cast_as_op(opr); + if (reshape1 == nullptr) + return false; + opr_set.insert(opr); + + // check dimshuffle + auto shuffle = try_cast_as_op( + reshape1->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 6) + return false; + bool is_nchw42nchw64 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 3 && param.pattern[3] == 4 && + param.pattern[4] == 2 && param.pattern[5] == 5 && + shuffle->output(0)->shape()[5] == 4 && + shuffle->output(0)->shape()[4] == 16; + if (!is_nchw42nchw64) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + // check reshape + auto reshape2 = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (reshape2 == nullptr) + return false; + opr_set.insert(reshape2); + for (auto&& i : readers[reshape2]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + auto typecvt = + try_cast_as_op(reshape2->input(0)->owner_opr()); + if (typecvt == nullptr) + return false; + auto in_dtype = typecvt->input(0)->dtype(), + out_dtype = typecvt->output(0)->dtype(); + printf("%s, %s\n", in_dtype.name(), out_dtype.name()); + bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && + (out_dtype.enumv() == DTypeEnum::QuantizedS4 || + out_dtype.enumv() == DTypeEnum::Quantized4Asymm); + if (!is_s82s4) + return false; + opr_set.insert(typecvt); + + // check conv bias + auto conv_bias = + try_cast_as_op(typecvt->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + auto inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW4; + if (!is_s8nchw4) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = nchw42nhwc(bias); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + OperatorNodeConfig{out_dtype}); + auto new_var = nhwc2nchw64(conv_bias_shuffle.node()); + rewriter.replace_var( + opr->output(0), new_var, + mgb_cstr_log("replace conv_bias + " + "reformat to conv_bias(NCHW4_NCHW64)")); + return true; + }; + auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( OperatorNodeBase* opr) { ThinHashSet opr_set; @@ -3805,12 +3934,14 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, &try_conv_reformat_nchw42nchw32, + &try_conv_reformat_nchw42nchw64, #if CUDA_VERSION >= 10020 &try_conv_reformat_nchw322nchw4, #endif &rewriter](OperatorNodeBase* opr) { if (!try_conv_dimshuffle_reshape_typecvt(opr) && - !try_conv_reformat_nchw42nchw32(opr) + !try_conv_reformat_nchw42nchw32(opr) && + !try_conv_reformat_nchw42nchw64(opr) #if CUDA_VERSION >= 10020 && !try_conv_reformat_nchw322nchw4(opr) #endif diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index af6ef436..b8598a45 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -4400,6 +4400,93 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW32NCHW4) { func->execute(); MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); } + +TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NHWC) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 75) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 75); + return; + } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {32, 4, 23, 40}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {64, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)), + w1 = mkcvar("w1", {64, 64, 3, 3}, dtype::QuantizedS4(1.234f)), + b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32(12.34567f*1.234f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make( + x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(12.34567f)}); + y = opr::TypeCvt::make(y, dtype::QuantizedS4(12.34567f)); + y = opr::ConvBias::make(y, w1, b1, param, {}, + OperatorNodeConfig{dtype::QuantizedS4(56.71234f)}); + y = opr::TypeCvt::make(y, dtype::Float32()); + SymbolVar y_fuse, y_non_fuse; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw64(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_fuse); + } + using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; + S strategy = S::PROFILE; + gopt::modify_opr_algo_strategy_inplace({y_fuse}, strategy); + HostTensorND host_y_fuse; + auto func1 = graph->compile({make_callback_copy(y_fuse, host_y_fuse)}); + func1->execute(); + graph->compile({{y_fuse, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.FoldingConvDimshuffleNCHW4NHWC.json")); + size_t nr_dimshuffle = find_opr_num(y_fuse); + printf("%zu \n", nr_dimshuffle); + ASSERT_EQ(3u, find_opr_num(y_fuse)); + bool found = false; + cg::DepOprIter{[&found](cg::OperatorNodeBase* opr) { + if (!found && opr->same_type()) { + opr::ConvBias* cb = &opr->cast_final_safe(); + if (cb->param().format == opr::ConvBias::Param::Format::NCHW4_NHWC) + found = true; + } + }} + .add(y_fuse.node()->owner_opr()); + EXPECT_TRUE(found); + unpack_vector(gopt::GraphOptimizer{}.apply({{y}}).endpoint_vars(), + y_non_fuse); + gopt::modify_opr_algo_strategy_inplace({y_non_fuse}, strategy); + HostTensorND host_y_non_fuse; + auto func2 = + graph->compile({make_callback_copy(y_non_fuse, host_y_non_fuse)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); +} #endif TEST(TestGoptInference, PaddingChannels) { diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 64433b9e..8d0121b1 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -864,7 +864,13 @@ void ConvBiasForward::init_output_static_infer_desc() { void ConvBiasForward::init_output_format() { mgb_assert(output().size() == 2); - output(0)->format(input(0)->format()); + auto format = input(0)->format(); + if (!format.is_default() && !format.is_lowbit_aligned()) { // propagate + output(0)->format(input(0)->format()); + } else { + mgb_assert(output(0)->dtype().valid()); + output(0)->format(TensorFormat(output(0)->dtype())); + } } void ConvBiasForward::check_winograd_param_valid( diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 752806bb..36d3e009 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -147,9 +147,11 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, packed_size = 32; } else { mgb_assert(param.format == Param::Format::NCHW4 || + param.format == Param::Format::NCHW4_NHWC || param.format == Param::Format::NCHW4_NCHW || param.format == Param::Format::NCHW4_NCHW32, - "format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32"); + "format should be " + "NCHW4/NCHW4_NCHW/NCHW4_NHWC/NCHW4_NCHW32"); packed_size = 4; } return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * packed_size / group * @@ -174,6 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, }; if (param.format == Param::Format::NCHW4 || param.format == Param::Format::NCHW4_NCHW || + param.format == Param::Format::NCHW4_NHWC || param.format == Param::Format::NCHW4_NCHW32 || param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW44 ||