Browse Source

fix(imperative): catch python exception in c++

GitOrigin-RevId: 16a2abfdad
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
9fb8444d24
2 changed files with 10 additions and 9 deletions
  1. +6
    -4
      imperative/python/src/tensor.cpp
  2. +4
    -5
      imperative/python/src/trace.cpp

+ 6
- 4
imperative/python/src/tensor.cpp View File

@@ -240,10 +240,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
pyf = cpp_apply_const_with_tracing;
}

auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, tup.ptr(), nullptr));
auto py_ret = py::reinterpret_borrow<py::list>(ret);
if (auto* t = try_cast(py_ret[0].ptr())) {
auto py_ret = PyObject_Call(pyf, tup.ptr(), nullptr);
if (!py_ret) throw py::error_already_set();
auto py_list = py::reinterpret_steal<py::list>(py_ret);
if (auto* t = try_cast(py_list[0].ptr())) {
m_tensor = t->m_tensor;
}
return;
@@ -389,6 +389,7 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() {
if (m_tensor->m_trace_info.compiled_info != nullptr) {
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
if (!np_val) throw py::error_already_set();
if (np_val == Py_None) {
throw TraceReadError("value of this tensor is not read in trace");
}
@@ -478,6 +479,7 @@ PyObject* TensorWrapper::detach() {
PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.compiled_info != nullptr) {
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
if (!dev_tensor) throw py::error_already_set();
if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace");
}


+ 4
- 5
imperative/python/src/trace.cpp View File

@@ -31,9 +31,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
}
py::object ret = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
if (!ret) {
throw py::value_error("invalid py object call");
}
if (!ret) throw py::error_already_set();

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
@@ -58,8 +56,9 @@ apply_result_t apply_trace(ApplyContext& ctx) {
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, args.ptr(), nullptr));
auto pyout = PyObject_Call(pyf, args.ptr(), nullptr);
if (!pyout) throw py::error_already_set();
auto ret = py::reinterpret_steal<py::object>(pyout);

// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);


Loading…
Cancel
Save