Browse Source

fix(mgb/core): fix stream propagation in seq_comp_node_opt

for ForwardInToOutput op, StreamPropType of output(0) should be forwarded as input(0)

GitOrigin-RevId: 13e9587961
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
075c4562bc
6 changed files with 138 additions and 26 deletions
  1. +79
    -26
      src/core/impl/graph/seq_comp_node_opt_impl.cpp
  2. +3
    -0
      src/core/impl/graph/seq_comp_node_opt_impl.h
  3. +7
    -0
      src/core/include/megbrain/graph/seq_comp_node_opt.h
  4. +28
    -0
      src/core/test/graph/misc.cpp
  5. +10
    -0
      src/opr/impl/internal/identical_fwd.cpp
  6. +11
    -0
      src/opr/include/megbrain/opr/internal/identical_fwd.h

+ 79
- 26
src/core/impl/graph/seq_comp_node_opt_impl.cpp View File

@@ -63,29 +63,21 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream(
} }


ThinHashMap<VarNode*, StreamPropType> changed_vars; ThinHashMap<VarNode*, StreamPropType> changed_vars;
auto cb = [this, &changed_vars](OperatorNodeBase *opr) {
if (opr->node_prop().contain(
OperatorNodeBase::NodeProp::Flag::
DISALLOW_COMP_NODE_OPTIMIZE)) {
return;
std::pair<OperatorNodeBase*, StreamPropType> prop_type_storage;
std::pair<OperatorNodeBase*, SmallVector<StreamPropType>> input_props_storage;

// both `propagate_single_opr` and `get_input_props` might be called any number
// of times(>=0) with the same \p opr in a cb(opr) function call, so we cache
// the result of the first call.
auto propagate_single_opr = [&](OperatorNodeBase* opr) {
mgb_assert(opr);
if (prop_type_storage.first == opr) {
return prop_type_storage.second;
} }
prop_type_storage.first = opr;


// first check whether any output var is registered for change
bool output_changed = false;
for (auto i: opr->output()) {
auto iter = m_var2prop_type.find(i);
if (iter != m_var2prop_type.end()) {
output_changed = true;
var_to_specific_stream(i, iter->second.stream);
changed_vars[i] = iter->second;
}
}
if (output_changed)
return;

// check inputs
bool any_strong_changed = false, all_weak_changed = true, bool any_strong_changed = false, all_weak_changed = true,
all_weak_changed_valid = false;
all_weak_changed_valid = false;
auto &&dep_map = opr->node_prop().dep_map(); auto &&dep_map = opr->node_prop().dep_map();


ThinHashSet<int> inp_streams; ThinHashSet<int> inp_streams;
@@ -100,6 +92,7 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream(
if (iter == changed_vars.end()) { if (iter == changed_vars.end()) {
all_weak_changed = false; all_weak_changed = false;
} else { } else {
mgb_assert(iter->second.prop_type != StreamPropType::NONE);
if (iter->second.prop_type == StreamPropType::STRONG) { if (iter->second.prop_type == StreamPropType::STRONG) {
any_strong_changed = true; any_strong_changed = true;
} else { } else {
@@ -109,21 +102,75 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream(
} }
} }


auto type = StreamPropType::NONE;
int stream = 0;
if (any_strong_changed || if (any_strong_changed ||
(all_weak_changed && all_weak_changed_valid)) { (all_weak_changed && all_weak_changed_valid)) {
auto type = any_strong_changed ?
type = any_strong_changed ?
StreamPropType::STRONG : StreamPropType::WEAK; StreamPropType::STRONG : StreamPropType::WEAK;
int stream = 0;
int copy_stream = CompNode::Stream::COPY; int copy_stream = CompNode::Stream::COPY;
int nccl_stream = CompNode::Stream::NCCL; int nccl_stream = CompNode::Stream::NCCL;
if (inp_streams.count(copy_stream)) if (inp_streams.count(copy_stream))
stream = copy_stream; stream = copy_stream;
else if (inp_streams.count(nccl_stream)) else if (inp_streams.count(nccl_stream))
stream = nccl_stream; stream = nccl_stream;
mgb_assert(stream != 0);
for (auto i: opr->output()) {
var_to_specific_stream(i, stream);
changed_vars[i] = StreamPropType{stream, type};
mgb_assert(type != StreamPropType::NONE && stream != 0);
}
return prop_type_storage.second = StreamPropType{stream, type};
};

auto get_input_props = [&](OperatorNodeBase *opr) {
mgb_assert(opr);
if (input_props_storage.first == opr) {
return input_props_storage.second;
}
input_props_storage.first = opr;

auto &&props = input_props_storage.second;
props.clear();
for (auto i : opr->input()) {
auto &&iter = changed_vars.find(i);
if (iter != changed_vars.end()) {
props.push_back(iter->second);
} else {
props.push_back(StreamPropType{0, StreamPropType::NONE});
}
}
return input_props_storage.second;
};

auto cb = [&](OperatorNodeBase *opr) {
if (opr->node_prop().contain(
OperatorNodeBase::NodeProp::Flag::
DISALLOW_COMP_NODE_OPTIMIZE)) {
return;
}

// first check whether any output var is registered for change
bool output_changed = false;
for (auto i: opr->output()) {
auto iter = m_var2prop_type.find(i);
if (iter != m_var2prop_type.end()) {
output_changed = true;
var_to_specific_stream(i, iter->second.stream);
changed_vars[i] = iter->second;
}
}
if (output_changed)
return;

for (auto i: opr->output()) {
StreamPropType prop;
auto &&iter = m_var2prop_func.find(i);
if (iter != m_var2prop_func.end()) {
iter->second(prop, get_input_props(opr));
}
else {
prop = propagate_single_opr(opr);
}
if (prop.prop_type != StreamPropType::NONE) {
var_to_specific_stream(i, prop.stream);
changed_vars[i] = prop;
} }
} }
}; };
@@ -152,6 +199,12 @@ void SeqCompNodeOptimizerImpl::register_stream_var(
} }
} }


void SeqCompNodeOptimizerImpl::register_propagate_function(
VarNode *var, PropFunction prop_func) {
mgb_assert(var->owner_graph() == m_owner_graph);
mgb_assert(m_var2prop_func.emplace(var, prop_func).second);
}

void SeqCompNodeOptimizerImpl::init_ready_event( void SeqCompNodeOptimizerImpl::init_ready_event(
const CompSeqExtraInfo &extra_info, const OprNodeArray &seq) { const CompSeqExtraInfo &extra_info, const OprNodeArray &seq) {
// clear existing synchronizers // clear existing synchronizers


+ 3
- 0
src/core/impl/graph/seq_comp_node_opt_impl.h View File

@@ -23,6 +23,7 @@ class SeqCompNodeOptimizerImpl final: public SeqCompNodeOptimizer {
ComputingGraphImpl *m_owner_graph; ComputingGraphImpl *m_owner_graph;
std::vector<std::pair<VarNode*, CompNode>> m_comp_node_to_restore; std::vector<std::pair<VarNode*, CompNode>> m_comp_node_to_restore;
ThinHashSet<OperatorNodeBase*> m_comp_node_changed_oprs; ThinHashSet<OperatorNodeBase*> m_comp_node_changed_oprs;
ThinHashMap<VarNode*, PropFunction> m_var2prop_func;


/*! /*!
* cn0 -> (cn1 -> [(a, b)]): an opr at step \p a on \p cn0 is known to start * cn0 -> (cn1 -> [(a, b)]): an opr at step \p a on \p cn0 is known to start
@@ -56,6 +57,8 @@ class SeqCompNodeOptimizerImpl final: public SeqCompNodeOptimizer {


void register_stream_var(VarNode* var, StreamPropType prop_type) override; void register_stream_var(VarNode* var, StreamPropType prop_type) override;


void register_propagate_function(VarNode* var, PropFunction prop_func) override;

StreamPropType stream_prop_type(VarNode *var) override { StreamPropType stream_prop_type(VarNode *var) override {
auto iter = m_var2prop_type.find(var); auto iter = m_var2prop_type.find(var);
return iter == m_var2prop_type.end() return iter == m_var2prop_type.end()


+ 7
- 0
src/core/include/megbrain/graph/seq_comp_node_opt.h View File

@@ -39,11 +39,18 @@ class SeqCompNodeOptimizer {
int stream; //!< stream to change int stream; //!< stream to change
PropType prop_type; PropType prop_type;
}; };
using PropFunction = thin_function<void(
StreamPropType& /* dest */,
const SmallVector<StreamPropType>& /* srcs */)>;


//! register a var that should be placed on the stream //! register a var that should be placed on the stream
virtual void register_stream_var( virtual void register_stream_var(
VarNode* var, StreamPropType prop_type) = 0; VarNode* var, StreamPropType prop_type) = 0;


//! register a propagate function on given var_node
virtual void register_propagate_function(
VarNode* var, PropFunction prop_func) = 0;

//! check if a var has been registered in stream and get its //! check if a var has been registered in stream and get its
//! propagation type //! propagation type
virtual StreamPropType stream_prop_type(VarNode *var) = 0; virtual StreamPropType stream_prop_type(VarNode *var) = 0;


+ 28
- 0
src/core/test/graph/misc.cpp View File

@@ -38,6 +38,11 @@ class ComputingGraphImpl : public ComputingGraph {
public: public:
GraphExecutable::ExecEnv* current_exec_env(); GraphExecutable::ExecEnv* current_exec_env();
}; };
class SeqCompNodeOptimizerImpl : public SeqCompNodeOptimizer {
~SeqCompNodeOptimizerImpl() = default;
public:
void optimize_comp_nodes(const VarNodeArray &endpoints);
};
} // namespace cg } // namespace cg
} // namespace mgb } // namespace mgb


@@ -1746,4 +1751,27 @@ TEST(TestGraph, CPUGPUHybrid) {


} }


TEST(TestGraph, In2OutOpStreamPropagate) {
REQUIRE_GPU(1); // seq_comp_node_opt works on comp_node with HAS_COPY_STREAM
HostTensorGenerator<> gen;
SmallVector<std::shared_ptr<HostTensorND>> host_v = {gen({233}), gen({23})};
using PropType = cg::SeqCompNodeOptimizer::StreamPropType;
for (auto type : {PropType::STRONG, PropType::WEAK})
for (size_t idx : {0, 1}) {
auto graph = ComputingGraph::make();
SymbolVarArray inp(2);
for (size_t i = 0 ; i < 2; ++ i) {
inp[i] = opr::Host2DeviceCopy::make(*graph, host_v[i]);
}
auto out = opr::VirtualDep::make(inp);
auto &&mgr = static_cast<cg::SeqCompNodeOptimizerImpl&>(graph->seq_comp_node_optimizer());
mgr.register_stream_var(inp[idx].node(), PropType{CompNode::Stream::COPY, type});
mgr.optimize_comp_nodes({out.node()});
ASSERT_EQ(inp[0].node()->comp_node(), out.node()->comp_node());
auto o_stream = out.node()->comp_node().locator().stream;
int expect = idx ? 0 : int(CompNode::Stream::COPY);
ASSERT_EQ(o_stream, expect);
}
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 10
- 0
src/opr/impl/internal/identical_fwd.cpp View File

@@ -168,4 +168,14 @@ void ForwardInputToOutput::mixin_scn_do_execute(OperatorNodeBase &opr) {


void ForwardInputToOutput::scn_do_execute_finish(const DeviceTensorND&) {} void ForwardInputToOutput::scn_do_execute_finish(const DeviceTensorND&) {}


void ForwardInputToOutput::register_stream_propagate_in2out(OperatorNodeBase &opr) {
auto &&ovar = opr.output(0);
auto&& mgr = ovar->owner_graph()->seq_comp_node_optimizer();
using PropType = cg::SeqCompNodeOptimizer::StreamPropType;
auto func = [](PropType& dst, const SmallVector<PropType>& inp) {
dst = inp[0];
};
mgr.register_propagate_function(ovar, func);
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 11
- 0
src/opr/include/megbrain/opr/internal/identical_fwd.h View File

@@ -87,6 +87,12 @@ class ForwardInputToOutput: public cg::OperatorNodeMixinBase {
*/ */
void set_ignore_side_effect(); void set_ignore_side_effect();


/*!
* \brief register stream propagate function which forwards the
* StreamPropType from \p opr input(0) to output(0).
*/
void register_stream_propagate_in2out(OperatorNodeBase &opr);

public: public:


/*! /*!
@@ -178,6 +184,11 @@ MGB_DEFINE_CLS_WITH_SUPER(ForwardInputToOutput,
void init_output_static_infer_desc() override { void init_output_static_infer_desc() override {
this->mixin_init_output_static_infer_desc(*this); this->mixin_init_output_static_infer_desc(*this);
} }

void init_output_comp_node() override {
Super::init_output_comp_node();
this->register_stream_propagate_in2out(*this);
}
}; };


} // namespace intl } // namespace intl


Loading…
Cancel
Save