From 0537cb7471bedce13400e9b70e5adaaaffe77f8f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 31 Aug 2020 21:37:27 +0800 Subject: [PATCH] chore(mge/imperative): fix BackwardGraph for jit.trace GitOrigin-RevId: 16e86a21d75ab467541b625289478c88ff175e1c --- imperative/python/megengine/jit/tracing.py | 1 + imperative/python/src/graph_rt.cpp | 1 + .../src/include/megbrain/imperative/ops/backward_graph.h | 12 ++++++++++++ 3 files changed, 14 insertions(+) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 01819de1..e3c20a5d 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -254,6 +254,7 @@ class trace: def _compile(self): graph = self._graph = G.Graph() + graph.options.no_force_inplace = True # graph.options.graph_opt_level = 0 need_reset_nodes = self._need_reset_nodes = [] # links enforce ordering of I/O nodes diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index caabf439..a021c2b7 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -105,6 +105,7 @@ void init_graph_rt(py::module m) { DEF_READWRITE(enable_grad_var_static_reshape) DEF_READWRITE(enable_memory_swap) DEF_READWRITE(comp_node_seq_record_level) + DEF_READWRITE(no_force_inplace) // DEF_READWRITE(eager_evaluation) // DEF_READWRITE(imperative_proxy_graph) // DEF_READWRITE(extra_vardeps) diff --git a/imperative/src/include/megbrain/imperative/ops/backward_graph.h b/imperative/src/include/megbrain/imperative/ops/backward_graph.h index e1d88768..cec59457 100644 --- a/imperative/src/include/megbrain/imperative/ops/backward_graph.h +++ b/imperative/src/include/megbrain/imperative/ops/backward_graph.h @@ -81,6 +81,18 @@ public: return m_graph; } + bool is_same_st(const Hashable& rhs) const override { + if (!rhs.same_type()) { + return false; + } + auto& other = rhs.cast_final_safe(); + if (this == &other) { + return true; + } + // FIXME + return false; + } + private: InternalGraph m_graph; };