for ForwardInToOutput op, StreamPropType of output(0) should be forwarded as input(0)
GitOrigin-RevId: 13e9587961
tags/v0.3.2
@@ -63,29 +63,21 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream( | |||||
} | } | ||||
ThinHashMap<VarNode*, StreamPropType> changed_vars; | ThinHashMap<VarNode*, StreamPropType> changed_vars; | ||||
auto cb = [this, &changed_vars](OperatorNodeBase *opr) { | |||||
if (opr->node_prop().contain( | |||||
OperatorNodeBase::NodeProp::Flag:: | |||||
DISALLOW_COMP_NODE_OPTIMIZE)) { | |||||
return; | |||||
std::pair<OperatorNodeBase*, StreamPropType> prop_type_storage; | |||||
std::pair<OperatorNodeBase*, SmallVector<StreamPropType>> input_props_storage; | |||||
// both `propagate_single_opr` and `get_input_props` might be called any number | |||||
// of times(>=0) with the same \p opr in a cb(opr) function call, so we cache | |||||
// the result of the first call. | |||||
auto propagate_single_opr = [&](OperatorNodeBase* opr) { | |||||
mgb_assert(opr); | |||||
if (prop_type_storage.first == opr) { | |||||
return prop_type_storage.second; | |||||
} | } | ||||
prop_type_storage.first = opr; | |||||
// first check whether any output var is registered for change | |||||
bool output_changed = false; | |||||
for (auto i: opr->output()) { | |||||
auto iter = m_var2prop_type.find(i); | |||||
if (iter != m_var2prop_type.end()) { | |||||
output_changed = true; | |||||
var_to_specific_stream(i, iter->second.stream); | |||||
changed_vars[i] = iter->second; | |||||
} | |||||
} | |||||
if (output_changed) | |||||
return; | |||||
// check inputs | |||||
bool any_strong_changed = false, all_weak_changed = true, | bool any_strong_changed = false, all_weak_changed = true, | ||||
all_weak_changed_valid = false; | |||||
all_weak_changed_valid = false; | |||||
auto &&dep_map = opr->node_prop().dep_map(); | auto &&dep_map = opr->node_prop().dep_map(); | ||||
ThinHashSet<int> inp_streams; | ThinHashSet<int> inp_streams; | ||||
@@ -100,6 +92,7 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream( | |||||
if (iter == changed_vars.end()) { | if (iter == changed_vars.end()) { | ||||
all_weak_changed = false; | all_weak_changed = false; | ||||
} else { | } else { | ||||
mgb_assert(iter->second.prop_type != StreamPropType::NONE); | |||||
if (iter->second.prop_type == StreamPropType::STRONG) { | if (iter->second.prop_type == StreamPropType::STRONG) { | ||||
any_strong_changed = true; | any_strong_changed = true; | ||||
} else { | } else { | ||||
@@ -109,21 +102,75 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream( | |||||
} | } | ||||
} | } | ||||
auto type = StreamPropType::NONE; | |||||
int stream = 0; | |||||
if (any_strong_changed || | if (any_strong_changed || | ||||
(all_weak_changed && all_weak_changed_valid)) { | (all_weak_changed && all_weak_changed_valid)) { | ||||
auto type = any_strong_changed ? | |||||
type = any_strong_changed ? | |||||
StreamPropType::STRONG : StreamPropType::WEAK; | StreamPropType::STRONG : StreamPropType::WEAK; | ||||
int stream = 0; | |||||
int copy_stream = CompNode::Stream::COPY; | int copy_stream = CompNode::Stream::COPY; | ||||
int nccl_stream = CompNode::Stream::NCCL; | int nccl_stream = CompNode::Stream::NCCL; | ||||
if (inp_streams.count(copy_stream)) | if (inp_streams.count(copy_stream)) | ||||
stream = copy_stream; | stream = copy_stream; | ||||
else if (inp_streams.count(nccl_stream)) | else if (inp_streams.count(nccl_stream)) | ||||
stream = nccl_stream; | stream = nccl_stream; | ||||
mgb_assert(stream != 0); | |||||
for (auto i: opr->output()) { | |||||
var_to_specific_stream(i, stream); | |||||
changed_vars[i] = StreamPropType{stream, type}; | |||||
mgb_assert(type != StreamPropType::NONE && stream != 0); | |||||
} | |||||
return prop_type_storage.second = StreamPropType{stream, type}; | |||||
}; | |||||
auto get_input_props = [&](OperatorNodeBase *opr) { | |||||
mgb_assert(opr); | |||||
if (input_props_storage.first == opr) { | |||||
return input_props_storage.second; | |||||
} | |||||
input_props_storage.first = opr; | |||||
auto &&props = input_props_storage.second; | |||||
props.clear(); | |||||
for (auto i : opr->input()) { | |||||
auto &&iter = changed_vars.find(i); | |||||
if (iter != changed_vars.end()) { | |||||
props.push_back(iter->second); | |||||
} else { | |||||
props.push_back(StreamPropType{0, StreamPropType::NONE}); | |||||
} | |||||
} | |||||
return input_props_storage.second; | |||||
}; | |||||
auto cb = [&](OperatorNodeBase *opr) { | |||||
if (opr->node_prop().contain( | |||||
OperatorNodeBase::NodeProp::Flag:: | |||||
DISALLOW_COMP_NODE_OPTIMIZE)) { | |||||
return; | |||||
} | |||||
// first check whether any output var is registered for change | |||||
bool output_changed = false; | |||||
for (auto i: opr->output()) { | |||||
auto iter = m_var2prop_type.find(i); | |||||
if (iter != m_var2prop_type.end()) { | |||||
output_changed = true; | |||||
var_to_specific_stream(i, iter->second.stream); | |||||
changed_vars[i] = iter->second; | |||||
} | |||||
} | |||||
if (output_changed) | |||||
return; | |||||
for (auto i: opr->output()) { | |||||
StreamPropType prop; | |||||
auto &&iter = m_var2prop_func.find(i); | |||||
if (iter != m_var2prop_func.end()) { | |||||
iter->second(prop, get_input_props(opr)); | |||||
} | |||||
else { | |||||
prop = propagate_single_opr(opr); | |||||
} | |||||
if (prop.prop_type != StreamPropType::NONE) { | |||||
var_to_specific_stream(i, prop.stream); | |||||
changed_vars[i] = prop; | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
@@ -152,6 +199,12 @@ void SeqCompNodeOptimizerImpl::register_stream_var( | |||||
} | } | ||||
} | } | ||||
void SeqCompNodeOptimizerImpl::register_propagate_function( | |||||
VarNode *var, PropFunction prop_func) { | |||||
mgb_assert(var->owner_graph() == m_owner_graph); | |||||
mgb_assert(m_var2prop_func.emplace(var, prop_func).second); | |||||
} | |||||
void SeqCompNodeOptimizerImpl::init_ready_event( | void SeqCompNodeOptimizerImpl::init_ready_event( | ||||
const CompSeqExtraInfo &extra_info, const OprNodeArray &seq) { | const CompSeqExtraInfo &extra_info, const OprNodeArray &seq) { | ||||
// clear existing synchronizers | // clear existing synchronizers | ||||
@@ -23,6 +23,7 @@ class SeqCompNodeOptimizerImpl final: public SeqCompNodeOptimizer { | |||||
ComputingGraphImpl *m_owner_graph; | ComputingGraphImpl *m_owner_graph; | ||||
std::vector<std::pair<VarNode*, CompNode>> m_comp_node_to_restore; | std::vector<std::pair<VarNode*, CompNode>> m_comp_node_to_restore; | ||||
ThinHashSet<OperatorNodeBase*> m_comp_node_changed_oprs; | ThinHashSet<OperatorNodeBase*> m_comp_node_changed_oprs; | ||||
ThinHashMap<VarNode*, PropFunction> m_var2prop_func; | |||||
/*! | /*! | ||||
* cn0 -> (cn1 -> [(a, b)]): an opr at step \p a on \p cn0 is known to start | * cn0 -> (cn1 -> [(a, b)]): an opr at step \p a on \p cn0 is known to start | ||||
@@ -56,6 +57,8 @@ class SeqCompNodeOptimizerImpl final: public SeqCompNodeOptimizer { | |||||
void register_stream_var(VarNode* var, StreamPropType prop_type) override; | void register_stream_var(VarNode* var, StreamPropType prop_type) override; | ||||
void register_propagate_function(VarNode* var, PropFunction prop_func) override; | |||||
StreamPropType stream_prop_type(VarNode *var) override { | StreamPropType stream_prop_type(VarNode *var) override { | ||||
auto iter = m_var2prop_type.find(var); | auto iter = m_var2prop_type.find(var); | ||||
return iter == m_var2prop_type.end() | return iter == m_var2prop_type.end() | ||||
@@ -39,11 +39,18 @@ class SeqCompNodeOptimizer { | |||||
int stream; //!< stream to change | int stream; //!< stream to change | ||||
PropType prop_type; | PropType prop_type; | ||||
}; | }; | ||||
using PropFunction = thin_function<void( | |||||
StreamPropType& /* dest */, | |||||
const SmallVector<StreamPropType>& /* srcs */)>; | |||||
//! register a var that should be placed on the stream | //! register a var that should be placed on the stream | ||||
virtual void register_stream_var( | virtual void register_stream_var( | ||||
VarNode* var, StreamPropType prop_type) = 0; | VarNode* var, StreamPropType prop_type) = 0; | ||||
//! register a propagate function on given var_node | |||||
virtual void register_propagate_function( | |||||
VarNode* var, PropFunction prop_func) = 0; | |||||
//! check if a var has been registered in stream and get its | //! check if a var has been registered in stream and get its | ||||
//! propagation type | //! propagation type | ||||
virtual StreamPropType stream_prop_type(VarNode *var) = 0; | virtual StreamPropType stream_prop_type(VarNode *var) = 0; | ||||
@@ -38,6 +38,11 @@ class ComputingGraphImpl : public ComputingGraph { | |||||
public: | public: | ||||
GraphExecutable::ExecEnv* current_exec_env(); | GraphExecutable::ExecEnv* current_exec_env(); | ||||
}; | }; | ||||
class SeqCompNodeOptimizerImpl : public SeqCompNodeOptimizer { | |||||
~SeqCompNodeOptimizerImpl() = default; | |||||
public: | |||||
void optimize_comp_nodes(const VarNodeArray &endpoints); | |||||
}; | |||||
} // namespace cg | } // namespace cg | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -1746,4 +1751,27 @@ TEST(TestGraph, CPUGPUHybrid) { | |||||
} | } | ||||
TEST(TestGraph, In2OutOpStreamPropagate) { | |||||
REQUIRE_GPU(1); // seq_comp_node_opt works on comp_node with HAS_COPY_STREAM | |||||
HostTensorGenerator<> gen; | |||||
SmallVector<std::shared_ptr<HostTensorND>> host_v = {gen({233}), gen({23})}; | |||||
using PropType = cg::SeqCompNodeOptimizer::StreamPropType; | |||||
for (auto type : {PropType::STRONG, PropType::WEAK}) | |||||
for (size_t idx : {0, 1}) { | |||||
auto graph = ComputingGraph::make(); | |||||
SymbolVarArray inp(2); | |||||
for (size_t i = 0 ; i < 2; ++ i) { | |||||
inp[i] = opr::Host2DeviceCopy::make(*graph, host_v[i]); | |||||
} | |||||
auto out = opr::VirtualDep::make(inp); | |||||
auto &&mgr = static_cast<cg::SeqCompNodeOptimizerImpl&>(graph->seq_comp_node_optimizer()); | |||||
mgr.register_stream_var(inp[idx].node(), PropType{CompNode::Stream::COPY, type}); | |||||
mgr.optimize_comp_nodes({out.node()}); | |||||
ASSERT_EQ(inp[0].node()->comp_node(), out.node()->comp_node()); | |||||
auto o_stream = out.node()->comp_node().locator().stream; | |||||
int expect = idx ? 0 : int(CompNode::Stream::COPY); | |||||
ASSERT_EQ(o_stream, expect); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -168,4 +168,14 @@ void ForwardInputToOutput::mixin_scn_do_execute(OperatorNodeBase &opr) { | |||||
void ForwardInputToOutput::scn_do_execute_finish(const DeviceTensorND&) {} | void ForwardInputToOutput::scn_do_execute_finish(const DeviceTensorND&) {} | ||||
void ForwardInputToOutput::register_stream_propagate_in2out(OperatorNodeBase &opr) { | |||||
auto &&ovar = opr.output(0); | |||||
auto&& mgr = ovar->owner_graph()->seq_comp_node_optimizer(); | |||||
using PropType = cg::SeqCompNodeOptimizer::StreamPropType; | |||||
auto func = [](PropType& dst, const SmallVector<PropType>& inp) { | |||||
dst = inp[0]; | |||||
}; | |||||
mgr.register_propagate_function(ovar, func); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -87,6 +87,12 @@ class ForwardInputToOutput: public cg::OperatorNodeMixinBase { | |||||
*/ | */ | ||||
void set_ignore_side_effect(); | void set_ignore_side_effect(); | ||||
/*! | |||||
* \brief register stream propagate function which forwards the | |||||
* StreamPropType from \p opr input(0) to output(0). | |||||
*/ | |||||
void register_stream_propagate_in2out(OperatorNodeBase &opr); | |||||
public: | public: | ||||
/*! | /*! | ||||
@@ -178,6 +184,11 @@ MGB_DEFINE_CLS_WITH_SUPER(ForwardInputToOutput, | |||||
void init_output_static_infer_desc() override { | void init_output_static_infer_desc() override { | ||||
this->mixin_init_output_static_infer_desc(*this); | this->mixin_init_output_static_infer_desc(*this); | ||||
} | } | ||||
void init_output_comp_node() override { | |||||
Super::init_output_comp_node(); | |||||
this->register_stream_propagate_in2out(*this); | |||||
} | |||||
}; | }; | ||||
} // namespace intl | } // namespace intl | ||||