From 5f15f759841579dae17c62d0fb3847305a7d8d66 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 20 Oct 2021 17:52:02 +0800 Subject: [PATCH] test(mgb/gopt): add a testcase for SubGraphExtractor with multiple outputs GitOrigin-RevId: 7785bdc8c090467cf75c864cb4056a0cb2059199 --- .../layout_transform_context.cpp | 25 ++++++------- .../opr_tensor_formats_config.cpp | 41 ++++++++-------------- src/gopt/test/subgraph_extractor.cpp | 38 ++++++++++++++++++++ 3 files changed, 66 insertions(+), 38 deletions(-) diff --git a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp index 5034c7dc..20bc12f3 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp @@ -94,8 +94,8 @@ std::unique_ptr make_arm_ctx( opr::TypeCvt::typeinfo(), opr::PoolingForward::typeinfo(), opr::Resize::typeinfo(), - opr::PowC::typeinfo(), - opr::Concat::typeinfo(), + opr::PowC::typeinfo(), + opr::Concat::typeinfo(), }; SmallVector available_tensor_formats = { @@ -103,22 +103,23 @@ std::unique_ptr make_arm_ctx( DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; auto ctx = std::make_unique( - std::move(opr_list), std::move(available_tensor_formats), - attribute); + std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW44, - DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) + {OprFormat::NCHW, OprFormat::NCHW44, DNN_INC_FLOAT16(OprFormat::NCHW88), + OprFormat::NCHW44_DOT}) .add_opr_config( opr::ConvolutionForward::typeinfo(), {OprFormat::NCHW, OprFormat::NCHW44, DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) - .add_opr_config(opr::PoolingForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW44, - DNN_INC_FLOAT16(OprFormat::NCHW88)}) - .add_opr_config(opr::ResizeForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW44, - DNN_INC_FLOAT16(OprFormat::NCHW88)}); + .add_opr_config( + opr::PoolingForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW44, + DNN_INC_FLOAT16(OprFormat::NCHW88)}) + .add_opr_config( + opr::ResizeForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW44, + DNN_INC_FLOAT16(OprFormat::NCHW88)}); return ctx; } } // namespace diff --git a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp index c3a32730..f5f450f5 100644 --- a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp +++ b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp @@ -80,8 +80,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { template <> struct OprSingleInOutTensorFormatsDispatcherImpl { - static Maybe dispatch( - const OperatorNodeBase* opr) { + static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW44; @@ -101,8 +100,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { #if !MEGDNN_DISABLE_FLOAT16 template <> struct OprSingleInOutTensorFormatsDispatcherImpl { - static Maybe dispatch( - const OperatorNodeBase* opr) { + static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW88; @@ -440,8 +438,7 @@ struct ConvTensorFormatsDispatcherImpl { template struct ConvTensorFormatsDispatcherImpl { - static Maybe dispatch( - const OperatorNodeBase* opr) { + static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); @@ -451,8 +448,7 @@ struct ConvTensorFormatsDispatcherImpl { for (size_t i = 0; i < opr->input().size(); ++i) { available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); - TensorType tensor_type = - i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; @@ -484,8 +480,7 @@ struct ConvTensorFormatsDispatcherImpl { #if !MEGDNN_DISABLE_FLOAT16 template struct ConvTensorFormatsDispatcherImpl { - static Maybe dispatch( - const OperatorNodeBase* opr) { + static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); @@ -495,8 +490,7 @@ struct ConvTensorFormatsDispatcherImpl { for (size_t i = 0; i < opr->input().size(); ++i) { available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); - TensorType tensor_type = - i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; @@ -528,8 +522,7 @@ struct ConvTensorFormatsDispatcherImpl { template struct ConvTensorFormatsDispatcherImpl { - static Maybe dispatch( - const OperatorNodeBase* opr) { + static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); @@ -538,22 +531,18 @@ struct ConvTensorFormatsDispatcherImpl { // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) { - available &= opr->input(i)->dtype().enumv() == - DTypeEnum::QuantizedS32; + available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; } else { - available &= opr->input(i)->dtype().enumv() == - DTypeEnum::QuantizedS8 || - opr->input(i)->dtype().enumv() == - DTypeEnum::Quantized8Asymm; + available &= + opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->input(i)->dtype().enumv() == DTypeEnum::Quantized8Asymm; } config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); - TensorType tensor_type = - i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } - available &= - opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || - opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; + available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); // setup tensor formats if (conv.param().sparse == Opr::Param::Sparse::DENSE) { @@ -747,7 +736,7 @@ StaticData::StaticData() { OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); -#if !MEGDNN_DISABLE_FLOAT16 +#if !MEGDNN_DISABLE_FLOAT16 OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); #endif diff --git a/src/gopt/test/subgraph_extractor.cpp b/src/gopt/test/subgraph_extractor.cpp index fccc299b..1580eeb5 100644 --- a/src/gopt/test/subgraph_extractor.cpp +++ b/src/gopt/test/subgraph_extractor.cpp @@ -264,4 +264,42 @@ TEST(TestSubGraphExtractor, Complicated) { output_file(ssprintf("%s.json", prefix).c_str())); } +TEST(TestSubGraphExtractor, SubGraphWithMultipleOutputs) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); + }; + + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); + }; + + graph->options().graph_opt_level = 0; + auto x = mkvar("x", {8, 8, 8, 8}), w = mkcvar("w", {4, 8, 3, 3}); + + opr::Convolution::Param param; + param.pad_h = param.pad_w = 1; + auto c = opr::Convolution::make(x, w, param); + auto neg_c = -c; + auto z = opr::Concat::make({c, neg_c}, 1); + + using OprList = SubGraphExtractor::OprList; + static const OprList opr_list = { + opr::ConvolutionForward::typeinfo(), + opr::Elemwise::typeinfo(), + }; + SubGraphExtractor extractor(opr_list); + auto partitions = extractor.extract({z}); + ASSERT_EQ(partitions.size(), 1u); + ASSERT_EQ(partitions[0].output().size(), 2u); + ASSERT_TRUE(partitions[0].output().count(c.node()) > 0); + ASSERT_TRUE(partitions[0].output().count(neg_c.node()) > 0); + ASSERT_EQ(partitions[0].input().size(), 2u); + ASSERT_TRUE(partitions[0].input().count(x.node()) > 0); + partitions[0].to_json()->writeto_fpath( + output_file("TestSubGraphExtractor.SubGraphMultipleOuputs.json")); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}