@@ -33,16 +33,16 @@ class JITFusionPass::Impl final { | |||
CompNode::UnorderedMap<size_t> m_cn2max_nr_input; | |||
SubGraph::Rewriter m_rewriter; | |||
SmallVector<std::unique_ptr<InternalGraphGenrator>> m_igraph_gen_storage; | |||
ThinHashMap<VarNode*, InternalGraphGenrator*> m_var2igraph_gen; | |||
SmallVector<std::unique_ptr<InternalGraphGenerator>> m_igraph_gen_storage; | |||
ThinHashMap<VarNode*, InternalGraphGenerator*> m_var2igraph_gen; | |||
//! map from var to its reader oprs and the corresponding dependency types | |||
ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||
m_var_readers; | |||
ThinHashSet<VarNode*> m_endpoint_set; | |||
//! create a new InternalGraphGenrator rooted at given opr | |||
InternalGraphGenrator* create_new_igraph_gen(OperatorNodeBase* opr); | |||
//! create a new InternalGraphGenerator rooted at given opr | |||
InternalGraphGenerator* create_new_igraph_gen(OperatorNodeBase* opr); | |||
//! process a single operator, maintaining m_var2igraph_gen | |||
void process_opr(OperatorNodeBase* opr); | |||
@@ -51,11 +51,11 @@ class JITFusionPass::Impl final { | |||
//! check whether all oprs which depend on the var are in i_graph | |||
bool test_all_readers_in_the_graph(VarNode* var, | |||
InternalGraphGenrator* i_graph); | |||
InternalGraphGenerator* i_graph); | |||
//! check shape to determine whether the opr should be added to the internal | |||
//! graph | |||
bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenrator* i_graph); | |||
bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenerator* i_graph); | |||
//! use m_rewriter to update graph | |||
void update_graph(); | |||
@@ -155,7 +155,7 @@ void JITFusionPass::Impl::update_graph() { | |||
} | |||
bool JITFusionPass::Impl::test_all_readers_in_the_graph( | |||
VarNode* var, InternalGraphGenrator* ig_gen) { | |||
VarNode* var, InternalGraphGenerator* ig_gen) { | |||
for (auto&& reader : m_var_readers.at(var)) { | |||
if (reader.second & DepType::DEV_VALUE) { | |||
if (ig_gen->opr_set().count(reader.first) == 0) { | |||
@@ -167,7 +167,7 @@ bool JITFusionPass::Impl::test_all_readers_in_the_graph( | |||
} | |||
bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr, | |||
InternalGraphGenrator* ig_gen) { | |||
InternalGraphGenerator* ig_gen) { | |||
if (!cg::is_static_var_shape(opr->output(0))) { | |||
// currently we do not handle dynamic shape in JIT | |||
return false; | |||
@@ -249,9 +249,9 @@ bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr, | |||
} | |||
} | |||
InternalGraphGenrator* JITFusionPass::Impl::create_new_igraph_gen( | |||
InternalGraphGenerator* JITFusionPass::Impl::create_new_igraph_gen( | |||
OperatorNodeBase* opr) { | |||
auto uptr = std::make_unique<InternalGraphGenrator>(opr); | |||
auto uptr = std::make_unique<InternalGraphGenerator>(opr); | |||
auto ptr = uptr.get(); | |||
m_igraph_gen_storage.emplace_back(std::move(uptr)); | |||
m_var2igraph_gen[opr->output(0)] = ptr; | |||
@@ -267,7 +267,7 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { | |||
} | |||
// dimshuffle should not be an endpoint, because megbrain has lazy | |||
// dimshuffle machanism | |||
InternalGraphGenrator* ig_gen = nullptr; | |||
InternalGraphGenerator* ig_gen = nullptr; | |||
if (m_var2igraph_gen.count(opr->output(0)) == 0) { | |||
// because of the reverse traversal, when an operator is being | |||
// processed but not in m_var2igraph_gen, means it is a endpoint of a | |||
@@ -81,12 +81,12 @@ InternalGraphPtr expand_executor_opr(const InternalGraphPtr& prev_igraph) { | |||
} // namespace | |||
InternalGraphGenrator::InternalGraphGenrator(cg::OperatorNodeBase* opr) | |||
InternalGraphGenerator::InternalGraphGenerator(cg::OperatorNodeBase* opr) | |||
: m_output{opr->output(0)} { | |||
add_opr(opr); | |||
} | |||
VarNode* InternalGraphGenrator::replace_graph_by_placeholder() { | |||
VarNode* InternalGraphGenerator::replace_graph_by_placeholder() { | |||
ThinHashMap<VarNode*, VarNode*> old2new; | |||
auto cpu_default = CompNode::default_cpu(); | |||
auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, | |||
@@ -163,7 +163,7 @@ VarNode* InternalGraphGenrator::replace_graph_by_placeholder() { | |||
return old2new.at(m_output); | |||
} | |||
InternalGraphPtr InternalGraphGenrator::generate() { | |||
InternalGraphPtr InternalGraphGenerator::generate() { | |||
m_input_idx = 0; | |||
auto new_nd = replace_graph_by_placeholder(); | |||
@@ -172,7 +172,7 @@ InternalGraphPtr InternalGraphGenrator::generate() { | |||
return expand_executor_opr(igraph); | |||
} | |||
size_t InternalGraphGenrator::get_cnt_input_if_add( | |||
size_t InternalGraphGenerator::get_cnt_input_if_add( | |||
cg::OperatorNodeBase* opr) const { | |||
// minus 1 first because this opr should be removed from subgraph's input | |||
size_t new_cnt_input = m_graph_input_set.size() - 1; | |||
@@ -183,7 +183,7 @@ size_t InternalGraphGenrator::get_cnt_input_if_add( | |||
return new_cnt_input; | |||
} | |||
void InternalGraphGenrator::add_opr(cg::OperatorNodeBase* opr) { | |||
void InternalGraphGenerator::add_opr(cg::OperatorNodeBase* opr) { | |||
if (m_opr_set.count(opr)) { | |||
// ignore duplicated oprs (which occur in tests) | |||
return; | |||
@@ -253,7 +253,7 @@ void InternalGraphGenrator::add_opr(cg::OperatorNodeBase* opr) { | |||
} | |||
} | |||
void InternalGraphGenrator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) { | |||
void InternalGraphGenerator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) { | |||
mgb_assert(opr->same_type<opr::Reduce>() || | |||
(opr->same_type<jit::JITExecutor>() && | |||
try_cast_as_op<jit::JITExecutor>(opr)->has_reduce())); | |||
@@ -264,7 +264,7 @@ void InternalGraphGenrator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) { | |||
cg::DepOprIter{cb}.add(opr); | |||
} | |||
void InternalGraphGenrator::find_oprs_depended_by_dimshuffle( | |||
void InternalGraphGenerator::find_oprs_depended_by_dimshuffle( | |||
cg::OperatorNodeBase* dimshuffle) { | |||
mgb_assert( | |||
dimshuffle->same_type<opr::Dimshuffle>() || | |||
@@ -287,7 +287,7 @@ void InternalGraphGenrator::find_oprs_depended_by_dimshuffle( | |||
cg::DepOprIter{cb}.add(dimshuffle); | |||
} | |||
PlaceholderArray InternalGraphGenrator::to_placeholder_opr_arr( | |||
PlaceholderArray InternalGraphGenerator::to_placeholder_opr_arr( | |||
const VarNodeArray& vars) { | |||
PlaceholderArray ret(vars.size()); | |||
for (size_t i = 0; i < vars.size(); ++i) { | |||
@@ -76,12 +76,12 @@ private: | |||
* This object stores intermediate state during visiting the computing graph in | |||
* JITFusionPass. | |||
* | |||
* The graph is iterated in reverse topological order. InternalGraphGenrator | |||
* The graph is iterated in reverse topological order. InternalGraphGenerator | |||
* starts with a single operator (i.e. the output node of the fused opr), and | |||
* new oprs are gradually added into it. Thus the process is expanding a tree | |||
* rooted at the output node. | |||
*/ | |||
class InternalGraphGenrator { | |||
class InternalGraphGenerator { | |||
//! replace oprs in the graph of m_output and populate m_orig_inps, | |||
//! m_placeholders | |||
VarNode* replace_graph_by_placeholder(); | |||
@@ -95,7 +95,7 @@ class InternalGraphGenrator { | |||
void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr); | |||
public: | |||
explicit InternalGraphGenrator(cg::OperatorNodeBase* opr); | |||
explicit InternalGraphGenerator(cg::OperatorNodeBase* opr); | |||
//! generate the graph; this method can be called multiple times | |||
InternalGraphPtr generate(); | |||
@@ -54,7 +54,7 @@ void run<simple>(Backend backend, CompNode cn) { | |||
VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | |||
auto ig_gen = | |||
std::make_unique<InternalGraphGenrator>(y.node()->owner_opr()); | |||
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||
for (auto i : get_rev_topo_order(y)) { | |||
if (!i->same_type<opr::Host2DeviceCopy>()) { | |||
@@ -91,7 +91,7 @@ void run<grad>(Backend backend, CompNode cn) { | |||
VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | |||
auto ig_gen = | |||
std::make_unique<InternalGraphGenrator>(y.node()->owner_opr()); | |||
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||
for (auto i : get_rev_topo_order(y)) { | |||
if (!i->same_type<opr::Host2DeviceCopy>()) { | |||
@@ -540,7 +540,7 @@ void run<expand_jit_executor>(Backend backend, CompNode cn) { | |||
auto make_jit = [](SymbolVar target, const SymbolVarArray& inputs) { | |||
auto y = target.node(); | |||
auto ig_gen = std::make_unique<InternalGraphGenrator>(y->owner_opr()); | |||
auto ig_gen = std::make_unique<InternalGraphGenerator>(y->owner_opr()); | |||
auto inputs_vptr = cg::to_var_node_array(inputs); | |||
for (auto i : get_rev_topo_order( | |||
target, {inputs_vptr.begin(), inputs_vptr.end()})) { | |||
@@ -830,9 +830,9 @@ TEST(TestJITFusionHalide, JITExecutor) { | |||
y = opr::reduce_sum(a + b, shape_of_b), | |||
z = opr::reduce_sum(a * b, shape_of_a); | |||
auto ig_gen_1 = | |||
std::make_unique<InternalGraphGenrator>(y.node()->owner_opr()); | |||
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||
auto ig_gen_2 = | |||
std::make_unique<InternalGraphGenrator>(z.node()->owner_opr()); | |||
std::make_unique<InternalGraphGenerator>(z.node()->owner_opr()); | |||
{ | |||
ThinHashSet<VarNode*> nd_set; | |||
nd_set.insert(a.node()); | |||
@@ -85,7 +85,7 @@ void FusionChecker::ensure_init_graph() { | |||
SymbolVar jit_y; | |||
if (m_direct_build) { | |||
auto ig_gen = std::make_unique<InternalGraphGenrator>( | |||
auto ig_gen = std::make_unique<InternalGraphGenerator>( | |||
m_truth_y.node()->owner_opr()); | |||
ThinHashSet<VarNode*> endpoints_set; | |||
for (size_t i = 0; i < m_nr_input; ++i) { | |||