@@ -350,12 +350,16 @@ class trace: | |||||
self._lazy_eval_links = () | self._lazy_eval_links = () | ||||
def _take_escaped_tensors(self): | def _take_escaped_tensors(self): | ||||
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values())) | |||||
escaped_tensors = tuple( | |||||
filter(lambda x: x() is not None, self._active_tensors.values()) | |||||
) | |||||
self._active_tensors.clear() | self._active_tensors.clear() | ||||
return escaped_tensors | return escaped_tensors | ||||
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | ||||
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values())) | |||||
lazy_eval_tensors = list( | |||||
filter(lambda x: x() is not None, lazy_eval_tensors.values()) | |||||
) | |||||
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | ||||
self._apply_graph_options(lazy_eval_graph) | self._apply_graph_options(lazy_eval_graph) | ||||
# FIXME | # FIXME | ||||
@@ -443,6 +447,7 @@ class trace: | |||||
x()._reset_varnode() | x()._reset_varnode() | ||||
x().mixin_handle = -1 | x().mixin_handle = -1 | ||||
x().recording = False | x().recording = False | ||||
x()._trace_mixin_info = None | |||||
try: | try: | ||||
do_enter() | do_enter() | ||||
@@ -294,8 +294,13 @@ PyObject* TensorWrapper::copied() { | |||||
return m_tensor->m_trace_info.member; \ | return m_tensor->m_trace_info.member; \ | ||||
} \ | } \ | ||||
void TensorWrapper::set_##member(PyObject* dest) { \ | void TensorWrapper::set_##member(PyObject* dest) { \ | ||||
Py_INCREF(dest); \ | |||||
m_tensor->m_trace_info.member = dest; \ | |||||
if (dest == Py_None) { \ | |||||
Py_XDECREF(m_tensor->m_trace_info.member); \ | |||||
m_tensor->m_trace_info.member = nullptr; \ | |||||
} else { \ | |||||
Py_INCREF(dest); \ | |||||
m_tensor->m_trace_info.member = dest; \ | |||||
} \ | |||||
} | } | ||||
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) | REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) | ||||
@@ -463,6 +468,8 @@ PyObject* TensorWrapper::_dev_tensor(){ | |||||
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); | auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); | ||||
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); | auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); | ||||
m_tensor->m_handle = std::move(SharedHandle(sh)); | m_tensor->m_handle = std::move(SharedHandle(sh)); | ||||
Py_DECREF(m_tensor->m_trace_info.compiled_info); | |||||
m_tensor->m_trace_info.compiled_info = nullptr; | |||||
return dev_tensor; | return dev_tensor; | ||||
} | } | ||||
@@ -55,10 +55,9 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
auto args = py::tuple(ctx.nargs + 1); | auto args = py::tuple(ctx.nargs + 1); | ||||
args[0] = py::cast(ctx.op); | args[0] = py::cast(ctx.op); | ||||
py::tuple args(ctx.nargs); | |||||
for (size_t i = 0; i < ctx.nargs; i++) { | for (size_t i = 0; i < ctx.nargs; i++) { | ||||
args[i + 1] = TensorWrapper::make( | |||||
std::move(std::shared_ptr<Tensor>(ctx.args[i]))) | |||||
.release(); | |||||
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||||
} | } | ||||
auto ret = py::reinterpret_steal<py::object>( | auto ret = py::reinterpret_steal<py::object>( | ||||
PyObject_Call(pyf, args.ptr(), nullptr)); | PyObject_Call(pyf, args.ptr(), nullptr)); | ||||
@@ -28,10 +28,10 @@ struct TraceInfo { | |||||
mixin_handle = that.mixin_handle; | mixin_handle = that.mixin_handle; | ||||
recording = that.recording; | recording = that.recording; | ||||
compiled_info = that.compiled_info; | |||||
Py_XINCREF(compiled_info); | |||||
trace_mixin_info = that.trace_mixin_info; | trace_mixin_info = that.trace_mixin_info; | ||||
Py_XINCREF(trace_mixin_info); | Py_XINCREF(trace_mixin_info); | ||||
compiled_info = that.compiled_info; | |||||
Py_XINCREF(compiled_info); | |||||
copied = true; | copied = true; | ||||
return *this; | return *this; | ||||
@@ -39,7 +39,7 @@ struct TraceInfo { | |||||
~TraceInfo() { | ~TraceInfo() { | ||||
Py_XDECREF(trace_mixin_info); | Py_XDECREF(trace_mixin_info); | ||||
// Py_XDECREF(compiled_info); | |||||
Py_XDECREF(compiled_info); | |||||
} | } | ||||
private: | private: | ||||