Browse Source

fix(mge): fix sublnear cuda and mem leak

GitOrigin-RevId: 82091ec9a6
release-1.2
Megvii Engine Team 4 years ago
parent
commit
23b9a98f5e
4 changed files with 21 additions and 10 deletions
  1. +7
    -2
      imperative/python/megengine/jit/tracing.py
  2. +9
    -2
      imperative/python/src/tensor.cpp
  3. +2
    -3
      imperative/python/src/trace.cpp
  4. +3
    -3
      imperative/python/src/trace_info.h

+ 7
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -350,12 +350,16 @@ class trace:
self._lazy_eval_links = () self._lazy_eval_links = ()


def _take_escaped_tensors(self): def _take_escaped_tensors(self):
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values()))
escaped_tensors = tuple(
filter(lambda x: x() is not None, self._active_tensors.values())
)
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors return escaped_tensors


def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values()))
lazy_eval_tensors = list(
filter(lambda x: x() is not None, lazy_eval_tensors.values())
)
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph) self._apply_graph_options(lazy_eval_graph)
# FIXME # FIXME
@@ -443,6 +447,7 @@ class trace:
x()._reset_varnode() x()._reset_varnode()
x().mixin_handle = -1 x().mixin_handle = -1
x().recording = False x().recording = False
x()._trace_mixin_info = None


try: try:
do_enter() do_enter()


+ 9
- 2
imperative/python/src/tensor.cpp View File

@@ -294,8 +294,13 @@ PyObject* TensorWrapper::copied() {
return m_tensor->m_trace_info.member; \ return m_tensor->m_trace_info.member; \
} \ } \
void TensorWrapper::set_##member(PyObject* dest) { \ void TensorWrapper::set_##member(PyObject* dest) { \
Py_INCREF(dest); \
m_tensor->m_trace_info.member = dest; \
if (dest == Py_None) { \
Py_XDECREF(m_tensor->m_trace_info.member); \
m_tensor->m_trace_info.member = nullptr; \
} else { \
Py_INCREF(dest); \
m_tensor->m_trace_info.member = dest; \
} \
} }


REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
@@ -463,6 +468,8 @@ PyObject* TensorWrapper::_dev_tensor(){
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh)); m_tensor->m_handle = std::move(SharedHandle(sh));
Py_DECREF(m_tensor->m_trace_info.compiled_info);
m_tensor->m_trace_info.compiled_info = nullptr;


return dev_tensor; return dev_tensor;
} }


+ 2
- 3
imperative/python/src/trace.cpp View File

@@ -55,10 +55,9 @@ apply_result_t apply_trace(ApplyContext& ctx) {


auto args = py::tuple(ctx.nargs + 1); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op); args[0] = py::cast(ctx.op);
py::tuple args(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) { for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(
std::move(std::shared_ptr<Tensor>(ctx.args[i])))
.release();
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
} }
auto ret = py::reinterpret_steal<py::object>( auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, args.ptr(), nullptr)); PyObject_Call(pyf, args.ptr(), nullptr));


+ 3
- 3
imperative/python/src/trace_info.h View File

@@ -28,10 +28,10 @@ struct TraceInfo {
mixin_handle = that.mixin_handle; mixin_handle = that.mixin_handle;
recording = that.recording; recording = that.recording;


compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);
trace_mixin_info = that.trace_mixin_info; trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info); Py_XINCREF(trace_mixin_info);
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);


copied = true; copied = true;
return *this; return *this;
@@ -39,7 +39,7 @@ struct TraceInfo {


~TraceInfo() { ~TraceInfo() {
Py_XDECREF(trace_mixin_info); Py_XDECREF(trace_mixin_info);
// Py_XDECREF(compiled_info);
Py_XDECREF(compiled_info);
} }


private: private:


Loading…
Cancel
Save