From 075c4562bcf35eadd1b5287c1186394a66ba7aca Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 3 Apr 2020 11:21:29 +0800 Subject: [PATCH] fix(mgb/core): fix stream propagation in seq_comp_node_opt for ForwardInToOutput op, StreamPropType of output(0) should be forwarded as input(0) GitOrigin-RevId: 13e95879612db46635a304c0e25990fb8ba30512 --- src/core/impl/graph/seq_comp_node_opt_impl.cpp | 105 ++++++++++++++++----- src/core/impl/graph/seq_comp_node_opt_impl.h | 3 + .../include/megbrain/graph/seq_comp_node_opt.h | 7 ++ src/core/test/graph/misc.cpp | 28 ++++++ src/opr/impl/internal/identical_fwd.cpp | 10 ++ .../include/megbrain/opr/internal/identical_fwd.h | 11 +++ 6 files changed, 138 insertions(+), 26 deletions(-) diff --git a/src/core/impl/graph/seq_comp_node_opt_impl.cpp b/src/core/impl/graph/seq_comp_node_opt_impl.cpp index 7a4883fc..5f166490 100644 --- a/src/core/impl/graph/seq_comp_node_opt_impl.cpp +++ b/src/core/impl/graph/seq_comp_node_opt_impl.cpp @@ -63,29 +63,21 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream( } ThinHashMap changed_vars; - auto cb = [this, &changed_vars](OperatorNodeBase *opr) { - if (opr->node_prop().contain( - OperatorNodeBase::NodeProp::Flag:: - DISALLOW_COMP_NODE_OPTIMIZE)) { - return; + std::pair prop_type_storage; + std::pair> input_props_storage; + + // both `propagate_single_opr` and `get_input_props` might be called any number + // of times(>=0) with the same \p opr in a cb(opr) function call, so we cache + // the result of the first call. + auto propagate_single_opr = [&](OperatorNodeBase* opr) { + mgb_assert(opr); + if (prop_type_storage.first == opr) { + return prop_type_storage.second; } + prop_type_storage.first = opr; - // first check whether any output var is registered for change - bool output_changed = false; - for (auto i: opr->output()) { - auto iter = m_var2prop_type.find(i); - if (iter != m_var2prop_type.end()) { - output_changed = true; - var_to_specific_stream(i, iter->second.stream); - changed_vars[i] = iter->second; - } - } - if (output_changed) - return; - - // check inputs bool any_strong_changed = false, all_weak_changed = true, - all_weak_changed_valid = false; + all_weak_changed_valid = false; auto &&dep_map = opr->node_prop().dep_map(); ThinHashSet inp_streams; @@ -100,6 +92,7 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream( if (iter == changed_vars.end()) { all_weak_changed = false; } else { + mgb_assert(iter->second.prop_type != StreamPropType::NONE); if (iter->second.prop_type == StreamPropType::STRONG) { any_strong_changed = true; } else { @@ -109,21 +102,75 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream( } } + auto type = StreamPropType::NONE; + int stream = 0; if (any_strong_changed || (all_weak_changed && all_weak_changed_valid)) { - auto type = any_strong_changed ? + type = any_strong_changed ? StreamPropType::STRONG : StreamPropType::WEAK; - int stream = 0; int copy_stream = CompNode::Stream::COPY; int nccl_stream = CompNode::Stream::NCCL; if (inp_streams.count(copy_stream)) stream = copy_stream; else if (inp_streams.count(nccl_stream)) stream = nccl_stream; - mgb_assert(stream != 0); - for (auto i: opr->output()) { - var_to_specific_stream(i, stream); - changed_vars[i] = StreamPropType{stream, type}; + mgb_assert(type != StreamPropType::NONE && stream != 0); + } + return prop_type_storage.second = StreamPropType{stream, type}; + }; + + auto get_input_props = [&](OperatorNodeBase *opr) { + mgb_assert(opr); + if (input_props_storage.first == opr) { + return input_props_storage.second; + } + input_props_storage.first = opr; + + auto &&props = input_props_storage.second; + props.clear(); + for (auto i : opr->input()) { + auto &&iter = changed_vars.find(i); + if (iter != changed_vars.end()) { + props.push_back(iter->second); + } else { + props.push_back(StreamPropType{0, StreamPropType::NONE}); + } + } + return input_props_storage.second; + }; + + auto cb = [&](OperatorNodeBase *opr) { + if (opr->node_prop().contain( + OperatorNodeBase::NodeProp::Flag:: + DISALLOW_COMP_NODE_OPTIMIZE)) { + return; + } + + // first check whether any output var is registered for change + bool output_changed = false; + for (auto i: opr->output()) { + auto iter = m_var2prop_type.find(i); + if (iter != m_var2prop_type.end()) { + output_changed = true; + var_to_specific_stream(i, iter->second.stream); + changed_vars[i] = iter->second; + } + } + if (output_changed) + return; + + for (auto i: opr->output()) { + StreamPropType prop; + auto &&iter = m_var2prop_func.find(i); + if (iter != m_var2prop_func.end()) { + iter->second(prop, get_input_props(opr)); + } + else { + prop = propagate_single_opr(opr); + } + if (prop.prop_type != StreamPropType::NONE) { + var_to_specific_stream(i, prop.stream); + changed_vars[i] = prop; } } }; @@ -152,6 +199,12 @@ void SeqCompNodeOptimizerImpl::register_stream_var( } } +void SeqCompNodeOptimizerImpl::register_propagate_function( + VarNode *var, PropFunction prop_func) { + mgb_assert(var->owner_graph() == m_owner_graph); + mgb_assert(m_var2prop_func.emplace(var, prop_func).second); +} + void SeqCompNodeOptimizerImpl::init_ready_event( const CompSeqExtraInfo &extra_info, const OprNodeArray &seq) { // clear existing synchronizers diff --git a/src/core/impl/graph/seq_comp_node_opt_impl.h b/src/core/impl/graph/seq_comp_node_opt_impl.h index 6eb7e222..8f81f453 100644 --- a/src/core/impl/graph/seq_comp_node_opt_impl.h +++ b/src/core/impl/graph/seq_comp_node_opt_impl.h @@ -23,6 +23,7 @@ class SeqCompNodeOptimizerImpl final: public SeqCompNodeOptimizer { ComputingGraphImpl *m_owner_graph; std::vector> m_comp_node_to_restore; ThinHashSet m_comp_node_changed_oprs; + ThinHashMap m_var2prop_func; /*! * cn0 -> (cn1 -> [(a, b)]): an opr at step \p a on \p cn0 is known to start @@ -56,6 +57,8 @@ class SeqCompNodeOptimizerImpl final: public SeqCompNodeOptimizer { void register_stream_var(VarNode* var, StreamPropType prop_type) override; + void register_propagate_function(VarNode* var, PropFunction prop_func) override; + StreamPropType stream_prop_type(VarNode *var) override { auto iter = m_var2prop_type.find(var); return iter == m_var2prop_type.end() diff --git a/src/core/include/megbrain/graph/seq_comp_node_opt.h b/src/core/include/megbrain/graph/seq_comp_node_opt.h index 0b312a9e..27fc7d8f 100644 --- a/src/core/include/megbrain/graph/seq_comp_node_opt.h +++ b/src/core/include/megbrain/graph/seq_comp_node_opt.h @@ -39,11 +39,18 @@ class SeqCompNodeOptimizer { int stream; //!< stream to change PropType prop_type; }; + using PropFunction = thin_function& /* srcs */)>; //! register a var that should be placed on the stream virtual void register_stream_var( VarNode* var, StreamPropType prop_type) = 0; + //! register a propagate function on given var_node + virtual void register_propagate_function( + VarNode* var, PropFunction prop_func) = 0; + //! check if a var has been registered in stream and get its //! propagation type virtual StreamPropType stream_prop_type(VarNode *var) = 0; diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index a22b6273..bd24b63c 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -38,6 +38,11 @@ class ComputingGraphImpl : public ComputingGraph { public: GraphExecutable::ExecEnv* current_exec_env(); }; +class SeqCompNodeOptimizerImpl : public SeqCompNodeOptimizer { + ~SeqCompNodeOptimizerImpl() = default; +public: + void optimize_comp_nodes(const VarNodeArray &endpoints); +}; } // namespace cg } // namespace mgb @@ -1746,4 +1751,27 @@ TEST(TestGraph, CPUGPUHybrid) { } +TEST(TestGraph, In2OutOpStreamPropagate) { + REQUIRE_GPU(1); // seq_comp_node_opt works on comp_node with HAS_COPY_STREAM + HostTensorGenerator<> gen; + SmallVector> host_v = {gen({233}), gen({23})}; + using PropType = cg::SeqCompNodeOptimizer::StreamPropType; + for (auto type : {PropType::STRONG, PropType::WEAK}) + for (size_t idx : {0, 1}) { + auto graph = ComputingGraph::make(); + SymbolVarArray inp(2); + for (size_t i = 0 ; i < 2; ++ i) { + inp[i] = opr::Host2DeviceCopy::make(*graph, host_v[i]); + } + auto out = opr::VirtualDep::make(inp); + auto &&mgr = static_cast(graph->seq_comp_node_optimizer()); + mgr.register_stream_var(inp[idx].node(), PropType{CompNode::Stream::COPY, type}); + mgr.optimize_comp_nodes({out.node()}); + ASSERT_EQ(inp[0].node()->comp_node(), out.node()->comp_node()); + auto o_stream = out.node()->comp_node().locator().stream; + int expect = idx ? 0 : int(CompNode::Stream::COPY); + ASSERT_EQ(o_stream, expect); + } +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/internal/identical_fwd.cpp b/src/opr/impl/internal/identical_fwd.cpp index a2289764..9d5dca91 100644 --- a/src/opr/impl/internal/identical_fwd.cpp +++ b/src/opr/impl/internal/identical_fwd.cpp @@ -168,4 +168,14 @@ void ForwardInputToOutput::mixin_scn_do_execute(OperatorNodeBase &opr) { void ForwardInputToOutput::scn_do_execute_finish(const DeviceTensorND&) {} +void ForwardInputToOutput::register_stream_propagate_in2out(OperatorNodeBase &opr) { + auto &&ovar = opr.output(0); + auto&& mgr = ovar->owner_graph()->seq_comp_node_optimizer(); + using PropType = cg::SeqCompNodeOptimizer::StreamPropType; + auto func = [](PropType& dst, const SmallVector& inp) { + dst = inp[0]; + }; + mgr.register_propagate_function(ovar, func); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/internal/identical_fwd.h b/src/opr/include/megbrain/opr/internal/identical_fwd.h index 34739c37..df99d0bc 100644 --- a/src/opr/include/megbrain/opr/internal/identical_fwd.h +++ b/src/opr/include/megbrain/opr/internal/identical_fwd.h @@ -87,6 +87,12 @@ class ForwardInputToOutput: public cg::OperatorNodeMixinBase { */ void set_ignore_side_effect(); + /*! + * \brief register stream propagate function which forwards the + * StreamPropType from \p opr input(0) to output(0). + */ + void register_stream_propagate_in2out(OperatorNodeBase &opr); + public: /*! @@ -178,6 +184,11 @@ MGB_DEFINE_CLS_WITH_SUPER(ForwardInputToOutput, void init_output_static_infer_desc() override { this->mixin_init_output_static_infer_desc(*this); } + + void init_output_comp_node() override { + Super::init_output_comp_node(); + this->register_stream_propagate_in2out(*this); + } }; } // namespace intl