@@ -72,7 +72,6 @@ if sys.platform == "win32": | |||
kernel32.SetErrorMode(old_error_mode) | |||
from .core._imperative_rt.core2 import full_sync as _full_sync | |||
from .core._imperative_rt.core2 import release_trace_apply_func | |||
from .core._imperative_rt.core2 import sync as _sync | |||
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | |||
from .device import * | |||
@@ -92,9 +91,7 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||
_persistent_cache_impl_ins.reg() | |||
atexit.register(_full_sync) | |||
atexit.register(release_trace_apply_func) | |||
del release_trace_apply_func | |||
del _set_fork_exec_path_for_timed_func | |||
del _persistent_cache_impl_ins | |||
@@ -34,22 +34,15 @@ namespace mgb::imperative::python { | |||
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||
py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, | |||
cpp_apply_compiled_mode, cpp_apply_const_compiled_mode; | |||
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing, | |||
*cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode; | |||
py::object cpp_apply_backward_varnode; | |||
PyObject *cpp_apply_backward_varnode; | |||
void release_trace_apply_func(){ | |||
cpp_apply_with_tracing.release(); | |||
cpp_apply_const_with_tracing.release(); | |||
cpp_apply_compiled_mode.release(); | |||
cpp_apply_const_compiled_mode.release(); | |||
cpp_apply_backward_varnode.release(); | |||
} | |||
#define REGISTE_APPLY_FUNC(mode) \ | |||
void set_##mode(py::object pyf) { \ | |||
mode = pybind11::reinterpret_steal<py::object>(pyf); \ | |||
mode = pyf.ptr(); \ | |||
} | |||
REGISTE_APPLY_FUNC(cpp_apply_with_tracing) | |||
@@ -242,14 +235,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
// const op | |||
if (is_const && is_tracing) { | |||
py::object pyf; | |||
PyObject *pyf; | |||
if (is_compiled) { | |||
pyf = cpp_apply_const_compiled_mode; | |||
} else { | |||
pyf = cpp_apply_const_with_tracing; | |||
} | |||
auto ret = pyf(*tup); | |||
auto ret = py::reinterpret_steal<py::object>( | |||
PyObject_Call(pyf, tup.ptr(), nullptr)); | |||
auto py_ret = py::reinterpret_borrow<py::list>(ret); | |||
if (auto* t = try_cast(py_ret[0].ptr())) { | |||
m_tensor = t->m_tensor; | |||
@@ -744,8 +738,6 @@ void init_tensor(py::module m) { | |||
}, | |||
py::call_guard<py::gil_scoped_release>()); | |||
m.def("release_trace_apply_func", &release_trace_apply_func); | |||
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() | |||
.def<&GradKeyWrapper::attach>("attach") | |||
.def<&GradKeyWrapper::is_attached_to>("is_attached_to") | |||
@@ -253,8 +253,8 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||
void init_tensor(pybind11::module); | |||
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; | |||
extern pybind11::object cpp_apply_backward_varnode; | |||
extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode; | |||
extern PyObject *cpp_apply_backward_varnode; | |||
} // namespace mgb::imperative::python | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./trace.h" | |||
@@ -23,12 +24,13 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
if (ctx.backward) { | |||
// reach here when symbolic=True or compiled=True | |||
// call megbrain_graph.py apply(BackwardGraph, *args) | |||
auto args = py::tuple(ctx.nargs); | |||
auto args = py::tuple(ctx.nargs + 1); | |||
args[0] = py::cast(ctx.op); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
args[i] = py::cast(ctx.args[i]->m_var); | |||
args[i + 1] = py::cast(ctx.args[i]->m_var); | |||
} | |||
py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args); | |||
py::object ret = py::reinterpret_steal<py::object>( | |||
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); | |||
if (!ret) { | |||
throw py::value_error("invalid py object call"); | |||
} | |||
@@ -36,13 +38,13 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||
for (auto i = 0; i < tup.size(); i++) { | |||
auto pitem = tup[i].cast<cg::VarNode *>(); | |||
auto pitem = tup[i].cast<cg::VarNode*>(); | |||
outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||
} | |||
return outputs; | |||
} | |||
py::object pyf; | |||
PyObject* pyf; | |||
if (is_compiled) { | |||
// run apply in compiled mode, step 2, 3, etc | |||
pyf = cpp_apply_compiled_mode; | |||
@@ -51,11 +53,15 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
pyf = cpp_apply_with_tracing; | |||
} | |||
auto args = py::tuple(ctx.nargs); | |||
auto args = py::tuple(ctx.nargs + 1); | |||
args[0] = py::cast(ctx.op); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release(); | |||
args[i + 1] = TensorWrapper::make( | |||
std::move(std::shared_ptr<Tensor>(ctx.args[i]))) | |||
.release(); | |||
} | |||
auto ret = pyf(py::cast(ctx.op), *args); | |||
auto ret = py::reinterpret_steal<py::object>( | |||
PyObject_Call(pyf, args.ptr(), nullptr)); | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||