Browse Source

chore(mge/imperative): fix BackwardGraph for jit.trace

GitOrigin-RevId: 16e86a21d7
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
0537cb7471
3 changed files with 14 additions and 0 deletions
  1. +1
    -0
      imperative/python/megengine/jit/tracing.py
  2. +1
    -0
      imperative/python/src/graph_rt.cpp
  3. +12
    -0
      imperative/src/include/megbrain/imperative/ops/backward_graph.h

+ 1
- 0
imperative/python/megengine/jit/tracing.py View File

@@ -254,6 +254,7 @@ class trace:

def _compile(self):
graph = self._graph = G.Graph()
graph.options.no_force_inplace = True
# graph.options.graph_opt_level = 0
need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes


+ 1
- 0
imperative/python/src/graph_rt.cpp View File

@@ -105,6 +105,7 @@ void init_graph_rt(py::module m) {
DEF_READWRITE(enable_grad_var_static_reshape)
DEF_READWRITE(enable_memory_swap)
DEF_READWRITE(comp_node_seq_record_level)
DEF_READWRITE(no_force_inplace)
// DEF_READWRITE(eager_evaluation)
// DEF_READWRITE(imperative_proxy_graph)
// DEF_READWRITE(extra_vardeps)


+ 12
- 0
imperative/src/include/megbrain/imperative/ops/backward_graph.h View File

@@ -81,6 +81,18 @@ public:
return m_graph;
}

bool is_same_st(const Hashable& rhs) const override {
if (!rhs.same_type<BackwardGraph>()) {
return false;
}
auto& other = rhs.cast_final_safe<BackwardGraph>();
if (this == &other) {
return true;
}
// FIXME
return false;
}

private:
InternalGraph m_graph;
};


Loading…
Cancel
Save