Browse Source

fix(imperative): check async error when getting value

GitOrigin-RevId: 52b8a29932
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
619d78ed86
5 changed files with 37 additions and 8 deletions
  1. +1
    -0
      imperative/python/megengine/functional/vision.py
  2. +13
    -7
      imperative/python/megengine/jit/tracing.py
  3. +4
    -0
      imperative/python/src/tensor.cpp
  4. +9
    -0
      imperative/python/test/unit/core/test_interpreter.py
  5. +10
    -1
      imperative/src/impl/interpreter/interpreter_impl.cpp

+ 1
- 0
imperative/python/megengine/functional/vision.py View File

@@ -420,6 +420,7 @@ def warp_affine(
Here all available options for params are listed,
however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
"""
conv_format = _config._get_actual_op_param(format, _config.__conv_format)



+ 13
- 7
imperative/python/megengine/jit/tracing.py View File

@@ -104,6 +104,7 @@ class TensorInfo:
"shape",
"is_const",
"bound_data",
"bound_data_numpy",
# resources for execution
"varnode",
"data_setter",
@@ -119,12 +120,18 @@ class TensorInfo:
self.shape_read = None
self.value_read = None
self.bound_data = None
self.bound_data_numpy = None

self.data_setter = None
self.shape_reader = None
self.value_reader = None
self.data_reader = None

def get_numpy(self):
if self.bound_data_numpy is None:
self.bound_data_numpy = self.bound_data.numpy()
return self.bound_data_numpy


_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}

@@ -292,7 +299,7 @@ class trace:
# Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const"

expected = self._tinfo[ohandles[0]].bound_data.numpy()
expected = self._tinfo[ohandles[0]].get_numpy()
shape = value.shape
if shape != expected.shape or dtype != expected.dtype:
eq = False
@@ -369,6 +376,7 @@ class trace:
info.dtype = x.dtype
info.shape = x.shape
info.bound_data = x
info.bound_data_numpy = None
info.is_const = True
x._mixin_handle = h
x._recording = True
@@ -612,9 +620,7 @@ class trace:
assert info.external
assert info.bound_data
info.varnode = graph.make_const(
info.bound_data.numpy(),
info.bound_data.dtype,
info.bound_data.device,
info.get_numpy(), info.bound_data.dtype, info.bound_data.device,
)
continue

@@ -627,7 +633,7 @@ class trace:
if info.bound_data:
if getattr(info, "is_const", False):
info.varnode = graph.make_const(
info.bound_data.numpy(),
info.get_numpy(),
info.bound_data.dtype,
info.bound_data.device,
)
@@ -1174,7 +1180,7 @@ class trace:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(),
info.get_numpy(),
dtype=info.dtype,
device=dumped_device(info),
name=info.name,
@@ -1187,7 +1193,7 @@ class trace:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(),
info.get_numpy(),
dtype=info.dtype,
device=dumped_device(info),
name=info.name,


+ 4
- 0
imperative/python/src/tensor.cpp View File

@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) {
[]() {
interpreter_for_py->sync();
CompNode::sync_all();
CompNode::foreach ([](CompNode cn) {
auto err = cn.check_async_error();
mgb_assert(!err, "%s", err->what());
});
sync_py_task_q();
},
py::call_guard<py::gil_scoped_release>());


+ 9
- 0
imperative/python/test/unit/core/test_interpreter.py View File

@@ -96,6 +96,15 @@ def test_regression_2870():
(x + x).numpy()


@pytest.mark.require_ngpu(1)
def test_async_error_check():
src = mge.tensor([[1.0, 2.0]])
index = mge.tensor([3])
val = F.indexing_one_hot(src, index)
with pytest.raises(RuntimeError):
val.numpy()


# NOTE: DO NOT REMOVE THIS TEST
# This is also a compatibility test for
# mge.core.set_option('async_level', 0).


+ 10
- 1
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
if (m_async_level == 0) {
sync_impl();
info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error();
mgb_assert(!err, "%s", err->what());
}
return info;
}
@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel(
for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup);
info->ptr->comp_node().sync();
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
}
}
}
@@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
bool require_host = prop == TensorProp::HostValue;
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
if (require_host && !host_available()) {
bool wait_host = !host_available();
if (require_host && wait_host) {
// avoid dead lock
lock.unlock();
m_buffer.enqueue(GetValue{info});
@@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
});
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr;
if (require_host && wait_host) {
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
}
return info->ptr;
}



Loading…
Cancel
Save