Browse Source

fix(imperative): check async error when getting value

GitOrigin-RevId: 3945a9bfa2
tags/v1.7.0.m1
Megvii Engine Team 3 years ago
parent
commit
57e197475b
3 changed files with 11 additions and 0 deletions
  1. +1
    -0
      imperative/python/megengine/functional/vision.py
  2. +4
    -0
      imperative/python/src/tensor.cpp
  3. +6
    -0
      imperative/src/impl/interpreter/interpreter_impl.cpp

+ 1
- 0
imperative/python/megengine/functional/vision.py View File

@@ -420,6 +420,7 @@ def warp_affine(
Here all available options for params are listed, Here all available options for params are listed,
however it does not mean that you can use all the combinations. however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported. On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
""" """
conv_format = _config._get_actual_op_param(format, _config.__conv_format) conv_format = _config._get_actual_op_param(format, _config.__conv_format)




+ 4
- 0
imperative/python/src/tensor.cpp View File

@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) {
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
CompNode::sync_all(); CompNode::sync_all();
CompNode::foreach ([](CompNode cn) {
auto err = cn.check_async_error();
mgb_assert(!err, "%s", err->what());
});
sync_py_task_q(); sync_py_task_q();
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());


+ 6
- 0
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
if (m_async_level == 0) { if (m_async_level == 0) {
sync_impl(); sync_impl();
info->desc.comp_node.sync(); info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error();
mgb_assert(!err, "%s", err->what());
} }
return info; return info;
} }
@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel(
for (auto&& oup : *outputs) { for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup); auto info = reinterpret_cast<TensorInfo*>(oup);
info->ptr->comp_node().sync(); info->ptr->comp_node().sync();
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
} }
} }
} }
@@ -944,6 +948,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}); });
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr; m_waitee = nullptr;
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
return info->ptr; return info->ptr;
} }




Loading…
Cancel
Save