From 80c4705317d9c85164d2540059053037e6bcfac2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 21 Jul 2020 16:11:58 +0800 Subject: [PATCH] perf(mgb): use midout in megbrain to reduce binary size GitOrigin-RevId: ddc8af79af90737cb6f55ae5c1fc95cba722eef9 --- src/CMakeLists.txt | 1 + src/gopt/impl/basic_arith/chain.cpp | 23 ++++++++++++- src/gopt/impl/basic_arith/inplace.cpp | 12 +++++++ src/gopt/impl/basic_arith/trans.cpp | 18 ++++++++++ src/gopt/impl/inference.cpp | 31 +++++++++++++++++ src/gopt/impl/misc.cpp | 28 ++++++++++++++++ src/gopt/impl/tensor_reformat.cpp | 26 ++++++++++++++- src/gopt/impl/weights_preprocess.cpp | 12 +++++++ src/opr-mm/impl/collective_comm.cpp | 2 ++ src/opr-mm/impl/io_remote.cpp | 2 ++ src/opr/impl/basic_arith.cpp | 11 +++++- src/opr/impl/blas.cpp | 13 ++++++++ src/opr/impl/cond.cpp | 4 +++ src/opr/impl/dnn/batch_norm.cpp | 2 ++ src/opr/impl/dnn/convolution.cpp | 63 ++++++++++++++++++++++++++++------- src/opr/impl/dnn/images2neibs.cpp | 2 ++ src/opr/impl/dnn/local.cpp | 6 ++++ src/opr/impl/dnn/lrn.cpp | 2 ++ src/opr/impl/dnn/pooling.cpp | 2 ++ src/opr/impl/dnn/roi_align.cpp | 2 ++ src/opr/impl/dnn/roi_pooling.cpp | 4 +++ src/opr/impl/imgproc.cpp | 4 +++ src/opr/impl/indexing.cpp | 26 +++++++++++++++ src/opr/impl/io.cpp | 2 ++ src/opr/impl/loop/forward.cpp | 2 ++ src/opr/impl/misc.cpp | 13 +++++++- src/opr/impl/muxing.cpp | 2 ++ src/opr/impl/rand.cpp | 16 +++++---- src/opr/impl/tensor_gen.cpp | 7 +++- src/opr/impl/tensor_manip.cpp | 23 ++++++++++++- src/opr/impl/utility.cpp | 12 +++++++ src/plugin/impl/opr_footprint.cpp | 13 ++++++++ 32 files changed, 361 insertions(+), 25 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 11a4cef9..40ed0dc3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -43,6 +43,7 @@ add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES}) target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) target_include_directories(megbrain PUBLIC $ + PRIVATE ${PROJECT_SOURCE_DIR}/third_party/midout/src ) foreach (INCPATH IN LISTS MGB_INC) target_include_directories(megbrain diff --git a/src/gopt/impl/basic_arith/chain.cpp b/src/gopt/impl/basic_arith/chain.cpp index c345e17f..654b3a46 100644 --- a/src/gopt/impl/basic_arith/chain.cpp +++ b/src/gopt/impl/basic_arith/chain.cpp @@ -15,6 +15,20 @@ #include +//! TODO: here has to be know some megdnn::opr when there is produced midout.h +//! fix it if there is another graceful way. +#include "megdnn/oprs.h" + +#include "megbrain/utils/hash_ct.h" +#include "midout.h" + +MIDOUT_DECL(megbrain_chain) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_chain, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + using namespace mgb; using namespace gopt; using namespace opr; @@ -132,6 +146,7 @@ const char* ExpandFusedArithPass::name() const { } void ExpandFusedArithPass::apply(OptState &opt) const { + MIDOUT_B("ExpandFusedArithPass::apply") auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&](OperatorNodeBase *opr) { using Mode = Elemwise::Mode; @@ -172,6 +187,7 @@ void ExpandFusedArithPass::apply(OptState &opt) const { opt.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ NormalizeArithChainPass ================ */ @@ -529,7 +545,9 @@ const char* NormalizeArithChainPass::name() const { } void NormalizeArithChainPass::apply(OptState &opt) const { + MIDOUT_B("NormalizeArithChainPass::apply") Impl{opt}; + MIDOUT_E } /* ================ ReorderArithChainPass ================ */ @@ -737,7 +755,9 @@ const char* ReorderArithChainPass::name() const { } void ReorderArithChainPass::apply(OptState &opt) const { + MIDOUT_B("ReorderArithChainPass::apply") Impl{*this, opt}; + MIDOUT_E } /* ================ ArithFusePass ================ */ @@ -944,8 +964,9 @@ const char* ArithFusePass::name() const { } void ArithFusePass::apply(OptState &opt) const { + MIDOUT_B("ArithFusePass::apply") Impl{opt}; + MIDOUT_E } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/src/gopt/impl/basic_arith/inplace.cpp b/src/gopt/impl/basic_arith/inplace.cpp index da62ee51..5d72d23d 100644 --- a/src/gopt/impl/basic_arith/inplace.cpp +++ b/src/gopt/impl/basic_arith/inplace.cpp @@ -19,6 +19,16 @@ #include +#include "megbrain/utils/hash_ct.h" +#include "midout.h" + +MIDOUT_DECL(megbrain_inplace) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_inplace, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + using namespace mgb; using namespace opr; using namespace gopt; @@ -150,8 +160,10 @@ bool gopt::has_inplace_basic_arith_opt(const cg::OperatorNodeBase& opr) { const inplace_optimize::OptimizerRegistry& inplace_optimize::optimizer_registry() { + MIDOUT_B("inplace_optimize::optimizer_registry") static OptimizerRegistry ret = make_optimizer_registry(); return ret; + MIDOUT_E } inplace_optimize::OptimizerRegistry diff --git a/src/gopt/impl/basic_arith/trans.cpp b/src/gopt/impl/basic_arith/trans.cpp index 705b50b7..6a52e41e 100644 --- a/src/gopt/impl/basic_arith/trans.cpp +++ b/src/gopt/impl/basic_arith/trans.cpp @@ -13,6 +13,20 @@ #include "megbrain/gopt/basic_arith.h" #include "megbrain/serialization/serializer.h" +//! TODO: here has to be know some megdnn::opr when there is produced midout.h +//! fix it if there is another graceful way. +#include "megdnn/oprs.h" + +#include "megbrain/utils/hash_ct.h" +#include "midout.h" + +MIDOUT_DECL(megbrain_trans) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_trans, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + using namespace mgb; using namespace gopt; @@ -284,7 +298,9 @@ const char* ArithMulDistributePass::name() const { } void ArithMulDistributePass::apply(OptState &opt) const { + MIDOUT_B("ArithMulDistributePass::apply") Impl{*this, opt}; + MIDOUT_E } /* ================ FinalArithTransformPass ================ */ @@ -488,7 +504,9 @@ const char* FinalArithTransformPass::name() const { } void FinalArithTransformPass::apply(OptState &opt) const { + MIDOUT_B("FinalArithTransformPass::apply") Impl{*this, opt}; + MIDOUT_E } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index fa9c33b5..88117d19 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -27,6 +27,7 @@ #include "megbrain/opr/imgproc.h" #include "megbrain/opr/nn_int.h" #include "megbrain/opr/tensor_gen.h" +#include "megbrain/utils/hash_ct.h" #include "megdnn/tensor_format.h" @@ -36,6 +37,16 @@ #include "megbrain/gopt/misc.h" +#include "megbrain/utils/hash_ct.h" +#include "midout.h" + +MIDOUT_DECL(megbrain_inference) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + using namespace mgb; using namespace gopt; @@ -430,7 +441,9 @@ ParamRedistributePass::Impl::Impl(OptState &state): } void ParamRedistributePass::apply(OptState &state) const { + MIDOUT_B("ParamRedistributePass::apply") Impl{state}; + MIDOUT_E } /* ================ ParamFusePass ================ */ @@ -512,6 +525,7 @@ const char* ParamFusePass::name() const { } void ParamFusePass::apply(OptState &state) const { + MIDOUT_B("ParamFusePass::apply") auto rewriter = state.graph().make_rewriter(); auto cg = state.graph().comp_graph(); @@ -613,6 +627,7 @@ void ParamFusePass::apply(OptState &state) const { state.graph().iter(replace_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ One2OneOprReplacePass ================ */ @@ -621,6 +636,7 @@ const char* ConvertF32ToF16Pass::name() const { } void ConvertF32ToF16Pass::apply(OptState& state) const { + MIDOUT_B("ConvertF32ToF16Pass::apply") state.set_var_replace_check_flag(m_var_replace_check_flag); auto rewriter = state.graph().make_rewriter(); VarNodeArray new_inp_cache; @@ -674,6 +690,7 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { auto opr = endpoints[0].node()->owner_opr(); state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE); rewriter.apply_inplace(); + MIDOUT_E } std::unique_ptr ConvertF32ToF16Pass::make( @@ -940,6 +957,7 @@ std::unique_ptr ConvertF32ToF16Pass::make( /* ================ ConvertFormatPass ================ */ void ConvertFormatPass::apply(OptState& state) const { + MIDOUT_B("ConvertFormatPass::apply") state.set_var_replace_check_flag(m_var_replace_check_flag); auto rewriter = state.graph().make_rewriter(); VarNodeArray new_inp_cache; @@ -994,9 +1012,11 @@ void ConvertFormatPass::apply(OptState& state) const { }; state.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { + MIDOUT_B("ConvertFormatPass::make") auto filter_mode = [](const megdnn::param::Convolution::Sparse conv_mode, const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode { @@ -1551,6 +1571,7 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { replace_func[opr::GroupLocalForward::typeinfo()] = relayout_first_inp_to_chw; return ret; + MIDOUT_E } /* ================ ConvertBatchNormPass ================ */ @@ -1559,6 +1580,7 @@ const char* ConvertBatchNormToElemwisePass::name() const { } void ConvertBatchNormToElemwisePass::apply(OptState& state) const { + MIDOUT_B("ConvertBatchNormToElemwisePass::apply") auto rewriter = state.graph().make_rewriter(); auto on_opr = [&](OperatorNodeBase* opr) { if (auto bn = try_cast_as_op(opr)) { @@ -1586,6 +1608,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { state.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ FuseConvBiasNonlinPass ================ */ @@ -1594,6 +1617,7 @@ const char* FuseConvBiasNonlinPass::name() const { } void FuseConvBiasNonlinPass::apply(OptState& state) const { + MIDOUT_B("FuseConvBiasNonlinPass::apply") std::unordered_map> m_deps; state.graph().iter([&m_deps](OperatorNodeBase* opr) { for (auto& inp : opr->input()) { @@ -1843,6 +1867,7 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { state.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ FuseConvBiasZPass ================ */ @@ -1851,6 +1876,7 @@ const char* FuseConvBiasZPass::name() const { } void FuseConvBiasZPass::apply(OptState& state) const { + MIDOUT_B("FuseConvBiasZPass::apply") UniqReaderCheck uniq_reader_check{state.graph()}; auto rewriter = state.graph().make_rewriter(); @@ -1977,6 +2003,7 @@ void FuseConvBiasZPass::apply(OptState& state) const { state.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ FuseDeconvCvtPass ================ */ @@ -1986,6 +2013,7 @@ const char* FuseDeconvCvtPass::name() const { void FuseDeconvCvtPass::apply(OptState& state) const { + MIDOUT_B("FuseDeconvCvtPass::apply") std::unordered_map> m_deps; state.graph().iter([&m_deps](OperatorNodeBase* opr) { for (auto& inp : opr->input()) { @@ -2036,6 +2064,7 @@ void FuseDeconvCvtPass::apply(OptState& state) const { state.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ ParamMergePass ================ */ @@ -2044,10 +2073,12 @@ const char* ParamMergePass::name() const { } void ParamMergePass::apply(OptState& opt_state) const { + MIDOUT_B("ParamMergePass::apply") param_merge( opt_state); param_merge(opt_state); + MIDOUT_E } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/misc.cpp b/src/gopt/impl/misc.cpp index af47b924..6ab3f155 100644 --- a/src/gopt/impl/misc.cpp +++ b/src/gopt/impl/misc.cpp @@ -19,6 +19,16 @@ #include "megbrain/serialization/opr_shallow_copy.h" #include "../../core/impl/graph/cg_impl.h" +#include "megbrain/utils/hash_ct.h" +#include "midout.h" + +MIDOUT_DECL(megbrain_misc) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_misc, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + using namespace mgb; using namespace gopt; @@ -29,6 +39,7 @@ const char* RemoveNonComputingOprPass::name() const { } void RemoveNonComputingOprPass::apply(OptState& opt) const { + MIDOUT_B("RemoveNonComputingOprPass::apply") auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&](OperatorNodeBase* opr) { auto type = opr->dyn_typeinfo(); @@ -75,6 +86,7 @@ void RemoveNonComputingOprPass::apply(OptState& opt) const { opt.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ ExpandVirtualGradPass ================ */ @@ -84,6 +96,7 @@ const char* ExpandVirtualGradPass::name() const { } void ExpandVirtualGradPass::apply(OptState& opt) const { + MIDOUT_B("ExpandVirtualGradPass::apply") #if MGB_ENABLE_GRAD opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); auto rewriter = opt.graph().make_rewriter(); @@ -111,6 +124,7 @@ void ExpandVirtualGradPass::apply(OptState& opt) const { #else MGB_MARK_USED_VAR(opt); #endif + MIDOUT_E } /* ================= DelayBroadcastPass ================ */ @@ -144,6 +158,7 @@ void DelayBroadcastPass::apply(OptState& opt) const { // remove them from the chain, and add them back right after the endpoint. // TypeCvt's order may change, so disable the check. + MIDOUT_B("DelayBroadcastPass::apply") opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); auto unique_reader_chk = UniqReaderCheck{opt.graph()}; @@ -325,6 +340,7 @@ void DelayBroadcastPass::apply(OptState& opt) const { opt.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ======================= RecompTypeCvtPass ====================== */ @@ -334,6 +350,7 @@ const char* RecompTypeCvtPass::name() const { } void RecompTypeCvtPass::apply(OptState& opt) const { + MIDOUT_B("RecompTypeCvtPass::apply") auto rewriter = opt.graph().make_rewriter(); auto allowed_typecvt = [](OperatorNodeBase* opr) -> OperatorNodeBase* { @@ -399,6 +416,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const { }; opt.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ======================= CombineAstypeAndReducePass ====================== */ @@ -408,6 +426,7 @@ const char* CombineAstypeAndReducePass::name() const { } void CombineAstypeAndReducePass::apply(OptState& opt) const { + MIDOUT_B("CombineAstypeAndReducePass::apply") auto rewriter = opt.graph().make_rewriter(); using DataType = opr::Reduce::Param::DataType; @@ -453,6 +472,7 @@ void CombineAstypeAndReducePass::apply(OptState& opt) const { opt.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /* ================ CondExecConstPredicateFolding ================ */ @@ -462,6 +482,7 @@ const char* CondExecConstPredicateFolding::name() const { void CondExecConstPredicateFolding::apply(OptState& opt) const { #if MGB_ENABLE_COND_EXEC + MIDOUT_B("CondExecConstPredicateFolding::apply") if (!cg::ExecutionMask::have_alive_instance()) { return; } @@ -605,6 +626,7 @@ void CondExecConstPredicateFolding::apply(OptState& opt) const { } rewriter.apply_inplace(); + MIDOUT_E #endif // MGB_ENABLE_COND_EXEC } @@ -632,6 +654,7 @@ bool RemoveRedundantTypeCvtPass::should_remove(DType A, DType B) { } void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { + MIDOUT_B("RemoveRedundantTypeCvtPass::apply") auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&](OperatorNodeBase* opr) { @@ -656,6 +679,7 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { opt.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } #if MGB_ENABLE_OPR_MM @@ -668,6 +692,7 @@ const char* PackAllReduceScanPass::name() const { } void PackAllReduceScanPass::apply(OptState& opt) const { + MIDOUT_B("PackAllReduceScanPass::apply") auto comp_graph = opt.graph().comp_graph(); if (comp_graph->options().allreduce_pack_max_size == 0) return; auto cb_scan = [this] (OperatorNodeBase* opr) { @@ -682,6 +707,7 @@ void PackAllReduceScanPass::apply(OptState& opt) const { } }; opt.graph().iter(cb_scan); + MIDOUT_E } bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { @@ -856,6 +882,7 @@ void PackAllReduceReplacePass::insert_packed_oprs( } void PackAllReduceReplacePass::apply(OptState& opt) const { + MIDOUT_B("PackAllReduceReplacePass::apply") // get graph options auto comp_graph = opt.graph().comp_graph(); size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; @@ -917,6 +944,7 @@ void PackAllReduceReplacePass::apply(OptState& opt) const { }; opt.graph().iter(cb_replace); rewriter.apply_inplace(); + MIDOUT_E } #else diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index ec8ad015..a5f1f79a 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -36,6 +36,16 @@ #endif #include "megbrain/gopt/misc.h" +#include "megbrain/utils/hash_ct.h" + +#include "midout.h" + +MIDOUT_DECL(megbrain_tensor_reformat) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_tensor_reformat, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); using namespace mgb; using namespace gopt; @@ -755,8 +765,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const { } void TensorReformatPass::apply(OptState& opt) const { + MIDOUT_B("TensorReformatPass::apply") insert_pass(opt); translate_pass(opt); + MIDOUT_E } /* ================ EnableTensorCorePass =============== */ @@ -773,6 +785,7 @@ VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var, std::unique_ptr EnableTensorCorePass::make_tensorcore_converter() { + MIDOUT_B("EnableTensorCorePass::make") // replace rule for conv bias opr auto replace_conv_bias_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { @@ -1111,6 +1124,7 @@ EnableTensorCorePass::make_tensorcore_converter() { replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; return ret; + MIDOUT_E } /* ================ EnableCHWN4Pass =============== */ @@ -1125,6 +1139,7 @@ VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var, } std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { + MIDOUT_B("EnableCHWN4Pass::make") auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); auto&& replace_func = ret->m_opr_replace_func; @@ -1381,6 +1396,7 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4; return ret; + MIDOUT_E } /* ================ EnableNCHW4Pass ================ */ @@ -1395,6 +1411,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, } std::unique_ptr EnableNCHW4Pass::make_nchw4_converter(){ + MIDOUT_B("EnableNCHW4Pass::make") auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); using RelayoutMode = RelayoutPlaceholder::LayoutType; @@ -1772,6 +1789,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter(){ replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; return ret; + MIDOUT_E } /* ================ EnableNchwxxPass =============== */ @@ -2140,6 +2158,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( size_t pack_c_size) { + MIDOUT_B("EnableNchwxxPass::make") auto ret = std::make_unique(pack_c_size); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); std::string convter_pass_name = "conv_format_nchw88"; @@ -2149,6 +2168,7 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( ret->fill_opr_convert_fun(pack_c_size); ret->set_name(convter_pass_name); return ret; + MIDOUT_E } /* ================ EnableNchw44DotPass =============== */ @@ -2164,6 +2184,7 @@ VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var, std::unique_ptr EnableNchw44DotPass::make_nchw44_dot_converter() { + MIDOUT_B("EnableNchw44DotPass::make") auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); //! First is whether the conv can trans to nchwxx, second is the filter @@ -2384,6 +2405,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; return ret; + MIDOUT_E } /* ==================== ShuffleShuffleRemovePass ================= */ @@ -2961,9 +2983,11 @@ const char* ShuffleShuffleRemovePass::name() const { } void ShuffleShuffleRemovePass::apply(OptState& opt) const { + MIDOUT_B("ShuffleShuffleRemovePass::apply") opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE | VarReplaceCheckFlag::CHECK_DTYPE); Impl{opt}; + MIDOUT_E } -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/src/gopt/impl/weights_preprocess.cpp b/src/gopt/impl/weights_preprocess.cpp index c9c872d4..25e6d535 100644 --- a/src/gopt/impl/weights_preprocess.cpp +++ b/src/gopt/impl/weights_preprocess.cpp @@ -14,6 +14,16 @@ #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/utils/hash_ct.h" +#include "midout.h" + +MIDOUT_DECL(megbrain_weight_preprocess) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_weight_preprocess, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + using namespace mgb; using namespace gopt; using namespace cg; @@ -23,6 +33,7 @@ const char* WinogradTransformReplacePass::name() const { } void WinogradTransformReplacePass::apply(OptState& opt) const { + MIDOUT_B("WinogradTransformReplacePass::apply") auto rewriter = opt.graph().make_rewriter(); ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; opt.graph().iter([&cvprop](OperatorNodeBase *opr) { @@ -174,6 +185,7 @@ void WinogradTransformReplacePass::apply(OptState& opt) const { opt.graph().iter(on_opr); rewriter.apply_inplace(); + MIDOUT_E } /** diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 9ae25222..f01094b3 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -855,10 +855,12 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const { return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CollectiveComm) { mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); return opr.grad(out_grad[0]); } +#endif /* ===================== shallow copy ===================== */ diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 49e3a5b2..0dead49c 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -109,6 +109,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { return prop; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(RemoteSend) { mgb_assert(opr.is_grad()); return RemoteRecv::make(opr.key() + ":grad", @@ -118,6 +119,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { opr.input(0)->shape(), opr.input(0)->dtype()) .node(); } +#endif /* ===================== RemoteRecv ===================== */ diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index a1c8e4bb..d3a00c29 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -552,6 +552,7 @@ void Elemwise::call_megdnn_opr_exec( opr->exec(inp, out); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Elemwise) { SymbolVar i[5]; SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), @@ -730,6 +731,7 @@ MGB_IMPL_OPR_GRAD(Elemwise) { result = -result; return result.node(); } +#endif VarNode* Elemwise::sum_grad_list(VarNode *wrt, VarNodeArray &grads) { mgb_assert(!grads.empty()); @@ -814,6 +816,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const { return ret; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(TypeCvt) { MGB_MARK_USED_VAR(wrt_idx); auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); @@ -826,6 +829,7 @@ MGB_IMPL_OPR_GRAD(TypeCvt) { } return TypeCvt::make(out_grad[0], opr.input(0)->dtype()).node(); } +#endif void TypeCvt::mem_plan_fwd_in2out_writable() { if (input(0)->dtype().size() == output(0)->dtype().size() && @@ -963,10 +967,12 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AddUpdate) { // actually valid, just not implemented return InvalidGrad::make(opr, wrt_idx); } +#endif /* =========================== Reduce =========================== */ @@ -1698,6 +1704,7 @@ void Reduce::create_megdnn_opr() { create_operator()); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Reduce) { for (size_t i = 1; i < opr.output().size(); ++ i) mgb_assert(!out_grad[i]); @@ -1733,7 +1740,7 @@ MGB_IMPL_OPR_GRAD(Reduce) { grad = TypeCvt::make(grad, iv.dtype()); return grad.node(); } - +#endif void Reduce::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); @@ -1783,11 +1790,13 @@ void PowC::init_output_static_infer_desc() { {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(PowC) { auto exp = opr.param().exp; return (exp * SymbolVar{out_grad[0]} * PowC::make(opr.input(0), exp - 1, opr.config())) .node(); } +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/blas.cpp b/src/opr/impl/blas.cpp index 81703453..ebc14294 100644 --- a/src/opr/impl/blas.cpp +++ b/src/opr/impl/blas.cpp @@ -106,6 +106,7 @@ void MatrixMul::scn_do_execute() { MGB_FINALLY({ tparam = this->param(); }); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MatrixMul) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -128,6 +129,7 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { } return grad.node(); } +#endif /* ================= BatchedMatrixMul ================= */ @@ -224,6 +226,7 @@ void BatchedMatrixMul::scn_do_execute() { MGB_FINALLY({ tparam = this->param(); }); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -251,6 +254,7 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { } return grad.node(); } +#endif /* ================= Dot ================= */ @@ -327,6 +331,7 @@ void Dot::add_input_layout_constraint() { input(1)->add_layout_constraint(check); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Dot) { auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); auto ishp0 = opr::GetVarShape::make(opr.input(0)), @@ -336,6 +341,7 @@ MGB_IMPL_OPR_GRAD(Dot) { Broadcast::make(mul(out_grad[0], other_input), max_ishp), wrt_idx ? ishp1 : ishp0).node(); } +#endif SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1, const OperatorNodeConfig &config) { @@ -350,6 +356,8 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") + +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MatrixInverse) { SymbolVar a = opr.output(0); // TODO: use unified MatrixMul interface when we have it @@ -364,6 +372,7 @@ MGB_IMPL_OPR_GRAD(MatrixInverse) { a_bnn); return da.reshape(a.symshape()).node(); } +#endif /* ================= SVD ================= */ @@ -386,6 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : } } +#ifdef MGB_ENABLE_GRAD namespace { /*! @@ -477,7 +487,9 @@ OP(*, {}, {}) #undef OP } // anonymous namespace +#endif +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SVD) { /** * The formula is copied from @@ -555,6 +567,7 @@ MGB_IMPL_OPR_GRAD(SVD) { I_n - matmul(v, v, param01))); return ret.reshape(a.symshape()).node(); } +#endif SymbolVarArray SVD::make(const SymbolVar& src, const Param& param, const OperatorNodeConfig& config) { diff --git a/src/opr/impl/cond.cpp b/src/opr/impl/cond.cpp index 2f9e4daf..03e3dd03 100644 --- a/src/opr/impl/cond.cpp +++ b/src/opr/impl/cond.cpp @@ -818,6 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input, return input; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CondExecMark) { if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { return nullptr; @@ -841,6 +842,7 @@ MGB_IMPL_OPR_GRAD(CondExecMark) { {1, grad_mode}, OperatorNodeConfig{}) ->output(0); } +#endif /* ============================= CondExecMerge ============================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMerge); @@ -1225,6 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { return ret; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CondExecMerge) { using Mode = CondExecMerge::Param::Mode; if (opr.param().mode == Mode::SUM_COND_OUT && @@ -1259,6 +1262,7 @@ MGB_IMPL_OPR_GRAD(CondExecMerge) { OperatorNodeConfig{og->comp_node()}) ->output(0); } +#endif void CondExecMerge::modify_grad_sum_list(VarNode* wrt, VarNodeArray& grads) { if (!ExecutionMask::have_alive_instance()) { diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index 41fd87eb..f85a994a 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -230,6 +230,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { } } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchNormForward) { mgb_assert(wrt_idx < 5); if (wrt_idx < 3) { @@ -242,6 +243,7 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { return nullptr; } } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward); diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 80719392..044577bc 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -18,6 +18,19 @@ #include "megdnn/oprs/utils.h" +//! TODO: here has to be know some megdnn::opr when there is produced midout.h +//! fix it if there is another graceful way. +#include "megdnn/oprs.h" + +#include "midout.h" + +MIDOUT_DECL(megbrain_opr_convolution) +#define MIDOUT_B(...) \ + MIDOUT_BEGIN(megbrain_opr_convolution, __VA_ARGS__) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + #include "../internal/megdnn_opr_wrapper.inl" #include @@ -230,6 +243,7 @@ class TimedProfiler { static constexpr int arity_in = OprArityTrait::arity_in; static constexpr int arity_out = OprArityTrait::arity_out; static constexpr int arity = OprArityTrait::arity; + using ConvTensorShapes = std::array; public: @@ -295,6 +309,7 @@ double TimedProfiler::init_timeout_setting() { template typename TimedProfiler::TResult TimedProfiler::prof_impl( const TParam& raw_param) { + MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_impl"))) auto&& param = raw_param.as_single_pod(); CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); auto megdnn_opr = intl::create_megdnn_opr(cn); @@ -401,14 +416,17 @@ typename TimedProfiler::TResult TimedProfiler::prof_impl( mgb_assert(ev_start->finished()); return TResult::from_pod(Result{ev_start->elapsed_time_until(*ev_end)}); + MIDOUT_E }; template void TimedProfiler::prof_init_device(const TParam& raw_param) { + MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_init_device"))) auto&& param = raw_param.as_single_pod(); CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); // wait for cuda init, so its time does not get accounted in timeout cn.sync(); + MIDOUT_E } /* =================== AlgoChooser =================== */ @@ -426,6 +444,7 @@ class AlgoChooser { static constexpr int arity_in = OprArityTrait::arity_in; static constexpr int arity_out = OprArityTrait::arity_out; static constexpr int arity = OprArityTrait::arity; + using ImplAlgo = typename Opr::Algorithm*; using MGBOpr = typename MegDNNOpr2MGBOpr::MGBOpr; using ConvTensorLayouts = std::array; @@ -473,8 +492,8 @@ class AlgoChooser { //! put first std::vector get_all_candidates() const { auto heu = choose_by_heuristic(); - auto&& ret = OprArityTrait::get_all_algorithms( - m_megdnn_opr, m_layouts); + auto&& ret = OprArityTrait::get_all_algorithms(m_megdnn_opr, + m_layouts); bool found = false; for (size_t i = 0; i < ret.size(); ++i) { if (ret[i] == heu) { @@ -491,7 +510,7 @@ class AlgoChooser { //! get candidate algos with workspace limit. std::vector get_all_candidates_with_workspace_limit() const { - auto && all_algos = get_all_candidates(); + auto&& all_algos = get_all_candidates(); auto opr = m_mgb_opr; auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( opr->owner_graph(), opr->comp_node(), @@ -633,16 +652,16 @@ AlgoChooserProfileCache::Result AlgoChooser::get_profile_result( algo->name(), str_on_inp_shape.c_str()); timer.reset(); MGB_TRY { cur_rst = ctx.profile_single_algo(algo, cur_timeout); } - MGB_CATCH(std::exception & exc, - { - mgb_log_warn("caught exception during %s: %s", - msg.c_str(), exc.what()); - continue; - }) + MGB_CATCH(std::exception & exc, { + mgb_log_warn("caught exception during %s: %s", msg.c_str(), + exc.what()); + continue; + }) MGB_CATCH(..., { mgb_log_warn("caught exception during %s", msg.c_str()); continue; - }) if (!cur_rst.valid()) { + }) + if (!cur_rst.valid()) { mgb_log_warn("timeout when %s; timeout setting: %.3fsec", msg.c_str(), cur_timeout); continue; @@ -680,6 +699,7 @@ void AlgoChooser::get_origin_param_and_layouts( template typename AlgoChooser::ImplAlgo AlgoChooser::choose_by_profile( ExeContext& ctx, bool require_reproducible, bool enable_update) { + MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) auto opr = ctx.mgb_opr(); if (opr->owner_graph()->options().no_profiling_on_shape_change) { auto algo = ctx.megdnn_opr()->execution_policy().algorithm; @@ -720,6 +740,7 @@ typename AlgoChooser::ImplAlgo AlgoChooser::choose_by_profile( opr->owner_graph(), opr->comp_node(), opr->execution_policy().workspace_limit)); mgb_trap(); + MIDOUT_E } template <> @@ -748,7 +769,7 @@ void AlgoChooser::ExeContext:: if (m_layouts[1].dtype.enumv() == DTypeEnum::QuantizedS8 && param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44) { if (winograd_preprocess_opr->param().format == - megdnn::param::MatrixMul::Format::MK4){ + megdnn::param::MatrixMul::Format::MK4) { winograd_preprocess_opr->param().compute_mode = ConvBias::Param::ComputeMode::FLOAT32; param.opr_param.compute_mode = @@ -941,6 +962,7 @@ void ConvolutionForward::init_output_dtype() { output(0)->dtype(output_dtype); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ConvolutionForward) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -960,6 +982,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionForward) { return grad.node(); } } +#endif size_t ConvolutionForward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, @@ -1086,6 +1109,7 @@ void ConvolutionBackwardData::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1101,6 +1125,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { } return nullptr; } +#endif /* ==================== ConvolutionBackwardFilter ==================== */ IMPL_CONV(ConvolutionBackwardFilter, "conv_bwd_filter"); @@ -1138,6 +1163,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes( megdnn_opr(), this); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1153,6 +1179,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { } return nullptr; } +#endif /* ==================== Convolution3DForward ==================== */ IMPL_CONV(Convolution3DForward, "conv3d_fwd"); @@ -1192,6 +1219,7 @@ void Convolution3DForward::init_output_dtype() { } } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Convolution3DForward) { mgb_assert(opr.param().data_type == Convolution3DForward::Param::DataType::FLOAT, @@ -1212,6 +1240,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DForward) { return grad.node(); } } +#endif size_t Convolution3DForward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, @@ -1285,6 +1314,7 @@ void Convolution3DBackwardData::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1300,6 +1330,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { } return nullptr; } +#endif /* ==================== Convolution3DBackwardFilter ==================== */ IMPL_CONV(Convolution3DBackwardFilter, "conv3d_bwd_filter"); @@ -1658,6 +1689,7 @@ size_t LocalShareForward::get_workspace_size_bytes( megdnn_opr(), this); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalShareForward) { mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, "only float data type supported for grad"); @@ -1677,6 +1709,7 @@ MGB_IMPL_OPR_GRAD(LocalShareForward) { return grad.node(); } } +#endif /* ===================== LocalShareBackwardData ==================== */ @@ -1737,6 +1770,7 @@ void LocalShareBackwardData::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1752,6 +1786,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { } return nullptr; } +#endif /* ==================== LocalShareBackwardFilter ==================== */ @@ -1792,6 +1827,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes( megdnn_opr(), this); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { mgb_assert(!out_grad[1]); if (wrt_idx == 0) { @@ -1805,6 +1841,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { } return nullptr; } +#endif /* ===================== DeformableConvForward ==================== */ @@ -1869,6 +1906,7 @@ size_t DeformableConvForward::get_workspace_size_bytes( megdnn_opr(), this); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(DeformableConvForward) { mgb_assert(opr.input(0)->dtype() == dtype::Float32(), "only float data type supported for grad"); @@ -1888,6 +1926,7 @@ MGB_IMPL_OPR_GRAD(DeformableConvForward) { SymbolVarArray grads = {grad_arr[0], filter_grad, grad_arr[1], grad_arr[2]}; return grads[wrt_idx].node(); } +#endif /* ==================== DeformableConvBackwardData ==================== */ @@ -2265,4 +2304,4 @@ void BatchConvBiasForward::init_output_format() { #undef IMPL_CONV #undef MGB_FOREACH_FASTRUN_OPR -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/src/opr/impl/dnn/images2neibs.cpp b/src/opr/impl/dnn/images2neibs.cpp index 594e8681..904ba503 100644 --- a/src/opr/impl/dnn/images2neibs.cpp +++ b/src/opr/impl/dnn/images2neibs.cpp @@ -20,11 +20,13 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Images2NeibsForward) { mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); return Images2NeibsBackward::make( out_grad[0], opr.input(0), opr.param()).node(); } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsBackward); MEGDNN_OPR_INIT2(Images2NeibsBackward, "images2neibs_grad", 1, false); diff --git a/src/opr/impl/dnn/local.cpp b/src/opr/impl/dnn/local.cpp index af5edb85..44211c4f 100644 --- a/src/opr/impl/dnn/local.cpp +++ b/src/opr/impl/dnn/local.cpp @@ -20,10 +20,13 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); MEGDNN_OPR_INIT2(LocalForward, "local") + +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LocalForward) { return intl::conv_grad( opr, wrt_idx, out_grad); } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalBackwardData); MEGDNN_OPR_INIT3(LocalBackwardData, "local_bwd_data", 2, false); @@ -34,10 +37,13 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") + +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(GroupLocalForward) { return intl::conv_grad( opr, wrt_idx, out_grad); } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalBackwardData); MEGDNN_OPR_INIT3(GroupLocalBackwardData, "glocal_bwd_data", 2, false); diff --git a/src/opr/impl/dnn/lrn.cpp b/src/opr/impl/dnn/lrn.cpp index d57151cb..e192af2a 100644 --- a/src/opr/impl/dnn/lrn.cpp +++ b/src/opr/impl/dnn/lrn.cpp @@ -20,12 +20,14 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); MEGDNN_OPR_INIT1(LRNForward, "lrn") +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(LRNForward) { mgb_assert(wrt_idx == 0); SymbolVar grad = LRNBackward::make( opr.input(0), opr.output(0), out_grad[0], opr.param()); return grad.node(); } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNBackward); MEGDNN_OPR_INIT3(LRNBackward, "lrn_bwd", 0, true); diff --git a/src/opr/impl/dnn/pooling.cpp b/src/opr/impl/dnn/pooling.cpp index 27e717c1..5045b498 100644 --- a/src/opr/impl/dnn/pooling.cpp +++ b/src/opr/impl/dnn/pooling.cpp @@ -19,12 +19,14 @@ using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); MEGDNN_OPR_INIT1(PoolingForward, "pooling") +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(PoolingForward) { mgb_assert(wrt_idx == 0); SymbolVar grad = PoolingBackward::make( opr.input(0), opr.output(0), out_grad[0], opr.param()); return grad.node(); } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward); MEGDNN_OPR_INIT3(PoolingBackward, "pooling_bwd", 0, true); diff --git a/src/opr/impl/dnn/roi_align.cpp b/src/opr/impl/dnn/roi_align.cpp index 21fae8dd..3c29417f 100644 --- a/src/opr/impl/dnn/roi_align.cpp +++ b/src/opr/impl/dnn/roi_align.cpp @@ -40,6 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois, src.node(), rois.node(), param, config); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIAlignForward) { if (out_grad[1]) { return InvalidGrad::make(opr, wrt_idx); @@ -55,6 +56,7 @@ MGB_IMPL_OPR_GRAD(ROIAlignForward) { return nullptr; } } +#endif /* ==================== ROIAlignBackward ==================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROIAlignBackward); diff --git a/src/opr/impl/dnn/roi_pooling.cpp b/src/opr/impl/dnn/roi_pooling.cpp index f3621375..ab2801d5 100644 --- a/src/opr/impl/dnn/roi_pooling.cpp +++ b/src/opr/impl/dnn/roi_pooling.cpp @@ -84,6 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes( input_shapes, output_shapes); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIPoolingForward) { if (out_grad[1] || wrt_idx == 2) { return InvalidGrad::make(opr, wrt_idx); @@ -98,6 +99,7 @@ MGB_IMPL_OPR_GRAD(ROIPoolingForward) { return nullptr; } } +#endif void ROIPoolingForward::scn_do_execute() { return intl::MegDNNOprMethInvoker:: @@ -146,6 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make( return all[0]; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 @@ -168,6 +171,7 @@ MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { } return nullptr; } +#endif /* ==================== DeformablePSROIPoolingBackward ==================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(DeformablePSROIPoolingBackward); diff --git a/src/opr/impl/imgproc.cpp b/src/opr/impl/imgproc.cpp index e5c407f9..0166e67f 100644 --- a/src/opr/impl/imgproc.cpp +++ b/src/opr/impl/imgproc.cpp @@ -127,6 +127,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { mgb_assert(opr.input().size() == 3, "backward with mat_idx is currently unsupported"); @@ -145,6 +146,7 @@ MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { } else return InvalidGrad::make(opr, wrt_idx); } +#endif /* ====================== WarpPerspectiveBackwardData ====================== */ @@ -234,6 +236,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray &deps) { record_megdnn_opr(deps); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ResizeForward) { mgb_assert(opr.input().size() == 2); if (wrt_idx == 0) { @@ -243,6 +246,7 @@ MGB_IMPL_OPR_GRAD(ResizeForward) { } else return InvalidGrad::make(opr, wrt_idx); } +#endif /* ====================== ResizeBackward ====================== */ diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index 133f175f..f7371c6f 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -83,6 +83,7 @@ void IndexingOneHot::init_output_dtype() { output(0)->dtype(input(0)->dtype()); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingOneHot) { if (wrt_idx == 0) { return IndexingSetOneHot::make( @@ -91,6 +92,7 @@ MGB_IMPL_OPR_GRAD(IndexingOneHot) { } return InvalidGrad::make(opr, wrt_idx); } +#endif /* ==================== IndexingSetOneHot ==================== */ @@ -133,6 +135,7 @@ void IndexingSetOneHot::scn_do_execute() { intl::get_megdnn_workspace_from_var(output(1))); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; if (wrt_idx == 0) { @@ -144,6 +147,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { } return InvalidGrad::make(opr, wrt_idx); } +#endif size_t IndexingSetOneHot::get_workspace_size_bytes( const TensorShapeArray &input_shapes, @@ -165,6 +169,7 @@ void IndexingRemap::init_output_dtype() { output(0)->dtype(input(0)->dtype()); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingRemap) { if (wrt_idx == 1) return InvalidGrad::make(opr, wrt_idx); @@ -172,6 +177,7 @@ MGB_IMPL_OPR_GRAD(IndexingRemap) { return IndexingRemapBackward::make( out_grad[0], opr.input(1), opr.input(0), opr.param()).node(); } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward); MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false); @@ -460,6 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -468,7 +475,9 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0), opr.index_desc()).node(); } +#endif +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -479,7 +488,9 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { } return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); } +#endif +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -488,6 +499,7 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { } return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); } +#endif /* ============================= Mesh Indexing ============================ */ @@ -498,6 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET( BatchedMeshIndexing, "batched_mesh_indexing", false, output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MeshIndexing) { if (wrt_idx != 0) { return InvalidGrad::make(opr, wrt_idx); @@ -507,6 +520,9 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) { opr.index_desc()) .node(); } +#endif + +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { if (wrt_idx != 0) { return InvalidGrad::make(opr, wrt_idx); @@ -516,11 +532,14 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { opr.index_desc()) .node(); } +#endif /* ========================= IncrMeshIndexing ========================= */ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", false); + +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { if (wrt_idx > 2) { return opr::InvalidGrad::make(opr, wrt_idx); @@ -530,9 +549,11 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { } return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); } +#endif MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, "batched_incr_mesh_indexing", false); +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { if (wrt_idx > 2) { return opr::InvalidGrad::make(opr, wrt_idx); @@ -542,10 +563,12 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { } return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); } +#endif /* ======================== SetMeshIndexing =========================== */ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SetMeshIndexing) { if (wrt_idx >= 2) { return opr::InvalidGrad::make(opr, wrt_idx); @@ -560,9 +583,11 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) { return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); } } +#endif MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, "batched_set_mesh_indexing", false); +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { if (wrt_idx > 2) { return opr::InvalidGrad::make(opr, wrt_idx); @@ -578,5 +603,6 @@ MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { .node(); } } +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index c535c837..5e34884a 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -764,11 +764,13 @@ Copy::NodeProp* Copy::do_make_node_prop() const { return rst; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Copy) { mgb_assert(wrt_idx == 0); return Copy::make(out_grad[0], OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); } +#endif void Copy::add_input_layout_constraint() { if (input(0)->comp_node() != output(0)->comp_node()) { diff --git a/src/opr/impl/loop/forward.cpp b/src/opr/impl/loop/forward.cpp index d4e878d1..b52344ec 100644 --- a/src/opr/impl/loop/forward.cpp +++ b/src/opr/impl/loop/forward.cpp @@ -268,9 +268,11 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) { return gopr->get_grad_var(wrt_idx); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Loop) { return Loop::grad(const_cast(opr), wrt_idx, out_grad); } +#endif cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { auto prop = LoopImpl::do_make_node_prop(); diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 92673f2f..3f738d4e 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -48,23 +48,26 @@ namespace intl { /* ================= Argmxx ================= */ +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Argmax) { MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(opr); mgb_assert(!wrt_idx); return nullptr; } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); MEGDNN_OPR_INIT1(Argmax, "argmax") - +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Argmin) { MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(opr); mgb_assert(!wrt_idx); return nullptr; } +#endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmin); MEGDNN_OPR_INIT1(Argmin, "argmin") @@ -84,12 +87,14 @@ std::array ArgsortForward::make( return {node->output(0), node->output(1)}; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ArgsortForward) { mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); if (!out_grad[0]) return nullptr; return ArgsortBackward::make(out_grad[0], opr.output(1)).node(); } +#endif /* ================= ArgsortBackward ================= */ @@ -107,12 +112,14 @@ Cumsum::Cumsum(VarNode* opr, const Param& param, add_input({opr}, AddInputSortType::CUR_ADDED); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Cumsum) { mgb_assert(out_grad[0] && !out_grad[1]); auto param = opr.param(); param.reverse = !param.reverse; return Cumsum::make(out_grad[0], param).node(); } +#endif SymbolVar Cumsum::make(SymbolVar opr, const Param& param, const OperatorNodeConfig& config) { @@ -170,6 +177,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask, } } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CondTake) { mgb_assert(out_grad.size() == 3 && !out_grad[2]); if (wrt_idx == 0 && out_grad[0]) { @@ -181,6 +189,7 @@ MGB_IMPL_OPR_GRAD(CondTake) { } return nullptr; } +#endif std::array CondTake::make( SymbolVar data, SymbolVar mask, @@ -318,6 +327,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) { record_megdnn_opr(deps); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(TopK) { if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); @@ -334,5 +344,6 @@ MGB_IMPL_OPR_GRAD(TopK) { return ArgsortBackward::make(out_grad[0], opr.output(1), opr.input(0)) .node(); } +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/muxing.cpp b/src/opr/impl/muxing.cpp index eb32f194..571dc7d1 100644 --- a/src/opr/impl/muxing.cpp +++ b/src/opr/impl/muxing.cpp @@ -316,9 +316,11 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) { OperatorNodeConfig().comp_node_arr(sp_cn))); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AllGather) { return const_cast(opr).grad(out_grad); } +#endif void AllGather::on_output_comp_node_stream_changed() { } diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 538d1167..b5c16c59 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -112,19 +112,21 @@ UniqPtrWithCN RNGOpr::create_megdnn_opr() { return opr; } -#define IMPL(_cls) \ -template class RNGOpr<::megdnn::_cls>; \ -MGB_IMPL_OPR_GRAD(_cls) { \ - MGB_MARK_USED_VAR(out_grad); \ - return InvalidGrad::make(opr, wrt_idx); \ -} \ - +#define IMPL(_cls) \ + MGB_IMPL_OPR_GRAD(_cls) { \ + MGB_MARK_USED_VAR(out_grad); \ + return InvalidGrad::make(opr, wrt_idx); \ + } namespace mgb { namespace opr { namespace intl { +template class RNGOpr<::megdnn::GaussianRNG>; +template class RNGOpr<::megdnn::UniformRNG>; +#ifdef MGB_ENABLE_GRAD IMPL(GaussianRNG); IMPL(UniformRNG); +#endif } } } diff --git a/src/opr/impl/tensor_gen.cpp b/src/opr/impl/tensor_gen.cpp index 6505fe07..24967f83 100644 --- a/src/opr/impl/tensor_gen.cpp +++ b/src/opr/impl/tensor_gen.cpp @@ -46,11 +46,13 @@ void Alloc::outshape_by_symvar_do_get_output_shape( void Alloc::scn_do_execute() { } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Alloc) { MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(out_grad); return InvalidGrad::make(opr, 0); } +#endif /* ======================= Linspace ======================= */ @@ -123,6 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) { std::make_unique(std::move(m_megdnn_opr))); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Linspace) { if (wrt_idx == 2) return InvalidGrad::make(opr, wrt_idx); @@ -134,6 +137,7 @@ MGB_IMPL_OPR_GRAD(Linspace) { return opr::Dot::make(og, opr::Linspace::make(i0, i1, opr.input(2), opr.param())).node(); } +#endif /* ======================= Eye ======================= */ @@ -195,9 +199,10 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) { std::make_unique(std::move(m_megdnn_opr))); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Eye) { return InvalidGrad::make(opr, wrt_idx); } - +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index 2abdd581..e15d909b 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -165,12 +165,13 @@ void GetVarShape::init_output_static_infer_desc() { mgr.register_value_infer(output(0), {SourceType::DEP, deps, infer_value}); } - +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(GetVarShape) { MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(out_grad); return nullptr; } +#endif SymbolVar GetVarShape::make(const VarNodeArrayView& inp, Param param, const OperatorNodeConfig& config) { @@ -362,11 +363,13 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp, inp.node(), tshp.node(), unspec_axis, config); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Reshape) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); } +#endif Maybe Reshape::reshapebrdcast_get_dest_layout( const TensorLayout &src, const TensorShape &tshape) const { @@ -429,12 +432,14 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp, inp.node(), tshp.node(), config); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Broadcast) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); return Reduce::make(out_grad.at(0), Reduce::Mode::SUM, GetVarShape::make(opr.input(0))).node(); } +#endif Maybe Broadcast::reshapebrdcast_get_dest_layout( const TensorLayout &src, const TensorShape &tshape) const { @@ -562,9 +567,11 @@ VarNode* Dimshuffle::grad( return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Dimshuffle) { return opr.grad(wrt_idx, out_grad); } +#endif // f}}} @@ -631,10 +638,12 @@ AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const { return ret; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AxisAddRemove) { MGB_MARK_USED_VAR(wrt_idx); return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); } +#endif // f}}} @@ -642,6 +651,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) { MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Subtensor) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -650,6 +660,7 @@ MGB_IMPL_OPR_GRAD(Subtensor) { SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0), opr.index_desc()).node(); } +#endif void Subtensor::init_output_static_infer_desc() { using namespace cg::static_infer; @@ -783,6 +794,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { sub.copy_from_fixlayout(val); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SetSubtensor) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -793,6 +805,7 @@ MGB_IMPL_OPR_GRAD(SetSubtensor) { } return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); } +#endif // f}}} @@ -813,6 +826,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { opr->exec(sub.as_megdnn(), val.as_megdnn()); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IncrSubtensor) { if (wrt_idx >= 2) return InvalidGrad::make(opr, wrt_idx); @@ -821,6 +835,7 @@ MGB_IMPL_OPR_GRAD(IncrSubtensor) { } return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); } +#endif // f}}} @@ -1085,6 +1100,7 @@ void Split::do_execute(ExecEnv &env) { } } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Split) { if (wrt_idx) return InvalidGrad::make(opr, wrt_idx); @@ -1100,6 +1116,7 @@ MGB_IMPL_OPR_GRAD(Split) { return Concat::make(grad, opr.options().axis, OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); } +#endif void Split::mem_plan_fwd_in2out_readonly() { m_readonly_fwd_called = true; @@ -1236,6 +1253,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis, axis, config); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Concat) { auto axis = opr.axis(); mgb_assert(out_grad.size() == 1); @@ -1250,6 +1268,7 @@ MGB_IMPL_OPR_GRAD(Concat) { OperatorNodeConfig().comp_node_arr(comp_node)); return cg::to_var_node_array(ret); } +#endif void Concat::scn_do_execute() { auto&& out = output(0)->dev_tensor(); @@ -1507,6 +1526,7 @@ void ParamPackSplit::init_output_static_infer_desc() { } } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ParamPackSplit) { mgb_assert(out_grad.size() == opr.output().size()); SmallVector grad; @@ -1531,6 +1551,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { OperatorNodeConfig{}.follow_comp_node(opr.input(0))) .node(); } +#endif // f}}} /* f{{{ ======================= RelayoutFormat ======================= */ diff --git a/src/opr/impl/utility.cpp b/src/opr/impl/utility.cpp index ca2d174f..e9f1f6e0 100644 --- a/src/opr/impl/utility.cpp +++ b/src/opr/impl/utility.cpp @@ -255,9 +255,11 @@ void MarkDynamicVar::scn_do_execute() { o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MarkDynamicVar) { return MarkDynamicVar::make(out_grad.at(0)).node(); } +#endif MarkDynamicVar::MarkDynamicVar(VarNode *node, const OperatorNodeConfig &config): Super{node->owner_graph(), config, "mark_dyn", {node}} @@ -381,10 +383,12 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) { } } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CallbackInjector) { MGB_MARK_USED_VAR(wrt_idx); return out_grad.at(0); } +#endif /* ===================== MarkNoBroadcastElemwise ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkNoBroadcastElemwise); @@ -404,9 +408,11 @@ SymbolVar MarkNoBroadcastElemwise::make( input.node(), config); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { return out_grad.at(0); } +#endif /* ===================== Identity ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); @@ -429,9 +435,11 @@ SymbolVar Identity::make( return input.insert_single_output_opr(input.node(), config); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Identity) { return out_grad.at(0); } +#endif /* ===================== AssertEqual ===================== */ @@ -530,6 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter, input.node(), grad_getter, config); } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SetGrad) { MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(out_grad); @@ -538,6 +547,7 @@ MGB_IMPL_OPR_GRAD(SetGrad) { "var returned by grad_getter belongs to a different comp graph"); return grad.node(); } +#endif /* ===================== InvalidGrad ===================== */ @@ -690,6 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const { return ret; } +#ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(VirtualLoss) { mgb_assert(out_grad.size() == 1); auto mid = opr.input().size() / 2; @@ -698,6 +709,7 @@ MGB_IMPL_OPR_GRAD(VirtualLoss) { } return nullptr; } +#endif #else VarNode* InvalidGrad::make(const OperatorNodeBase&, size_t) { diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 990aa787..9a70b767 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -24,6 +24,16 @@ #include "megdnn/opr_param_json.h" #endif +#include "megbrain/utils/hash_ct.h" +#include "midout.h" + +MIDOUT_DECL(megbrain_opr_footprint) +#define MIDOUT_B(...) \ + MIDOUT_BEGIN(megbrain_opr_footprint, __VA_ARGS__) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + using namespace mgb; namespace { @@ -581,9 +591,12 @@ std::shared_ptr opr_param_json_func( template void OprFootprint::add_single_comp_footprint() { + MIDOUT_B(OprType, + midout_iv(MGB_HASH_STR("OprFootprint::add_single_comp_footprint"))) auto&& record = m_type2comp_footprint.emplace(OprType::typeinfo(), opr_footprint_func); mgb_assert(record.second, "duplicate opr typeinfo"); + MIDOUT_E } #if MGB_ENABLE_JSON