GitOrigin-RevId: af5426c37d
release-1.6
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include <exception> | |||||
#include <stdexcept> | #include <stdexcept> | ||||
#include <vector> | #include <vector> | ||||
#include <utility> | #include <utility> | ||||
@@ -69,10 +70,43 @@ inline int cvt_retint(int ret) { | |||||
struct py_err_set : std::exception {}; | struct py_err_set : std::exception {}; | ||||
#define HANDLE_ALL_EXC(RET) catch(py_err_set&) {return RET;} \ | |||||
catch(pybind11::error_already_set& e) {e.restore(); return RET;} \ | |||||
catch(pybind11::builtin_exception& e) {e.set_error(); return RET;} \ | |||||
catch(std::exception& e) {PyErr_SetString(PyExc_RuntimeError, e.what()); return RET;} | |||||
// refer to pybind11 for the following exception handling helper | |||||
inline void pybind11_translate_exception(std::exception_ptr last_exception) { | |||||
auto ®istered_exception_translators = pybind11::detail::get_internals().registered_exception_translators; | |||||
for (auto& translator : registered_exception_translators) { | |||||
try { | |||||
translator(last_exception); | |||||
} catch (...) { | |||||
last_exception = std::current_exception(); | |||||
continue; | |||||
} | |||||
return; | |||||
} | |||||
PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!"); | |||||
} | |||||
inline void pybind11_translate_exception() { | |||||
pybind11_translate_exception(std::current_exception()); | |||||
} | |||||
#if defined(__GNUG__) && !defined(__clang__) | |||||
#define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND catch (::abi::__forced_unwind&) {throw;} | |||||
#else | |||||
#define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND | |||||
#endif | |||||
#define PYEXT17_TRANSLATE_EXC \ | |||||
catch(::pyext17::py_err_set&) {} \ | |||||
catch(::pybind11::error_already_set& e) {e.restore();} \ | |||||
PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \ | |||||
catch(...) {::pyext17::pybind11_translate_exception();} | |||||
#define PYEXT17_TRANSLATE_EXC_RET(RET) \ | |||||
catch(::pyext17::py_err_set&) {return RET;} \ | |||||
catch(::pybind11::error_already_set& e) {e.restore(); return RET;} \ | |||||
PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \ | |||||
catch(...) {::pyext17::pybind11_translate_exception(); return RET;}; | |||||
template <typename T> | template <typename T> | ||||
struct wrap { | struct wrap { | ||||
@@ -134,7 +168,7 @@ private: | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
try { | try { | ||||
CVT_RET_PYOBJ((inst->*f)()); | CVT_RET_PYOBJ((inst->*f)()); | ||||
} HANDLE_ALL_EXC(nullptr) | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
}; | }; | ||||
@@ -146,7 +180,7 @@ private: | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
try { | try { | ||||
CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | ||||
} HANDLE_ALL_EXC(nullptr) | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
}; | }; | ||||
@@ -159,7 +193,7 @@ private: | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
try { | try { | ||||
CVT_RET_PYOBJ((inst->*f)(args, nargs)); | CVT_RET_PYOBJ((inst->*f)(args, nargs)); | ||||
} HANDLE_ALL_EXC(nullptr) | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
#else | #else | ||||
static constexpr int flags = METH_VARARGS; | static constexpr int flags = METH_VARARGS; | ||||
@@ -170,7 +204,7 @@ private: | |||||
auto size = PyTuple_GET_SIZE(args); | auto size = PyTuple_GET_SIZE(args); | ||||
try { | try { | ||||
CVT_RET_PYOBJ((inst->*f)(arr, size)); | CVT_RET_PYOBJ((inst->*f)(arr, size)); | ||||
} HANDLE_ALL_EXC(nullptr) | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
#endif | #endif | ||||
}; | }; | ||||
@@ -183,7 +217,7 @@ private: | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
try { | try { | ||||
CVT_RET_PYOBJ((inst->*f)(obj)); | CVT_RET_PYOBJ((inst->*f)(obj)); | ||||
} HANDLE_ALL_EXC(nullptr) | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
}; | }; | ||||
@@ -209,7 +243,7 @@ private: | |||||
} else { | } else { | ||||
static_assert(!std::is_same_v<F, F>); | static_assert(!std::is_same_v<F, F>); | ||||
} | } | ||||
} HANDLE_ALL_EXC(nullptr) | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
}; | }; | ||||
@@ -230,7 +264,7 @@ private: | |||||
} else { | } else { | ||||
static_assert(!std::is_same_v<F, F>); | static_assert(!std::is_same_v<F, F>); | ||||
} | } | ||||
} HANDLE_ALL_EXC(-1) | |||||
} PYEXT17_TRANSLATE_EXC_RET(-1) | |||||
} | } | ||||
static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr; | static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr; | ||||
@@ -314,7 +348,7 @@ private: | |||||
} else { | } else { | ||||
new(inst) T(); | new(inst) T(); | ||||
} | } | ||||
} HANDLE_ALL_EXC(nullptr) | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
free_guard.self = nullptr; | free_guard.self = nullptr; | ||||
return self; | return self; | ||||
} | } | ||||
@@ -464,7 +498,7 @@ public: | |||||
new(inst) T(std::forward<Args>(args)...); | new(inst) T(std::forward<Args>(args)...); | ||||
return self; | return self; | ||||
} | } | ||||
struct caster { | struct caster { | ||||
static constexpr auto name = T::tp_name; | static constexpr auto name = T::tp_name; | ||||
@@ -493,4 +527,3 @@ public: | |||||
#undef HAS_MEMBER | #undef HAS_MEMBER | ||||
#undef CVT_RET_PYOBJ | #undef CVT_RET_PYOBJ | ||||
#undef CVT_RET_INT | #undef CVT_RET_INT | ||||
#undef HANDLE_ALL_EXC |
@@ -26,8 +26,11 @@ | |||||
#include "./graph_rt.h" | #include "./graph_rt.h" | ||||
#include "./helper.h" | #include "./helper.h" | ||||
#include <object.h> | |||||
#include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
#include <pybind11/operators.h> | #include <pybind11/operators.h> | ||||
#include <pybind11/pytypes.h> | |||||
#include <pyerrors.h> | |||||
#include <range/v3/all.hpp> | #include <range/v3/all.hpp> | ||||
#include <string> | #include <string> | ||||
@@ -230,10 +233,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); | ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); | ||||
} | } | ||||
return ret.release().ptr(); | return ret.release().ptr(); | ||||
} catch (std::exception& e) { | |||||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
return nullptr; | |||||
} | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
@@ -391,7 +391,7 @@ void TensorWrapper::set_handle(PyObject* dest) { | |||||
PyObject* TensorWrapper::shape() { | PyObject* TensorWrapper::shape() { | ||||
// if it's tracing compiled mode, get value from compiled_info | |||||
// if it's tracing compiled mode, get value from compiled_info | |||||
if (m_tensor->m_trace_info.compiled_info != nullptr) { | if (m_tensor->m_trace_info.compiled_info != nullptr) { | ||||
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | ||||
return PyTuple_New(0); | return PyTuple_New(0); | ||||
@@ -821,10 +821,7 @@ PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
try { | try { | ||||
PyArray_Descr* res = _dtype_promotion(args, nargs); | PyArray_Descr* res = _dtype_promotion(args, nargs); | ||||
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | ||||
} catch (std::exception& e) { | |||||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
return nullptr; | |||||
} | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { | PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { | ||||
@@ -835,10 +832,7 @@ PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
try { | try { | ||||
CompNode cn = _get_device(args, nargs); | CompNode cn = _get_device(args, nargs); | ||||
return py::cast(cn).release().ptr(); | return py::cast(cn).release().ptr(); | ||||
} catch (std::exception& e) { | |||||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
return nullptr; | |||||
} | |||||
} PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | } | ||||
#ifdef METH_FASTCALL | #ifdef METH_FASTCALL | ||||
@@ -865,6 +859,34 @@ void init_tensor(py::module m) { | |||||
static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | ||||
interpreter_for_py = sl_interpreter_for_py.get(); | interpreter_for_py = sl_interpreter_for_py.get(); | ||||
static py::exception<interpreter::AsyncError> py_async_error(m, "AsyncError", PyExc_RuntimeError); | |||||
py::register_exception_translator([](std::exception_ptr p) { | |||||
try { | |||||
if (p) std::rethrow_exception(p); | |||||
} catch (const interpreter::AsyncError& e) { | |||||
pyext17::pybind11_translate_exception(e.nested_ptr()); | |||||
if (PyErr_Occurred()) { | |||||
PyObject *exc, *val, *tb; | |||||
PyErr_Fetch(&exc, &val, &tb); | |||||
PyErr_NormalizeException(&exc, &val, &tb); | |||||
if (tb) { | |||||
PyException_SetTraceback(val, tb); | |||||
} | |||||
auto val2 = py_async_error.py::object::operator()( | |||||
"An async error is reported. See above for the actual cause." | |||||
" Hint: This is where it is reported, not where it happened." | |||||
" You may call `megengine.core.set_option('async_level', 0)` to get better error reporting." | |||||
); | |||||
PyException_SetCause(val2.ptr(), val); // PyException_SetCause steals reference | |||||
Py_XDECREF(exc); | |||||
Py_XDECREF(tb); | |||||
PyErr_Restore(py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr); | |||||
} else { | |||||
py_async_error("Unkown async error"); | |||||
} | |||||
} | |||||
}); | |||||
auto* tensor_type = TensorWrapper::wrap_t::type() | auto* tensor_type = TensorWrapper::wrap_t::type() | ||||
.def<&TensorWrapper::numpy>("numpy") | .def<&TensorWrapper::numpy>("numpy") | ||||
.def_getset<&TensorWrapper::shape>("shape") | .def_getset<&TensorWrapper::shape>("shape") | ||||
@@ -932,7 +954,7 @@ void init_tensor(py::module m) { | |||||
if (v->is_scalar) { | if (v->is_scalar) { | ||||
return py::object(py::array(np_val).squeeze()); | return py::object(py::array(np_val).squeeze()); | ||||
} | } | ||||
return np_val; | |||||
return np_val; | |||||
}) | }) | ||||
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | ||||
@@ -7,6 +7,7 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine.core._imperative_rt.core2 import ( | from megengine.core._imperative_rt.core2 import ( | ||||
AsyncError, | |||||
_set_drop_flag, | _set_drop_flag, | ||||
_set_swap_flag, | _set_swap_flag, | ||||
config_async_level, | config_async_level, | ||||
@@ -98,3 +99,25 @@ def test_regression_2870(): | |||||
with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
y.numpy() | y.numpy() | ||||
(x + x).numpy() | (x + x).numpy() | ||||
# NOTE: DO NOT REMOVE THIS TEST | |||||
# This is also a compatibility test for | |||||
# mge.core.set_option('async_level', 0). | |||||
# If you change the canonical API to set async level, | |||||
# update the error message of AsyncError as well. | |||||
def test_async_error(): | |||||
orig_lvl = mge.core.get_option("async_level") | |||||
try: | |||||
mge.core.set_option("async_level", 1) | |||||
x = F.utils._simulate_error() | |||||
try: | |||||
x.numpy() | |||||
except AsyncError as e: | |||||
assert isinstance(e.__cause__, RuntimeError) | |||||
mge.core.set_option("async_level", 0) | |||||
with pytest.raises(RuntimeError): | |||||
F.utils._simulate_error() | |||||
finally: | |||||
mge.core.set_option("async_level", orig_lvl) |