Browse Source

fix(imperative): check async error when getting value

GitOrigin-RevId: 52b8a29932
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
619d78ed86
5 changed files with 37 additions and 8 deletions
  1. +1
    -0
      imperative/python/megengine/functional/vision.py
  2. +13
    -7
      imperative/python/megengine/jit/tracing.py
  3. +4
    -0
      imperative/python/src/tensor.cpp
  4. +9
    -0
      imperative/python/test/unit/core/test_interpreter.py
  5. +10
    -1
      imperative/src/impl/interpreter/interpreter_impl.cpp

+ 1
- 0
imperative/python/megengine/functional/vision.py View File

@@ -420,6 +420,7 @@ def warp_affine(
Here all available options for params are listed, Here all available options for params are listed,
however it does not mean that you can use all the combinations. however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported. On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
""" """
conv_format = _config._get_actual_op_param(format, _config.__conv_format) conv_format = _config._get_actual_op_param(format, _config.__conv_format)




+ 13
- 7
imperative/python/megengine/jit/tracing.py View File

@@ -104,6 +104,7 @@ class TensorInfo:
"shape", "shape",
"is_const", "is_const",
"bound_data", "bound_data",
"bound_data_numpy",
# resources for execution # resources for execution
"varnode", "varnode",
"data_setter", "data_setter",
@@ -119,12 +120,18 @@ class TensorInfo:
self.shape_read = None self.shape_read = None
self.value_read = None self.value_read = None
self.bound_data = None self.bound_data = None
self.bound_data_numpy = None


self.data_setter = None self.data_setter = None
self.shape_reader = None self.shape_reader = None
self.value_reader = None self.value_reader = None
self.data_reader = None self.data_reader = None


def get_numpy(self):
if self.bound_data_numpy is None:
self.bound_data_numpy = self.bound_data.numpy()
return self.bound_data_numpy



_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}


@@ -292,7 +299,7 @@ class trace:
# Const op is represented by a str # Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const" assert isinstance(op_, str) and op_ == "Const"


expected = self._tinfo[ohandles[0]].bound_data.numpy()
expected = self._tinfo[ohandles[0]].get_numpy()
shape = value.shape shape = value.shape
if shape != expected.shape or dtype != expected.dtype: if shape != expected.shape or dtype != expected.dtype:
eq = False eq = False
@@ -369,6 +376,7 @@ class trace:
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.shape info.shape = x.shape
info.bound_data = x info.bound_data = x
info.bound_data_numpy = None
info.is_const = True info.is_const = True
x._mixin_handle = h x._mixin_handle = h
x._recording = True x._recording = True
@@ -612,9 +620,7 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
info.varnode = graph.make_const( info.varnode = graph.make_const(
info.bound_data.numpy(),
info.bound_data.dtype,
info.bound_data.device,
info.get_numpy(), info.bound_data.dtype, info.bound_data.device,
) )
continue continue


@@ -627,7 +633,7 @@ class trace:
if info.bound_data: if info.bound_data:
if getattr(info, "is_const", False): if getattr(info, "is_const", False):
info.varnode = graph.make_const( info.varnode = graph.make_const(
info.bound_data.numpy(),
info.get_numpy(),
info.bound_data.dtype, info.bound_data.dtype,
info.bound_data.device, info.bound_data.device,
) )
@@ -1174,7 +1180,7 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(),
info.get_numpy(),
dtype=info.dtype, dtype=info.dtype,
device=dumped_device(info), device=dumped_device(info),
name=info.name, name=info.name,
@@ -1187,7 +1193,7 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(),
info.get_numpy(),
dtype=info.dtype, dtype=info.dtype,
device=dumped_device(info), device=dumped_device(info),
name=info.name, name=info.name,


+ 4
- 0
imperative/python/src/tensor.cpp View File

@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) {
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
CompNode::sync_all(); CompNode::sync_all();
CompNode::foreach ([](CompNode cn) {
auto err = cn.check_async_error();
mgb_assert(!err, "%s", err->what());
});
sync_py_task_q(); sync_py_task_q();
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());


+ 9
- 0
imperative/python/test/unit/core/test_interpreter.py View File

@@ -96,6 +96,15 @@ def test_regression_2870():
(x + x).numpy() (x + x).numpy()




@pytest.mark.require_ngpu(1)
def test_async_error_check():
src = mge.tensor([[1.0, 2.0]])
index = mge.tensor([3])
val = F.indexing_one_hot(src, index)
with pytest.raises(RuntimeError):
val.numpy()


# NOTE: DO NOT REMOVE THIS TEST # NOTE: DO NOT REMOVE THIS TEST
# This is also a compatibility test for # This is also a compatibility test for
# mge.core.set_option('async_level', 0). # mge.core.set_option('async_level', 0).


+ 10
- 1
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
if (m_async_level == 0) { if (m_async_level == 0) {
sync_impl(); sync_impl();
info->desc.comp_node.sync(); info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error();
mgb_assert(!err, "%s", err->what());
} }
return info; return info;
} }
@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel(
for (auto&& oup : *outputs) { for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup); auto info = reinterpret_cast<TensorInfo*>(oup);
info->ptr->comp_node().sync(); info->ptr->comp_node().sync();
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
} }
} }
} }
@@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
bool require_host = prop == TensorProp::HostValue; bool require_host = prop == TensorProp::HostValue;
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
if (require_host && !host_available()) {
bool wait_host = !host_available();
if (require_host && wait_host) {
// avoid dead lock // avoid dead lock
lock.unlock(); lock.unlock();
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
@@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}); });
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr; m_waitee = nullptr;
if (require_host && wait_host) {
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
}
return info->ptr; return info->ptr;
} }




Loading…
Cancel
Save