Browse Source

fix(mge): fix cpp trace function release

GitOrigin-RevId: 73f9642821
release-1.2
Megvii Engine Team 4 years ago
parent
commit
243a05b410
2 changed files with 13 additions and 1 deletions
  1. +3
    -1
      imperative/python/megengine/__init__.py
  2. +10
    -0
      imperative/python/src/tensor.cpp

+ 3
- 1
imperative/python/megengine/__init__.py View File

@@ -71,7 +71,7 @@ if sys.platform == "win32":

kernel32.SetErrorMode(old_error_mode)

from .core._imperative_rt.core2 import sync
from .core._imperative_rt.core2 import sync, release_trace_apply_func
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
@@ -90,7 +90,9 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()

atexit.register(sync)
atexit.register(release_trace_apply_func)

del sync
del release_trace_apply_func
del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins

+ 10
- 0
imperative/python/src/tensor.cpp View File

@@ -39,6 +39,14 @@ py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing,

py::object cpp_apply_backward_varnode;

void release_trace_apply_func(){
cpp_apply_with_tracing.release();
cpp_apply_const_with_tracing.release();
cpp_apply_compiled_mode.release();
cpp_apply_const_compiled_mode.release();
cpp_apply_backward_varnode.release();
}

#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
mode = pybind11::reinterpret_steal<py::object>(pyf); \
@@ -720,6 +728,8 @@ void init_tensor(py::module m) {
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
m.def("release_trace_apply_func", &release_trace_apply_func);

py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")


Loading…
Cancel
Save