|
|
@@ -39,6 +39,14 @@ py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, |
|
|
|
|
|
|
|
py::object cpp_apply_backward_varnode; |
|
|
|
|
|
|
|
void release_trace_apply_func(){ |
|
|
|
cpp_apply_with_tracing.release(); |
|
|
|
cpp_apply_const_with_tracing.release(); |
|
|
|
cpp_apply_compiled_mode.release(); |
|
|
|
cpp_apply_const_compiled_mode.release(); |
|
|
|
cpp_apply_backward_varnode.release(); |
|
|
|
} |
|
|
|
|
|
|
|
#define REGISTE_APPLY_FUNC(mode) \ |
|
|
|
void set_##mode(py::object pyf) { \ |
|
|
|
mode = pybind11::reinterpret_steal<py::object>(pyf); \ |
|
|
@@ -720,6 +728,8 @@ void init_tensor(py::module m) { |
|
|
|
py_task_q.wait_all_task_finish(); |
|
|
|
}, |
|
|
|
py::call_guard<py::gil_scoped_release>()); |
|
|
|
|
|
|
|
m.def("release_trace_apply_func", &release_trace_apply_func); |
|
|
|
|
|
|
|
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() |
|
|
|
.def<&GradKeyWrapper::attach>("attach") |
|
|
|