Browse Source

fix(mge/jit): fix typo InternalGraphGenerator

GitOrigin-RevId: 1ab14acbd2
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
5fe7cd1457
6 changed files with 28 additions and 28 deletions
  1. +11
    -11
      src/jit/impl/fusion_pass.cpp
  2. +8
    -8
      src/jit/impl/internal_graph.cpp
  3. +3
    -3
      src/jit/include/megbrain/jit/internal_graph.h
  4. +2
    -2
      src/jit/test/codegen.cpp
  5. +3
    -3
      src/jit/test/fusion.cpp
  6. +1
    -1
      src/jit/test/helper.cpp

+ 11
- 11
src/jit/impl/fusion_pass.cpp View File

@@ -33,16 +33,16 @@ class JITFusionPass::Impl final {
CompNode::UnorderedMap<size_t> m_cn2max_nr_input; CompNode::UnorderedMap<size_t> m_cn2max_nr_input;


SubGraph::Rewriter m_rewriter; SubGraph::Rewriter m_rewriter;
SmallVector<std::unique_ptr<InternalGraphGenrator>> m_igraph_gen_storage;
ThinHashMap<VarNode*, InternalGraphGenrator*> m_var2igraph_gen;
SmallVector<std::unique_ptr<InternalGraphGenerator>> m_igraph_gen_storage;
ThinHashMap<VarNode*, InternalGraphGenerator*> m_var2igraph_gen;


//! map from var to its reader oprs and the corresponding dependency types //! map from var to its reader oprs and the corresponding dependency types
ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>> ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>>
m_var_readers; m_var_readers;
ThinHashSet<VarNode*> m_endpoint_set; ThinHashSet<VarNode*> m_endpoint_set;


//! create a new InternalGraphGenrator rooted at given opr
InternalGraphGenrator* create_new_igraph_gen(OperatorNodeBase* opr);
//! create a new InternalGraphGenerator rooted at given opr
InternalGraphGenerator* create_new_igraph_gen(OperatorNodeBase* opr);


//! process a single operator, maintaining m_var2igraph_gen //! process a single operator, maintaining m_var2igraph_gen
void process_opr(OperatorNodeBase* opr); void process_opr(OperatorNodeBase* opr);
@@ -51,11 +51,11 @@ class JITFusionPass::Impl final {


//! check whether all oprs which depend on the var are in i_graph //! check whether all oprs which depend on the var are in i_graph
bool test_all_readers_in_the_graph(VarNode* var, bool test_all_readers_in_the_graph(VarNode* var,
InternalGraphGenrator* i_graph);
InternalGraphGenerator* i_graph);


//! check shape to determine whether the opr should be added to the internal //! check shape to determine whether the opr should be added to the internal
//! graph //! graph
bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenrator* i_graph);
bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenerator* i_graph);


//! use m_rewriter to update graph //! use m_rewriter to update graph
void update_graph(); void update_graph();
@@ -155,7 +155,7 @@ void JITFusionPass::Impl::update_graph() {
} }


bool JITFusionPass::Impl::test_all_readers_in_the_graph( bool JITFusionPass::Impl::test_all_readers_in_the_graph(
VarNode* var, InternalGraphGenrator* ig_gen) {
VarNode* var, InternalGraphGenerator* ig_gen) {
for (auto&& reader : m_var_readers.at(var)) { for (auto&& reader : m_var_readers.at(var)) {
if (reader.second & DepType::DEV_VALUE) { if (reader.second & DepType::DEV_VALUE) {
if (ig_gen->opr_set().count(reader.first) == 0) { if (ig_gen->opr_set().count(reader.first) == 0) {
@@ -167,7 +167,7 @@ bool JITFusionPass::Impl::test_all_readers_in_the_graph(
} }


bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr, bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr,
InternalGraphGenrator* ig_gen) {
InternalGraphGenerator* ig_gen) {
if (!cg::is_static_var_shape(opr->output(0))) { if (!cg::is_static_var_shape(opr->output(0))) {
// currently we do not handle dynamic shape in JIT // currently we do not handle dynamic shape in JIT
return false; return false;
@@ -249,9 +249,9 @@ bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr,
} }
} }


InternalGraphGenrator* JITFusionPass::Impl::create_new_igraph_gen(
InternalGraphGenerator* JITFusionPass::Impl::create_new_igraph_gen(
OperatorNodeBase* opr) { OperatorNodeBase* opr) {
auto uptr = std::make_unique<InternalGraphGenrator>(opr);
auto uptr = std::make_unique<InternalGraphGenerator>(opr);
auto ptr = uptr.get(); auto ptr = uptr.get();
m_igraph_gen_storage.emplace_back(std::move(uptr)); m_igraph_gen_storage.emplace_back(std::move(uptr));
m_var2igraph_gen[opr->output(0)] = ptr; m_var2igraph_gen[opr->output(0)] = ptr;
@@ -267,7 +267,7 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
} }
// dimshuffle should not be an endpoint, because megbrain has lazy // dimshuffle should not be an endpoint, because megbrain has lazy
// dimshuffle machanism // dimshuffle machanism
InternalGraphGenrator* ig_gen = nullptr;
InternalGraphGenerator* ig_gen = nullptr;
if (m_var2igraph_gen.count(opr->output(0)) == 0) { if (m_var2igraph_gen.count(opr->output(0)) == 0) {
// because of the reverse traversal, when an operator is being // because of the reverse traversal, when an operator is being
// processed but not in m_var2igraph_gen, means it is a endpoint of a // processed but not in m_var2igraph_gen, means it is a endpoint of a


+ 8
- 8
src/jit/impl/internal_graph.cpp View File

@@ -81,12 +81,12 @@ InternalGraphPtr expand_executor_opr(const InternalGraphPtr& prev_igraph) {


} // namespace } // namespace


InternalGraphGenrator::InternalGraphGenrator(cg::OperatorNodeBase* opr)
InternalGraphGenerator::InternalGraphGenerator(cg::OperatorNodeBase* opr)
: m_output{opr->output(0)} { : m_output{opr->output(0)} {
add_opr(opr); add_opr(opr);
} }


VarNode* InternalGraphGenrator::replace_graph_by_placeholder() {
VarNode* InternalGraphGenerator::replace_graph_by_placeholder() {
ThinHashMap<VarNode*, VarNode*> old2new; ThinHashMap<VarNode*, VarNode*> old2new;
auto cpu_default = CompNode::default_cpu(); auto cpu_default = CompNode::default_cpu();
auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr,
@@ -163,7 +163,7 @@ VarNode* InternalGraphGenrator::replace_graph_by_placeholder() {
return old2new.at(m_output); return old2new.at(m_output);
} }


InternalGraphPtr InternalGraphGenrator::generate() {
InternalGraphPtr InternalGraphGenerator::generate() {
m_input_idx = 0; m_input_idx = 0;


auto new_nd = replace_graph_by_placeholder(); auto new_nd = replace_graph_by_placeholder();
@@ -172,7 +172,7 @@ InternalGraphPtr InternalGraphGenrator::generate() {
return expand_executor_opr(igraph); return expand_executor_opr(igraph);
} }


size_t InternalGraphGenrator::get_cnt_input_if_add(
size_t InternalGraphGenerator::get_cnt_input_if_add(
cg::OperatorNodeBase* opr) const { cg::OperatorNodeBase* opr) const {
// minus 1 first because this opr should be removed from subgraph's input // minus 1 first because this opr should be removed from subgraph's input
size_t new_cnt_input = m_graph_input_set.size() - 1; size_t new_cnt_input = m_graph_input_set.size() - 1;
@@ -183,7 +183,7 @@ size_t InternalGraphGenrator::get_cnt_input_if_add(
return new_cnt_input; return new_cnt_input;
} }


void InternalGraphGenrator::add_opr(cg::OperatorNodeBase* opr) {
void InternalGraphGenerator::add_opr(cg::OperatorNodeBase* opr) {
if (m_opr_set.count(opr)) { if (m_opr_set.count(opr)) {
// ignore duplicated oprs (which occur in tests) // ignore duplicated oprs (which occur in tests)
return; return;
@@ -253,7 +253,7 @@ void InternalGraphGenrator::add_opr(cg::OperatorNodeBase* opr) {
} }
} }


void InternalGraphGenrator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) {
void InternalGraphGenerator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) {
mgb_assert(opr->same_type<opr::Reduce>() || mgb_assert(opr->same_type<opr::Reduce>() ||
(opr->same_type<jit::JITExecutor>() && (opr->same_type<jit::JITExecutor>() &&
try_cast_as_op<jit::JITExecutor>(opr)->has_reduce())); try_cast_as_op<jit::JITExecutor>(opr)->has_reduce()));
@@ -264,7 +264,7 @@ void InternalGraphGenrator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) {
cg::DepOprIter{cb}.add(opr); cg::DepOprIter{cb}.add(opr);
} }


void InternalGraphGenrator::find_oprs_depended_by_dimshuffle(
void InternalGraphGenerator::find_oprs_depended_by_dimshuffle(
cg::OperatorNodeBase* dimshuffle) { cg::OperatorNodeBase* dimshuffle) {
mgb_assert( mgb_assert(
dimshuffle->same_type<opr::Dimshuffle>() || dimshuffle->same_type<opr::Dimshuffle>() ||
@@ -287,7 +287,7 @@ void InternalGraphGenrator::find_oprs_depended_by_dimshuffle(
cg::DepOprIter{cb}.add(dimshuffle); cg::DepOprIter{cb}.add(dimshuffle);
} }


PlaceholderArray InternalGraphGenrator::to_placeholder_opr_arr(
PlaceholderArray InternalGraphGenerator::to_placeholder_opr_arr(
const VarNodeArray& vars) { const VarNodeArray& vars) {
PlaceholderArray ret(vars.size()); PlaceholderArray ret(vars.size());
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {


+ 3
- 3
src/jit/include/megbrain/jit/internal_graph.h View File

@@ -76,12 +76,12 @@ private:
* This object stores intermediate state during visiting the computing graph in * This object stores intermediate state during visiting the computing graph in
* JITFusionPass. * JITFusionPass.
* *
* The graph is iterated in reverse topological order. InternalGraphGenrator
* The graph is iterated in reverse topological order. InternalGraphGenerator
* starts with a single operator (i.e. the output node of the fused opr), and * starts with a single operator (i.e. the output node of the fused opr), and
* new oprs are gradually added into it. Thus the process is expanding a tree * new oprs are gradually added into it. Thus the process is expanding a tree
* rooted at the output node. * rooted at the output node.
*/ */
class InternalGraphGenrator {
class InternalGraphGenerator {
//! replace oprs in the graph of m_output and populate m_orig_inps, //! replace oprs in the graph of m_output and populate m_orig_inps,
//! m_placeholders //! m_placeholders
VarNode* replace_graph_by_placeholder(); VarNode* replace_graph_by_placeholder();
@@ -95,7 +95,7 @@ class InternalGraphGenrator {
void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr); void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr);


public: public:
explicit InternalGraphGenrator(cg::OperatorNodeBase* opr);
explicit InternalGraphGenerator(cg::OperatorNodeBase* opr);


//! generate the graph; this method can be called multiple times //! generate the graph; this method can be called multiple times
InternalGraphPtr generate(); InternalGraphPtr generate();


+ 2
- 2
src/jit/test/codegen.cpp View File

@@ -54,7 +54,7 @@ void run<simple>(Backend backend, CompNode cn) {


VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()};
auto ig_gen = auto ig_gen =
std::make_unique<InternalGraphGenrator>(y.node()->owner_opr());
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());


for (auto i : get_rev_topo_order(y)) { for (auto i : get_rev_topo_order(y)) {
if (!i->same_type<opr::Host2DeviceCopy>()) { if (!i->same_type<opr::Host2DeviceCopy>()) {
@@ -91,7 +91,7 @@ void run<grad>(Backend backend, CompNode cn) {


VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()};
auto ig_gen = auto ig_gen =
std::make_unique<InternalGraphGenrator>(y.node()->owner_opr());
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());


for (auto i : get_rev_topo_order(y)) { for (auto i : get_rev_topo_order(y)) {
if (!i->same_type<opr::Host2DeviceCopy>()) { if (!i->same_type<opr::Host2DeviceCopy>()) {


+ 3
- 3
src/jit/test/fusion.cpp View File

@@ -540,7 +540,7 @@ void run<expand_jit_executor>(Backend backend, CompNode cn) {


auto make_jit = [](SymbolVar target, const SymbolVarArray& inputs) { auto make_jit = [](SymbolVar target, const SymbolVarArray& inputs) {
auto y = target.node(); auto y = target.node();
auto ig_gen = std::make_unique<InternalGraphGenrator>(y->owner_opr());
auto ig_gen = std::make_unique<InternalGraphGenerator>(y->owner_opr());
auto inputs_vptr = cg::to_var_node_array(inputs); auto inputs_vptr = cg::to_var_node_array(inputs);
for (auto i : get_rev_topo_order( for (auto i : get_rev_topo_order(
target, {inputs_vptr.begin(), inputs_vptr.end()})) { target, {inputs_vptr.begin(), inputs_vptr.end()})) {
@@ -830,9 +830,9 @@ TEST(TestJITFusionHalide, JITExecutor) {
y = opr::reduce_sum(a + b, shape_of_b), y = opr::reduce_sum(a + b, shape_of_b),
z = opr::reduce_sum(a * b, shape_of_a); z = opr::reduce_sum(a * b, shape_of_a);
auto ig_gen_1 = auto ig_gen_1 =
std::make_unique<InternalGraphGenrator>(y.node()->owner_opr());
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
auto ig_gen_2 = auto ig_gen_2 =
std::make_unique<InternalGraphGenrator>(z.node()->owner_opr());
std::make_unique<InternalGraphGenerator>(z.node()->owner_opr());
{ {
ThinHashSet<VarNode*> nd_set; ThinHashSet<VarNode*> nd_set;
nd_set.insert(a.node()); nd_set.insert(a.node());


+ 1
- 1
src/jit/test/helper.cpp View File

@@ -85,7 +85,7 @@ void FusionChecker::ensure_init_graph() {


SymbolVar jit_y; SymbolVar jit_y;
if (m_direct_build) { if (m_direct_build) {
auto ig_gen = std::make_unique<InternalGraphGenrator>(
auto ig_gen = std::make_unique<InternalGraphGenerator>(
m_truth_y.node()->owner_opr()); m_truth_y.node()->owner_opr());
ThinHashSet<VarNode*> endpoints_set; ThinHashSet<VarNode*> endpoints_set;
for (size_t i = 0; i < m_nr_input; ++i) { for (size_t i = 0; i < m_nr_input; ++i) {


Loading…
Cancel
Save