GitOrigin-RevId: 16e86a21d7
tags/v1.0.0-rc1
@@ -254,6 +254,7 @@ class trace: | |||||
def _compile(self): | def _compile(self): | ||||
graph = self._graph = G.Graph() | graph = self._graph = G.Graph() | ||||
graph.options.no_force_inplace = True | |||||
# graph.options.graph_opt_level = 0 | # graph.options.graph_opt_level = 0 | ||||
need_reset_nodes = self._need_reset_nodes = [] | need_reset_nodes = self._need_reset_nodes = [] | ||||
# links enforce ordering of I/O nodes | # links enforce ordering of I/O nodes | ||||
@@ -105,6 +105,7 @@ void init_graph_rt(py::module m) { | |||||
DEF_READWRITE(enable_grad_var_static_reshape) | DEF_READWRITE(enable_grad_var_static_reshape) | ||||
DEF_READWRITE(enable_memory_swap) | DEF_READWRITE(enable_memory_swap) | ||||
DEF_READWRITE(comp_node_seq_record_level) | DEF_READWRITE(comp_node_seq_record_level) | ||||
DEF_READWRITE(no_force_inplace) | |||||
// DEF_READWRITE(eager_evaluation) | // DEF_READWRITE(eager_evaluation) | ||||
// DEF_READWRITE(imperative_proxy_graph) | // DEF_READWRITE(imperative_proxy_graph) | ||||
// DEF_READWRITE(extra_vardeps) | // DEF_READWRITE(extra_vardeps) | ||||
@@ -81,6 +81,18 @@ public: | |||||
return m_graph; | return m_graph; | ||||
} | } | ||||
bool is_same_st(const Hashable& rhs) const override { | |||||
if (!rhs.same_type<BackwardGraph>()) { | |||||
return false; | |||||
} | |||||
auto& other = rhs.cast_final_safe<BackwardGraph>(); | |||||
if (this == &other) { | |||||
return true; | |||||
} | |||||
// FIXME | |||||
return false; | |||||
} | |||||
private: | private: | ||||
InternalGraph m_graph; | InternalGraph m_graph; | ||||
}; | }; | ||||