From fac67e7c2badcb5cef7b2e549076591fa6ba4470 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 22 Jul 2022 12:07:12 +0800 Subject: [PATCH] feat(gopt): support nchw44 global pooling with fuse_grain GitOrigin-RevId: 4c43a149f8214aae7f48ae72bbd6928c1e18f0f5 --- dnn/src/arm_common/adaptive_pooling/opr_impl.cpp | 29 +++ dnn/src/arm_common/adaptive_pooling/opr_impl.h | 20 ++ dnn/src/arm_common/handle.cpp | 2 + dnn/src/arm_common/pooling/opr_impl.cpp | 75 +++++++ dnn/src/common/adaptive_pooling.cpp | 7 +- dnn/src/common/pooling.cpp | 4 +- dnn/src/naive/adaptive_pooling/opr_impl.cpp | 14 +- dnn/src/naive/adaptive_pooling/opr_impl.h | 4 +- dnn/test/arm_common/adaptive_pooling.cpp | 98 ++++++++ dnn/test/common/adaptive_pooling.h | 30 +++ .../python/megengine/core/tensor/megbrain_graph.py | 8 +- imperative/python/megengine/utils/network.py | 5 +- imperative/python/src/graph_rt.cpp | 4 +- imperative/python/test/unit/utils/test_network.py | 1 + lite/load_and_run/src/models/model_mdl.cpp | 7 + lite/load_and_run/src/options/model_options.cpp | 43 +++- lite/load_and_run/src/options/model_options.h | 20 +- lite/load_and_run/src/options/optimize_options.cpp | 60 ++++- lite/load_and_run/src/options/optimize_options.h | 25 +++ src/core/include/megbrain/graph/bases.h | 2 +- src/core/include/megbrain/graph/cg.h | 4 + src/gopt/impl/folding_golbal_pooling.cpp | 180 +++++++++++++++ src/gopt/impl/folding_reduce_mean.cpp | 117 ++++++++++ src/gopt/impl/framework.cpp | 5 +- src/gopt/impl/gopt_helper.h | 35 +++ src/gopt/impl/tensor_reformat.cpp | 32 +++ src/gopt/include/megbrain/gopt/inference.h | 22 ++ src/gopt/test/gopt_old_model.cpp | 249 +++++++++++++++++++++ src/gopt/test/inference.cpp | 66 ++++++ src/opr/impl/dnn/adaptive_pooling.cpp | 15 +- src/opr/test/dnn/adaptive_pooling.cpp | 1 - 31 files changed, 1162 insertions(+), 22 deletions(-) create mode 100644 dnn/src/arm_common/adaptive_pooling/opr_impl.cpp create mode 100644 dnn/src/arm_common/adaptive_pooling/opr_impl.h create mode 100644 dnn/test/arm_common/adaptive_pooling.cpp create mode 100644 src/gopt/impl/folding_golbal_pooling.cpp create mode 100644 src/gopt/impl/folding_reduce_mean.cpp create mode 100644 src/gopt/impl/gopt_helper.h create mode 100644 src/gopt/test/gopt_old_model.cpp diff --git a/dnn/src/arm_common/adaptive_pooling/opr_impl.cpp b/dnn/src/arm_common/adaptive_pooling/opr_impl.cpp new file mode 100644 index 00000000..819ff5a8 --- /dev/null +++ b/dnn/src/arm_common/adaptive_pooling/opr_impl.cpp @@ -0,0 +1,29 @@ +#include "src/arm_common/adaptive_pooling/opr_impl.h" +#include "src/common/opr_delegate.h" +#include "src/naive/handle.h" +namespace megdnn { +namespace arm_common { + +void AdaptivePoolingImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + auto adapt_fwd = [=]() { + auto opr = inplace_cpu_handle()->create_operator(); + opr->param() = deduce_pooling_param(src.layout, dst.layout); + opr->exec(src, dst, workspace); + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(adapt_fwd()); + return; +} + +size_t AdaptivePoolingImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { + auto opr = inplace_cpu_handle()->create_operator(); + opr->param() = deduce_pooling_param(src, dst); + auto need_size = opr->get_workspace_in_bytes(src, dst); + return need_size; +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/adaptive_pooling/opr_impl.h b/dnn/src/arm_common/adaptive_pooling/opr_impl.h new file mode 100644 index 00000000..64c82d02 --- /dev/null +++ b/dnn/src/arm_common/adaptive_pooling/opr_impl.h @@ -0,0 +1,20 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace arm_common { + +class AdaptivePoolingImpl final : public AdaptivePoolingForward { +public: + using AdaptivePoolingForward::AdaptivePoolingForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; +}; + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/handle.cpp b/dnn/src/arm_common/handle.cpp index b2c13127..714001b7 100644 --- a/dnn/src/arm_common/handle.cpp +++ b/dnn/src/arm_common/handle.cpp @@ -2,6 +2,7 @@ #include "src/arm_common/handle.h" +#include "src/arm_common/adaptive_pooling/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h" #include "src/arm_common/cvt_color/opr_impl.h" @@ -45,6 +46,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell) MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell) MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM) +MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePooling) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wpragmas" diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index 6e33a7f6..4a356893 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -2,6 +2,7 @@ #include "src/arm_common/pooling/algo.h" #include "src/common/algo_chooser.h" #include "src/common/metahelper.h" +#include "src/common/opr_delegate.h" using namespace megdnn; using namespace arm_common; @@ -48,10 +49,72 @@ public: }; PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; +namespace { +TensorLayout merge_hw_layout(TensorLayout src) { + src.ndim -= 1; + src.shape[2] = src.shape[2] * src.shape[3]; + src.stride[2] = src.stride[3]; + for (size_t i = 3; i < src.ndim; ++i) { + src.shape[i] = src.shape[i + 1]; + src.stride[i] = src.stride[i + 1]; + } + return src; +} +std::pair get_gloabl_pooling_reduce_tensor( + const TensorND& src, const TensorND& dst) { + auto reduce_src_layout = merge_hw_layout(src.layout); + auto reduce_dst_layout = merge_hw_layout(dst.layout); + return std::make_pair( + {src.raw_ptr(), reduce_src_layout}, {dst.raw_ptr(), reduce_dst_layout}); +} +std::unique_ptr get_global_pooling_reduce_opr( + Handle* handle, const PoolingImpl::PoolingKernSizeParam& param) { + std::unique_ptr opr; + if (handle) { + opr = handle->create_operator(); + } else { + opr = inplace_cpu_handle()->create_operator(); + } + param::Reduce reduce_param; + reduce_param.axis = 2; + if (param.mode == PoolingImpl::Param::Mode::MAX) { + reduce_param.mode = param::Reduce::Mode::MAX; + } else { + megdnn_assert( + param.mode == PoolingImpl::Param::Mode::AVERAGE || + param.mode == PoolingImpl::Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING); + reduce_param.mode = param::Reduce::Mode::MEAN; + } + opr->param() = reduce_param; + return opr; +} +bool is_global_pooling_reduce(PoolingImpl::PoolingKernSizeParam& param) { + bool fmt_ok = param.format == PoolingImpl::Param::Format::NCHW || + param.format == PoolingImpl::Param::Format::NCHW44 || + param.format == PoolingImpl::Param::Format::NCHW88; + bool size_ok = param.filter[0] == param.isz[0] && param.filter[1] == param.isz[1] && + param.padding[0] == 0 && param.padding[1] == 0 && + param.osz[0] == 1 && param.osz[1] == 1; + bool dtype_ok = param.src_type == param.dst_type && + param.src_type.enumv() != DTypeEnum::Int8; + return fmt_ok && size_ok && dtype_ok; +} +} // namespace size_t PoolingImpl::get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& dst) { auto param = make_pooling_kern_szie_param(this, src, dst); + bool fwd_reduce = is_global_pooling_reduce(param); + if (fwd_reduce) { + TensorND src_tensor{nullptr, src}; + TensorND dst_tensor{nullptr, dst}; + auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src_tensor, dst_tensor); + auto&& opr = get_global_pooling_reduce_opr(nullptr, param); + auto reduce_need = opr->get_workspace_in_bytes( + reduce_tensor.first.layout, reduce_tensor.second.layout); + return reduce_need; + } + auto algo = get_algorithm(this, src, dst); if (!is_fallback_algo(algo)) { size_t arm_common_workspace = 0; @@ -93,6 +156,18 @@ void PoolingImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); auto param = make_pooling_kern_param(this, src, dst, workspace); + + bool fwd_reduce = is_global_pooling_reduce(param); + if (fwd_reduce) { + auto global_pooling_fwd = [=]() { + auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src, dst); + auto&& opr = get_global_pooling_reduce_opr(nullptr, param); + opr->exec(reduce_tensor.first, reduce_tensor.second, workspace); + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(global_pooling_fwd()); + return; + } + auto algo = get_algorithm(this, src.layout, dst.layout); if (!is_fallback_algo(algo)) { algo->exec(param); diff --git a/dnn/src/common/adaptive_pooling.cpp b/dnn/src/common/adaptive_pooling.cpp index 88862e1e..9eb2c590 100644 --- a/dnn/src/common/adaptive_pooling.cpp +++ b/dnn/src/common/adaptive_pooling.cpp @@ -8,7 +8,9 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param( const TensorLayout& src, const TensorLayout& dst) { auto param_format = param().format; size_t IH, IW, OH, OW; - if (param_format == param::AdaptivePooling::Format::NCHW) { + if (param_format == param::AdaptivePooling::Format::NCHW || + param_format == param::AdaptivePooling::Format::NCHW44 || + param_format == param::AdaptivePooling::Format::NCHW88) { IH = src.shape[2]; IW = src.shape[3]; OH = dst.shape[2]; @@ -19,7 +21,8 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param( OH = dst.shape[1]; OW = dst.shape[2]; } else { - megdnn_throw("AdaptivePooling only support NCHW or NHWC format"); + megdnn_throw( + "AdaptivePooling only support NCHW or NHWC or NCHW44 or NCHW88 format"); } param::Pooling ret; diff --git a/dnn/src/common/pooling.cpp b/dnn/src/common/pooling.cpp index 109fe383..f076131b 100644 --- a/dnn/src/common/pooling.cpp +++ b/dnn/src/common/pooling.cpp @@ -140,7 +140,9 @@ void PoolingForward::check_exec( const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); - megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + megdnn_assert( + workspace_in_bytes >= required_workspace_in_bytes, "need %zu, get %zu", + required_workspace_in_bytes, workspace_in_bytes); } void PoolingBackward::check_exec( diff --git a/dnn/src/naive/adaptive_pooling/opr_impl.cpp b/dnn/src/naive/adaptive_pooling/opr_impl.cpp index 1c1c5493..85bc19af 100644 --- a/dnn/src/naive/adaptive_pooling/opr_impl.cpp +++ b/dnn/src/naive/adaptive_pooling/opr_impl.cpp @@ -6,11 +6,17 @@ namespace megdnn { namespace naive { - +size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { + auto opr = inplace_cpu_handle(2)->create_operator(); + opr->param() = deduce_pooling_param(src, dst); + auto need_size = opr->get_workspace_in_bytes(src, dst); + return need_size; +} void AdaptivePoolingForwardImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), { - auto opr = inplace_cpu_handle()->create_operator(); + auto opr = inplace_cpu_handle(2)->create_operator(); opr->param() = deduce_pooling_param(src.layout, dst.layout); opr->exec(src, dst, workspace); }); @@ -20,7 +26,7 @@ void AdaptivePoolingBackwardImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), { - auto opr = inplace_cpu_handle()->create_operator(); + auto opr = inplace_cpu_handle(2)->create_operator(); opr->param() = deduce_pooling_param(src.layout, dst.layout); opr->exec(src, dst, diff, grad, workspace); }); @@ -29,7 +35,7 @@ void AdaptivePoolingBackwardImpl::exec( size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad) { - auto opr = inplace_cpu_handle()->create_operator(); + auto opr = inplace_cpu_handle(2)->create_operator(); opr->param() = deduce_pooling_param(src, dst); return opr->get_workspace_in_bytes(src, dst, diff, grad); } diff --git a/dnn/src/naive/adaptive_pooling/opr_impl.h b/dnn/src/naive/adaptive_pooling/opr_impl.h index 7d1a2f03..fa7ebe8c 100644 --- a/dnn/src/naive/adaptive_pooling/opr_impl.h +++ b/dnn/src/naive/adaptive_pooling/opr_impl.h @@ -11,9 +11,7 @@ public: void exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { - return 0; - } + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override; }; class AdaptivePoolingBackwardImpl : public AdaptivePoolingBackward { diff --git a/dnn/test/arm_common/adaptive_pooling.cpp b/dnn/test/arm_common/adaptive_pooling.cpp new file mode 100644 index 00000000..0796b985 --- /dev/null +++ b/dnn/test/arm_common/adaptive_pooling.cpp @@ -0,0 +1,98 @@ +#include "test/arm_common/fixture.h" + +#include "megdnn/tensor_iter.h" +#include "src/common/utils.h" +#include "test/common/adaptive_pooling.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD_NCHW44) { + auto args = adaptive_pooling::get_args_nchw44(); + Checker checker(handle()); + checker.set_epsilon(1e-4); + for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)}) + for (auto&& arg : args) { + auto param = arg.param; + auto src = arg.ishape; + auto dst = arg.oshape; + + checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( + TensorShapeArray{src, dst, {}}); + } +} + +TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD) { + auto args = adaptive_pooling::get_args(); + Checker checker(handle()); + checker.set_epsilon(1e-4); + for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)}) + for (auto&& arg : args) { + auto param = arg.param; + auto src = arg.ishape; + auto dst = arg.oshape; + + checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( + TensorShapeArray{src, dst, {}}); + } +} + +#if MEGDNN_WITH_BENCHMARK +namespace { +void benchmark_globalpooling_nchw44_fp32(Handle* handle) { + using Param = param::AdaptivePooling; + auto run = [&](size_t n, size_t c, size_t h, size_t w, Param::Mode mode) { + Param param; + param.format = Param::Format::NCHW; + param.mode = mode; + TensorShape nchw_shape = {n, c, h, w}; + TensorShape nchw_dst_shape = {n, c, 1, 1}; + TensorShape nchw44_shape = {n, c / 4, h, w, 4}; + TensorShape nchw44_dst_shape = {n, c / 4, 1, 1, 4}; + TensorLayout dst_layout; + float calc_amount = n * c * h * w; + + Benchmarker benchmarker_float_nchw(handle); + Benchmarker benchmarker_float_nchw44(handle); + Benchmarker benchmarker_int_nchw44(handle); + size_t RUN = 500; + auto t1 = benchmarker_float_nchw.set_display(false) + .set_times(RUN) + .set_param(param) + .exec({nchw_shape, nchw_dst_shape}); + + param.format = Param::Format::NCHW44; + auto t2 = benchmarker_int_nchw44.set_display(false) + .set_times(RUN) + .set_param(param) + .execl({{nchw44_shape, dtype::QuantizedS8(1.0)}, + {nchw44_dst_shape, dtype::QuantizedS8(1.0)}}); + + auto t3 = benchmarker_float_nchw44.set_display(false) + .set_times(RUN) + .set_param(param) + .exec({nchw44_shape, nchw44_dst_shape}); + + printf("{%zu %zu %zu %zu} \n" + "nchw_fp32={%.3f ms, %.3f Mflops}, " + "nchw44_int={%.3f ms, %.3f Mflops}, " + "nchw44_fp32={%.3f ms, %.3f Mflops, speed_up %f}\n\n", + n, c, h, w, t1 / RUN, calc_amount / (t1 / RUN * 1000), t2 / RUN, + calc_amount / (t2 / RUN * 1000), t3 / RUN, + calc_amount / (t3 / RUN * 1000), t1 / t3); + }; + + run(1, 128, 25, 25, param::AdaptivePooling::Mode::AVERAGE); +} +} // namespace +TEST_F(ARM_COMMON, BENCHMARK_GLOBAL_POOLING_NCHW44_FP32) { + benchmark_globalpooling_nchw44_fp32(handle()); +} + +#endif +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/adaptive_pooling.h b/dnn/test/common/adaptive_pooling.h index 632f346d..214bb05e 100644 --- a/dnn/test/common/adaptive_pooling.h +++ b/dnn/test/common/adaptive_pooling.h @@ -40,6 +40,36 @@ inline std::vector get_args() { return args; } +inline std::vector get_args_nchw44() { + std::vector args; + using Param = param::AdaptivePooling; + using Mode = param::AdaptivePooling::Mode; + + for (size_t i = 36; i < 40; ++i) { + args.emplace_back( + Param{Mode::AVERAGE, Param::Format::NCHW44}, + TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 4, i - 2, 4}); + args.emplace_back( + Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4}, + TensorShape{2, 3, i - 4, i - 2, 4}); + args.emplace_back( + Param{Mode::AVERAGE, Param::Format::NCHW44}, + TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, 1, 1, 4}); + args.emplace_back( + Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4}, + TensorShape{2, 3, 1, 1, 4}); + } + + for (size_t i = 5; i < 10; ++i) { + args.emplace_back( + Param{Mode::AVERAGE, Param::Format::NCHW44}, + TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 3, i - 2, 4}); + args.emplace_back( + Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4}, + TensorShape{2, 3, i - 3, i - 2, 4}); + } + return args; +} } // namespace adaptive_pooling } // namespace test } // namespace megdnn diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 3053c434..77fce034 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -254,7 +254,9 @@ def optimize_for_inference(dest_vars, **kwargs): * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z input for inference on nvidia backend(this optimization pass will result in mismatch of the precision of output of training and - inference) + inference + * enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it. + ) """ inference_options = GraphOptimizeOptions() inference_optimize_layout_transform_map = { @@ -282,6 +284,8 @@ def optimize_for_inference(dest_vars, **kwargs): inference_options.fuse_conv_bias_with_z = True if kwargs.pop("enable_fuse_preprocess", False): inference_options.fuse_preprocess = True + if kwargs.pop("enable_fuse_grain", True): + inference_options.fuse_grain = True if kwargs: raise ValueError("unknown options: %s" % list(kwargs)) @@ -330,6 +334,8 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]: ret["enable_fuse_conv_bias_with_z"] = True if inference_options.fuse_preprocess: ret["enable_fuse_preprocess"] = True + if inference_options.fuse_grain: + ret["enable_fuse_grain"] = True return ret diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 398d925b..f26d4417 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -151,7 +151,9 @@ class Network: * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z input for inference on nvidia backend(this optimization pass will result in mismatch of the precision of output of training and - inference) + inference + * enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it. + ) """ if not isinstance(dest_vars, Sequence): @@ -221,7 +223,6 @@ class Network: logger.warning( '"output_names" is not supported in Network.dump, rename output vars directly' ) - if optimize_for_inference: out, optimize_options = G.optimize_for_inference(out, **kwargs) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 754f3f35..de7f28b9 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -292,7 +292,9 @@ void init_graph_rt(py::module m) { &_OptimizeForInferenceOptions::fuse_preprocess) .def_readwrite( "layout_transform", - &_OptimizeForInferenceOptions::layout_transform); + &_OptimizeForInferenceOptions::layout_transform) + .def_readwrite( + "fuse_grain", &_OptimizeForInferenceOptions::fuse_grain); py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform") .value("DEFAULT", _LayoutTransform::DEFAULT) diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index 243b7164..c8ac9e4a 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -47,6 +47,7 @@ def test_metadata(): "user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None}, "graph_modified": True, # True: Network.dump "optimized_for_inference": True, + "enable_fuse_grain": True, "enable_nchw4": True, "enable_ioc16": True, } diff --git a/lite/load_and_run/src/models/model_mdl.cpp b/lite/load_and_run/src/models/model_mdl.cpp index 804c6d63..8097e9fa 100644 --- a/lite/load_and_run/src/models/model_mdl.cpp +++ b/lite/load_and_run/src/models/model_mdl.cpp @@ -76,6 +76,13 @@ void ModelMdl::make_output_spec() { } m_asyc_exec = m_load_result.graph_compile(m_output_spec); + auto new_output_vars = m_asyc_exec->get_output_vars(); + mgb::cg::SymbolVarArray symbol_var_array; + symbol_var_array.reserve(new_output_vars.size()); + for (auto output_var : new_output_vars) { + symbol_var_array.emplace_back(output_var); + } + m_load_result.output_var_list = symbol_var_array; } std::shared_ptr& ModelMdl::reset_loader( diff --git a/lite/load_and_run/src/options/model_options.cpp b/lite/load_and_run/src/options/model_options.cpp index f138794c..72c6bc12 100644 --- a/lite/load_and_run/src/options/model_options.cpp +++ b/lite/load_and_run/src/options/model_options.cpp @@ -56,6 +56,43 @@ void PackModelOption::config_model( RuntimeParam& runtime_param, std::shared_ptr model) { CONFIG_MODEL_FUN; } +///////////////////// RawModelOption ////////////////////////// +std::shared_ptr RawModelOption::create_option() { + static std::shared_ptr option(new RawModelOption); + if (RawModelOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +RawModelOption::RawModelOption() { + m_option_name = "raw_model"; + if (!FLAGS_model_dump.empty()) + model_dump = FLAGS_model_dump; +} +bool RawModelOption::is_valid() { + return !FLAGS_model_dump.empty(); +} +void RawModelOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +template +void RawModelOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { + auto model_data = model->get_model_data(); + std::ofstream ofs(model_dump, std::ios::binary); + if (!ofs.is_open()) { + mgb_log_warn("can not open file %s to write model\n", model_dump.c_str()); + return; + } + ofs.write((char*)model_data.data(), model_data.size()); + ofs.close(); + mgb_log_warn("success write model to %s\n", model_dump.c_str()); + } +} ////////////////////// PackModel gflags //////////////////////// @@ -79,4 +116,8 @@ DEFINE_string( "https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" "pack-lite-model.html for more details."); -REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); \ No newline at end of file +///////////////////// RawModel gflags /////////////////////////// +DEFINE_string(model_dump, "", "The output file path of raw model."); + +REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); +REGIST_OPTION_CREATOR(dump_model, lar::RawModelOption::create_option); \ No newline at end of file diff --git a/lite/load_and_run/src/options/model_options.h b/lite/load_and_run/src/options/model_options.h index f8753424..ac5df7a7 100644 --- a/lite/load_and_run/src/options/model_options.h +++ b/lite/load_and_run/src/options/model_options.h @@ -3,7 +3,7 @@ #include "megbrain/graph/operator_node.h" #include "models/model.h" #include "option_base.h" - +DECLARE_string(model_dump); DECLARE_string(packed_model_dump); DECLARE_string(pack_info_json); DECLARE_string(pack_cache); @@ -36,4 +36,22 @@ private: std::string pack_model_cryption; bool is_fast_run_cache = true; }; + +class RawModelOption : public OptionBase { +public: + static bool is_valid(); + static std::shared_ptr create_option(); + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + std::string option_name() const override { return m_option_name; } + +private: + RawModelOption(); + + template + void config_model_internel(RuntimeParam&, std::shared_ptr); + + std::string m_option_name; + std::string model_dump; +}; } // namespace lar diff --git a/lite/load_and_run/src/options/optimize_options.cpp b/lite/load_and_run/src/options/optimize_options.cpp index 5b1516d4..b1506cc2 100644 --- a/lite/load_and_run/src/options/optimize_options.cpp +++ b/lite/load_and_run/src/options/optimize_options.cpp @@ -124,6 +124,60 @@ void WeightPreprocessOption::config_model( CONFIG_MODEL_FUN; } +///////////////////////// fuse grain optimize options /////////////// +bool FuseGrainOption::m_valid; +namespace lar { +template <> +void FuseGrainOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (m_fuse_grain) { + LITE_THROW("enable fuse-grain optimization not support in lite model"); + } + } +} + +template <> +void FuseGrainOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (m_fuse_grain) { + mgb_log_warn("enable fuse-grain optimization"); + graph_option.graph_opt.enable_fuse_grain(); + } + } +} +} // namespace lar + +FuseGrainOption::FuseGrainOption() { + m_option_name = "fuse_grain"; + m_fuse_grain = FLAGS_fuse_grain; + m_option = {{"fuse_grain", lar::Bool::make(false)}}; + std::static_pointer_cast(m_option["fuse_grain"]) + ->set_value(FLAGS_fuse_grain); +} + +bool FuseGrainOption::is_valid() { + return true; +} + +std::shared_ptr FuseGrainOption::create_option() { + static std::shared_ptr option(new FuseGrainOption); + if (FuseGrainOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void FuseGrainOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + m_fuse_grain = + std::static_pointer_cast(m_option["fuse_grain"])->get_value(); + CONFIG_MODEL_FUN; +} + ///// fuse conv bias and nonlinear activation opr optimize options //////// bool FuseConvBiasNonlinearOption::m_valid; namespace lar { @@ -579,6 +633,7 @@ void TensorRTOption::config_model( DEFINE_bool( enable_fuse_preprocess, false, "Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr"); +DEFINE_bool(fuse_grain, false, "Enable fusion grain opr to huge opr"); DEFINE_bool( weight_preprocess, false, "Execute operators with weight preprocess, which can optimize the " @@ -589,7 +644,7 @@ DEFINE_bool( "whether to fuse conv+bias+nonlinearity"); DEFINE_bool( enable_fuse_conv_bias_with_z, false, - "fuse conv,bias (elemwise add),z(elemwise add) into one opr " + "fuse conv, bias (elemwise add), z(elemwise add) into one opr " "(only support on GPU)"); ///////////////////////// graph retrict options ///////////////////////// @@ -636,6 +691,9 @@ REGIST_OPTION_VALIDATER(fuse_preprocess, lar::FusePreprocessOption::set_valid); REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid); +REGIST_OPTION_CREATOR(disable_fuse_grain, lar::FuseGrainOption::create_option); +REGIST_OPTION_VALIDATER(disable_fuse_grain, lar::FuseGrainOption::set_valid); + REGIST_OPTION_CREATOR( fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option); REGIST_OPTION_VALIDATER( diff --git a/lite/load_and_run/src/options/optimize_options.h b/lite/load_and_run/src/options/optimize_options.h index 18df9a25..34873321 100644 --- a/lite/load_and_run/src/options/optimize_options.h +++ b/lite/load_and_run/src/options/optimize_options.h @@ -5,6 +5,7 @@ #include "option_base.h" DECLARE_bool(enable_fuse_preprocess); +DECLARE_bool(fuse_grain); DECLARE_bool(weight_preprocess); DECLARE_bool(enable_fuse_conv_bias_nonlinearity); DECLARE_bool(enable_fuse_conv_bias_with_z); @@ -79,7 +80,31 @@ private: static bool m_valid; OptionValMap m_option; }; +///////////////////////// fuse grain options ////////////// +class FuseGrainOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + static void set_valid(bool val) { m_valid = val; }; + OptionValMap* get_option() override { return &m_option; } + +private: + FuseGrainOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + bool m_fuse_grain; + static bool m_valid; + OptionValMap m_option; +}; /////////////// fuse_conv_bias_nonlinearity optimize options /////////////// class FuseConvBiasNonlinearOption final : public OptionBase { public: diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h index aa8c7fac..d79c90c6 100644 --- a/src/core/include/megbrain/graph/bases.h +++ b/src/core/include/megbrain/graph/bases.h @@ -91,7 +91,7 @@ public: }; class OutputVarsUserData final : public mgb::UserDataContainer::UserData { - MGB_TYPEINFO_OBJ_DECL; + MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT; private: VarNodeArray m_output_vars; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 57e1a229..86e1b12f 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -91,6 +91,9 @@ struct GraphCommonOptimizeOptions { bool weight_preprocess = false; //! fuse preprocess patten, like astype + pad_channel + dimshuffle bool fuse_preprocess = false; + //! fuse_grain patten, replace grain ir with huge ir + bool fuse_grain = false; + enum LayoutTransform : uint32_t { DEFAULT, NCHW4, ///< compute using NCHW4 tensor format @@ -124,6 +127,7 @@ struct GraphCommonOptimizeOptions { SET(fuse_conv_bias_with_z); SET(fuse_preprocess); SET(weight_preprocess); + SET(fuse_grain); #undef SET #define SET(_trans, _trans_capital) \ GraphCommonOptimizeOptions& enable_##_trans() { \ diff --git a/src/gopt/impl/folding_golbal_pooling.cpp b/src/gopt/impl/folding_golbal_pooling.cpp new file mode 100644 index 00000000..afc2d5d9 --- /dev/null +++ b/src/gopt/impl/folding_golbal_pooling.cpp @@ -0,0 +1,180 @@ +#include "megbrain/gopt/inference.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/dnn/adaptive_pooling.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" +#include "megbrain/serialization/opr_shallow_copy.h" + +#include "megdnn/opr_param_defs.h" + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "../../core/impl/graph/cg_impl.h" +#include "./gopt_helper.h" + +#include "megbrain/utils/hash_ct.h" + +#include "midout.h" + +MIDOUT_DECL(megbrain_folding_global_pooling) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_folding_global_pooling, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + +using namespace mgb; +using namespace gopt; + +/* ==================== FoldingGlobalPoolingPass ================= */ +const char* FoldingGlobalPoolingPass::name() const { + return mgb_cstr_log("folding reduce mean pass"); +} + +void FoldingGlobalPoolingPass::apply(OptState& opt) const { + MIDOUT_B("FoldingGlobalPoolingPass::apply"); + + FindNext find_tool(opt); + + auto rewriter = opt.graph().make_rewriter(); + /** + * + * reshape+------>reduce(mean or max)+--->axis_add_remove*n + * || + * || + * || + * \/ + * adaptive_pooling(1,1) + */ + auto try_fuse_global_pooling_axis_add = [&rewriter, + &find_tool](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + MGB_MARK_USED_VAR(rewriter); + MGB_MARK_USED_VAR(find_tool); + + auto axis_modi = try_cast_as_op(opr); + CHECK_OR_RETURN(axis_modi); + CHECK_OR_RETURN(find_tool.used_count(axis_modi) <= 1); + auto output_shape = axis_modi->output(0)->shape(); + CHECK_OR_RETURN(output_shape.ndim == 4); + CHECK_OR_RETURN(output_shape[2] == output_shape[3] && output_shape[2] == 1); + + auto axis_input = axis_modi->input(0)->owner_opr(); + auto axis_modi_x = axis_input->try_cast_final(); + auto reduce = axis_input->try_cast_final(); + while (axis_modi_x) { + CHECK_OR_RETURN(find_tool.used_count(axis_modi_x) == 1); + auto axis_input_x = axis_modi_x->input(0)->owner_opr(); + reduce = axis_input_x->try_cast_final(); + axis_modi_x = axis_input_x->try_cast_final(); + } + + CHECK_OR_RETURN(reduce); + auto reduce_mode = reduce->param().mode; + CHECK_OR_RETURN( + reduce_mode == opr::Reduce::Param::Mode::MAX || + reduce_mode == opr::Reduce::Param::Mode::MEAN); + auto reduce_axis = reduce->param().axis; + CHECK_OR_RETURN(reduce_axis == 2) + + auto reshape = reduce->input(0)->owner_opr()->try_cast_final(); + CHECK_OR_RETURN(reshape); + + auto reshape_in_shape = reshape->input(0)->shape(); + auto reshape_out_shape = reshape->output(0)->shape(); + bool merge_hw = + reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 && + reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2]; + CHECK_OR_RETURN(merge_hw); + opr::AdaptivePooling::Param param; + if (reduce_mode == opr::Reduce::Param::Mode::MAX) { + param.mode = opr::AdaptivePooling::Param::Mode::MAX; + } else { + mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN); + param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE; + } + + auto new_node = opr::AdaptivePooling::make( + rewriter.get_var(reshape->input(0)), {1, 1}, param); + rewriter.replace_var( + axis_modi->output(0), new_node.node(), + mgb_cstr_log("replace reshape+reduce+add_axis -> adaptive pooling")); + return true; + }; + /** + * + * reshape+------>reduce(mean or max)+--->dimshuffle(0,1,-1,-1) + * || + * || + * || + * \/ + * adaptive_pooling(1,1) + */ + auto try_fuse_global_pooling_dimshuffle = [&rewriter, + &find_tool](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + MGB_MARK_USED_VAR(rewriter); + MGB_MARK_USED_VAR(find_tool); + + auto dimshuffle = try_cast_as_op(opr); + CHECK_OR_RETURN(dimshuffle); + auto patten_param = dimshuffle->param(); + + CHECK_OR_RETURN(patten_param.pattern_len == 4); + auto patten = patten_param.pattern; + CHECK_OR_RETURN( + patten[0] == 0 && patten[1] == 1 && patten[2] == -1 && patten[3] == -1); + auto axis_remove = + dimshuffle->input(0)->owner_opr()->try_cast_final(); + CHECK_OR_RETURN(axis_remove); + + auto reduce = axis_remove->input(0)->owner_opr()->try_cast_final(); + CHECK_OR_RETURN(reduce); + auto reduce_mode = reduce->param().mode; + CHECK_OR_RETURN( + reduce_mode == opr::Reduce::Param::Mode::MAX || + reduce_mode == opr::Reduce::Param::Mode::MEAN); + auto reduce_axis = reduce->param().axis; + CHECK_OR_RETURN(reduce_axis == 2) + + auto reshape = reduce->input(0)->owner_opr()->try_cast_final(); + CHECK_OR_RETURN(reshape); + + auto reshape_in_shape = reshape->input(0)->shape(); + auto reshape_out_shape = reshape->output(0)->shape(); + bool merge_hw = + reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 && + reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2]; + CHECK_OR_RETURN(merge_hw); + opr::AdaptivePooling::Param param; + if (reduce_mode == opr::Reduce::Param::Mode::MAX) { + param.mode = opr::AdaptivePooling::Param::Mode::MAX; + } else { + mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN); + param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE; + } + auto new_node = opr::AdaptivePooling::make( + rewriter.get_var(reshape->input(0)), {1, 1}, param); + rewriter.replace_var( + dimshuffle->output(0), new_node.node(), + mgb_cstr_log("replace reshape+reduce+dimshuffle -> adaptive pooling")); + return true; + }; + + auto on_opr = [&try_fuse_global_pooling_axis_add, + &try_fuse_global_pooling_dimshuffle, + &rewriter](OperatorNodeBase* opr) { + if (!try_fuse_global_pooling_axis_add(opr) && + !try_fuse_global_pooling_dimshuffle(opr)) { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + + MIDOUT_E +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/folding_reduce_mean.cpp b/src/gopt/impl/folding_reduce_mean.cpp new file mode 100644 index 00000000..bfea8f13 --- /dev/null +++ b/src/gopt/impl/folding_reduce_mean.cpp @@ -0,0 +1,117 @@ +#include "megbrain/gopt/inference.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" +#include "megbrain/serialization/opr_shallow_copy.h" + +#include "megdnn/opr_param_defs.h" + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "../../core/impl/graph/cg_impl.h" +#include "./gopt_helper.h" + +#include "megbrain/utils/hash_ct.h" + +#include "midout.h" + +MIDOUT_DECL(megbrain_folding_reduce_mean) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_folding_reduce_mean, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + +using namespace mgb; +using namespace gopt; + +/* ==================== FoldingReduceMeanPass ================= */ +const char* FoldingReduceMeanPass::name() const { + return mgb_cstr_log("folding reduce mean pass"); +} + +void FoldingReduceMeanPass::apply(OptState& opt) const { + MIDOUT_B("FoldingReduceMeanPass::apply"); + FindNext find_tool(opt); + + auto rewriter = opt.graph().make_rewriter(); + + /** + * reshape+---------->reduce(axis, sum)+--------->axis_remove+----------->true_div + * | ^ + * | | + * +--------------> get_var_shape(axis)+------------>type_cvt(fp32)+-------+ + * || + * || + * \/ + * reshape+-------->reduce(axis, mean)+--------->axis_remove + * + * + **/ + auto try_fuse_reduce_mean = [&rewriter, &find_tool](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + MGB_MARK_USED_VAR(rewriter); + // check true_div + auto elemwise = try_cast_as_op(opr); + CHECK_OR_RETURN(elemwise); + auto mode_ok = elemwise->param().mode == opr::Elemwise::Mode::TRUE_DIV; + CHECK_OR_RETURN(mode_ok); + + auto input0 = elemwise->input(0)->owner_opr(); + auto remove_axis = input0->try_cast_final(); + auto reduce = input0->try_cast_final(); + if (remove_axis) { + reduce = remove_axis->input(0)->owner_opr()->try_cast_final(); + } + CHECK_OR_RETURN(reduce); + + bool reduce_sum = reduce->param().mode == opr::Reduce::Param::Mode::SUM; + CHECK_OR_RETURN(reduce_sum); + + auto input1 = elemwise->input(1)->owner_opr(); + auto typecvt = input1->try_cast_final(); + CHECK_OR_RETURN(typecvt); + auto is_typecvt_f32 = typecvt->param().enumv() == DTypeEnum::Float32; + CHECK_OR_RETURN(is_typecvt_f32); + + auto get_var_shape = + typecvt->input(0)->owner_opr()->try_cast_final(); + CHECK_OR_RETURN(get_var_shape); + + bool same_parent = + get_var_shape->input(0)->owner_opr() == reduce->input(0)->owner_opr(); + CHECK_OR_RETURN(same_parent); + + CHECK_OR_RETURN( + find_tool.used_count(get_var_shape->input(0)->owner_opr()) == 2); + + bool same_axis = get_var_shape->param().axis == reduce->param().axis; + CHECK_OR_RETURN(same_axis); + + auto new_reduce_param = reduce->param(); + new_reduce_param.mode = opr::Reduce::Mode::MEAN; + auto new_node = + opr::Reduce::make(rewriter.get_var(reduce->input(0)), new_reduce_param); + if (remove_axis) { + new_node = opr::AxisAddRemove::make( + new_node, remove_axis->param(), remove_axis->config()); + } + rewriter.replace_var( + opr->output(0), new_node.node(), + mgb_cstr_log("replace reduce_sum+div_axis -> reduce_mean")); + return true; + }; + + auto on_opr = [&try_fuse_reduce_mean, &rewriter](OperatorNodeBase* opr) { + if (!try_fuse_reduce_mean(opr)) { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + MIDOUT_E +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index c75f83ad..becf75d2 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -722,7 +722,10 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( options.disable_##_option(); \ } \ } - + cb(fuse_grain, { + add_pass(); + add_pass(); + }); cb(fuse_preprocess, { add_pass(FuseNCHW4Int8Preprocess::make()); add_pass(); diff --git a/src/gopt/impl/gopt_helper.h b/src/gopt/impl/gopt_helper.h new file mode 100644 index 00000000..37ae516e --- /dev/null +++ b/src/gopt/impl/gopt_helper.h @@ -0,0 +1,35 @@ +#pragma once + +#include "../../core/impl/graph/cg_impl.h" +#include "megbrain/gopt/inference.h" +#include "megbrain/utils/hash_ct.h" + +namespace mgb { +namespace gopt { +namespace { +#define CHECK_OR_RETURN(x) \ + if (!(x)) { \ + return false; \ + } + +class FindNext { + using DepType = cg::OperatorNodeProp::DepType; + +public: + FindNext(OptState& opt) { + opt.graph().iter([&](OperatorNodeBase* opr) { + for (auto&& i : opr->node_prop().dep_map()) { + m_readers[i.first->owner_opr()].emplace_back(opr, i.second); + } + }); + } + size_t used_count(OperatorNodeBase* opr) { return m_readers[opr].size(); } + +private: + ThinHashMap>> + m_readers; +}; +} // namespace +} // namespace gopt + +} // namespace mgb \ No newline at end of file diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index fbdc84ef..ec833cfd 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -4,6 +4,7 @@ #include "megbrain/graph/event.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/blas.h" +#include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/local.h" @@ -1368,6 +1369,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { megdnn::param::Convolution::Format::NCHW88; megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format::NCHW88; + megdnn::param::AdaptivePooling::Format adapt_pooling_format = + megdnn::param::AdaptivePooling::Format::NCHW88; megdnn::param::Resize::Format resize_format = megdnn::param::Resize::Format::NCHW88; std::string convter_pass_name = "conv_format_nchw88"; @@ -1381,6 +1384,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; conv_format = megdnn::param::Convolution::Format::NCHW44; pooling_format = megdnn::param::Pooling::Format::NCHW44; + adapt_pooling_format = megdnn::param::AdaptivePooling::Format::NCHW44; resize_format = megdnn::param::Resize::Format::NCHW44; convter_pass_name = "conv_format_nchw44"; } @@ -1646,6 +1650,33 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { return new_opr; } }; + auto replace_adapt_pooling_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& pooling_opr = opr->cast_final_safe(); + mgb_throw_if( + pooling_opr.param().format != + opr::AdaptivePoolingForward::Param::Format::NCHW, + MegBrainError, + "ConvertFormat Pass only support converting NCHW to NCHWxx"); + VarNode* inp_0 = new_inp[0]; + VarNode* inp_1 = new_inp[1]; + //! if input is nchwxx + if (inp_0->shape().ndim == 5) { + auto new_param = pooling_opr.param(); + new_param.format = adapt_pooling_format; + auto new_pooling_opr = opr::AdaptivePoolingForward::make( + inp_0, inp_1, new_param, opr->config()); + mgb_assert( + new_pooling_opr.shape().ndim == 5, + "The pooling dst dim is not trans to nchwxx"); + return new_pooling_opr.node()->owner_opr(); + } else { + auto new_opr = + serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + return new_opr; + } + }; auto replace_resize_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); @@ -1763,6 +1794,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + replace_func[opr::AdaptivePooling::typeinfo()] = replace_adapt_pooling_opr; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index b19122b4..791f5a8e 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -327,6 +327,8 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { ret |= 1u << 4; if (fuse_preprocess) ret |= 1u << 5; + if (fuse_grain) + ret |= 1u << 6; return ret; } @@ -338,6 +340,7 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { ret.fuse_conv_bias_with_z = buf & 1u << 3; ret.weight_preprocess = buf & 1u << 4; ret.fuse_preprocess = buf & 1u << 5; + ret.fuse_grain = buf & 1u << 6; ret.layout_transform = (LayoutTransform)(buf >> 32); return ret; } @@ -477,6 +480,25 @@ public: void apply(OptState& opt) const override; }; #endif +/** + * \brief old megbrain support reduce_mean by reduce_sum and div, fuse it for efficient + * + */ +class FoldingReduceMeanPass final : public Pass { +public: + const char* name() const override; + void apply(OptState& opt) const override; +}; + +/** + * \brief fold reduce hw to global pooling, for nchwxx optimize + * + */ +class FoldingGlobalPoolingPass final : public Pass { +public: + const char* name() const override; + void apply(OptState& opt) const override; +}; /*! * \brief padding channel to enable fast int8/int4 support diff --git a/src/gopt/test/gopt_old_model.cpp b/src/gopt/test/gopt_old_model.cpp new file mode 100644 index 00000000..c0399182 --- /dev/null +++ b/src/gopt/test/gopt_old_model.cpp @@ -0,0 +1,249 @@ +#include "megbrain/opr/dnn/local.h" +#include "megbrain/test/helper.h" + +#include "megbrain/gopt/basic_arith.h" +#include "megbrain/gopt/gtrans.h" +#include "megbrain/gopt/inference.h" + +#include "megbrain/opr/basic_arith_wrapper.h" +#include "megbrain/opr/blas.h" +#include "megbrain/opr/dnn/adaptive_pooling.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/nn_int.h" +#include "megbrain/opr/tensor_gen.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" + +#include "./helper.h" +#include "megbrain/comp_node_env.h" + +#include "megdnn/tensor_format.h" + +#include +#include + +using namespace mgb; + +namespace { +//! find first the operator of specific type; raise exception if not found +template +T& find_opr(SymbolVar endpoint) { + T* found = nullptr; + auto cb = [&found](cg::OperatorNodeBase* opr) { + if (!found && opr->same_type()) { + found = &opr->cast_final_safe(); + } + }; + cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); + mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str()); + return *found; +} + +template +T& find_opr(SymbolVar endpoint, const std::string& node_name) { + T* found = nullptr; + auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) { + if (!found && opr->same_type() && opr->name() == node_name) { + found = &opr->cast_final_safe(); + } + }; + cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); + mgb_assert( + found, "not found opr %s from %s", node_name.c_str(), + endpoint.node()->name().c_str()); + return *found; +} + +template +size_t find_opr_num(SymbolVar endpoint) { + size_t opr_num = 0; + auto cb = [&opr_num](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + opr_num++; + } + }; + cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); + return opr_num; +} +} // namespace + +TEST(TestGoptOldModel, FoldingGlobalPooling) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + + auto host_x = gen({2, 3, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + opr::Convolution::Param param_conv; + param_conv.stride_h = param_conv.stride_w = 1; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 3, 3, 3}); + auto conv1 = + opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); + auto conv_n = opr::GetVarShape::make(conv1, 0); + auto conv_c = opr::GetVarShape::make(conv1, 1); + auto conv_h = opr::GetVarShape::make(conv1, 2); + auto conv_w = opr::GetVarShape::make(conv1, 3); + auto hxw = conv_h * conv_w; + auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); + + auto reshape1 = opr::Reshape::make(conv1, reshape_shape); + + opr::Reduce::Param param_reduce; + param_reduce.axis = 2; + param_reduce.mode = opr::Reduce::Mode::SUM; + auto reduce = opr::Reduce::make(reshape1, param_reduce); + auto reduce_remove_axis = opr::AxisAddRemove::make( + reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)}); + auto hw_count = opr::GetVarShape::make(reshape1, 2); + + auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32()); + auto true_div = reduce_remove_axis / fp32_hw_count; + auto y = opr::AxisAddRemove::make( + true_div, {opr::AxisAddRemove::AxisDesc::make_add(2), + opr::AxisAddRemove::AxisDesc::make_add(3)}); + + SymbolVar y_opt = y; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.fuse_grain = true; + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + ASSERT_EQ( + opr::AdaptivePooling::Param::Mode::AVERAGE, + find_opr(y_opt).param().mode); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile( + {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + +TEST(TestGoptOldModel, FoldingGlobalPooling2) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + + auto host_x = gen({2, 3, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + opr::Convolution::Param param_conv; + param_conv.stride_h = param_conv.stride_w = 1; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 3, 3, 3}); + auto conv1 = + opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); + auto conv_n = opr::GetVarShape::make(conv1, 0); + auto conv_c = opr::GetVarShape::make(conv1, 1); + auto conv_h = opr::GetVarShape::make(conv1, 2); + auto conv_w = opr::GetVarShape::make(conv1, 3); + auto hxw = conv_h * conv_w; + auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); + + auto reshape1 = opr::Reshape::make(conv1, reshape_shape); + + opr::Reduce::Param param_reduce; + param_reduce.axis = 2; + param_reduce.mode = opr::Reduce::Mode::SUM; + auto reduce = opr::Reduce::make(reshape1, param_reduce); + auto reduce_remove_axis = opr::AxisAddRemove::make( + reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)}); + auto hw_count = opr::GetVarShape::make(reshape1, 2); + + auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32()); + auto true_div = reduce_remove_axis / fp32_hw_count; + auto y = opr::Dimshuffle::make(true_div, {0, 1, -1, -1}); + + SymbolVar y_opt = y; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.fuse_grain = true; + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + ASSERT_EQ( + opr::AdaptivePooling::Param::Mode::AVERAGE, + find_opr(y_opt).param().mode); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling2.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile( + {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + +TEST(TestGoptOldModel, FoldingReduceMean) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + + auto host_x = gen({2, 3, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + opr::Convolution::Param param_conv; + param_conv.stride_h = param_conv.stride_w = 1; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 3, 3, 3}); + auto conv1 = + opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); + auto conv_n = opr::GetVarShape::make(conv1, 0); + auto conv_c = opr::GetVarShape::make(conv1, 1); + auto conv_h = opr::GetVarShape::make(conv1, 2); + auto conv_w = opr::GetVarShape::make(conv1, 3); + auto hxw = conv_h * conv_w; + auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); + + auto reshape1 = opr::Reshape::make(conv1, reshape_shape); + + opr::Reduce::Param param_reduce; + param_reduce.axis = 2; + param_reduce.mode = opr::Reduce::Mode::SUM; + auto reduce = opr::Reduce::make(reshape1, param_reduce); + auto hw_count = opr::GetVarShape::make(reshape1, 2); + + auto y = reduce / hw_count; + + SymbolVar y_opt = y; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.fuse_grain = true; + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + ASSERT_EQ( + opr::Reduce::Param::Mode::MEAN, find_opr(y_opt).param().mode); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file("TestGoptOldModel.FoldingReduceMean.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile( + {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); + + *host_x = *gen({2, 3, 16, 16}, cn); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 4c8d5ec1..45c43f84 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -7,6 +7,7 @@ #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/blas.h" +#include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/pooling.h" @@ -4113,6 +4114,71 @@ TEST(TestGoptInference, ConvertFormatNCHW44Reshape) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); } +TEST(TestGoptInference, ConvertFormatNCHW44GlobalPooling) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + + auto host_x1 = gen({1, 4, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + opr::Convolution::Param param_conv; + param_conv.stride_h = param_conv.stride_w = 1; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 4, 3, 3}); + auto conv1 = + opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); + auto conv_n = opr::GetVarShape::make(conv1, 0); + auto conv_c = opr::GetVarShape::make(conv1, 1); + auto conv_h = opr::GetVarShape::make(conv1, 2); + auto conv_w = opr::GetVarShape::make(conv1, 3); + auto hxw = conv_h * conv_w; + auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); + auto reshape1 = opr::Reshape::make(conv1, reshape_shape); + + opr::Reduce::Param param_reduce; + param_reduce.axis = 2; + param_reduce.mode = opr::Reduce::Mode::SUM; + auto reduce = opr::Reduce::make(reshape1, param_reduce); + auto reduce_remove_axis = opr::AxisAddRemove::make( + reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)}); + auto hw_count = opr::GetVarShape::make(reshape1, 2); + + auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32()); + auto reduce_mean = reduce_remove_axis / fp32_hw_count; + auto global_pool = opr::AxisAddRemove::make( + reduce_mean, {opr::AxisAddRemove::AxisDesc::make_add(2), + opr::AxisAddRemove::AxisDesc::make_add(3)}); + + opr::Elemwise::Param elem_param; + elem_param.mode = opr::Elemwise::Param::Mode::RELU; + auto y = opr::Elemwise::make({global_pool}, elem_param); + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_grain(); + options.enable_nchw44(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ( + opr::AdaptivePooling::Param::Format::NCHW44, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNCHW44GlobalPooling.json")); + HostTensorND host_y_opt, host_y; + auto func = graph->compile( + {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + //! meybe go to winograd in x86-32, so set error 1e-1 + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); +} + TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { HostTensorGenerator<> gen; auto cn = CompNode::load("cpu0"); diff --git a/src/opr/impl/dnn/adaptive_pooling.cpp b/src/opr/impl/dnn/adaptive_pooling.cpp index fe6e0c1a..826b7330 100644 --- a/src/opr/impl/dnn/adaptive_pooling.cpp +++ b/src/opr/impl/dnn/adaptive_pooling.cpp @@ -39,7 +39,7 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); auto src = shpinfo.shape_inp_shp.at(0); mgb_assert( - src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1), + (src.ndim == 4 || src.ndim == 5) && (oshp2d.ndim == 2 || oshp2d.ndim == 1), "shape mismatch for AdaptivePooling: src=%s, out2d=%s", src.to_string().c_str(), oshp2d.to_string().c_str()); @@ -57,8 +57,19 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( dest.shape[1] = oshp2d.shape[0]; dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; dest.shape[3] = src.shape[3]; + } else if ( + param_format == Param::Format::NCHW44 || + param_format == Param::Format::NCHW88) { + dest.ndim = 5; + dest.shape[0] = src.shape[0]; + dest.shape[1] = src.shape[1]; + dest.shape[2] = oshp2d.shape[0]; + dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; + dest.shape[4] = src.shape[4]; } else { - mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); + mgb_throw( + MegBrainError, "AdaptivePooling not support %d format", + (int)param_format); } } diff --git a/src/opr/test/dnn/adaptive_pooling.cpp b/src/opr/test/dnn/adaptive_pooling.cpp index 58dd3bfd..b615ce72 100644 --- a/src/opr/test/dnn/adaptive_pooling.cpp +++ b/src/opr/test/dnn/adaptive_pooling.cpp @@ -48,7 +48,6 @@ void run(Param::Mode mode) { Checker::RunOptions opt; opt.numdiff_max_err = 1e-2; - Checker checker{make_graph, fwd}; checker.set_input_allow_grad(1, false).set_input_generator(0, gen); checker.run({TensorShape{1, 1, 10, 7}, TensorShape{5, 4}}, opt);