GitOrigin-RevId: 52b8a29932
tags/v1.8.0
@@ -420,6 +420,7 @@ def warp_affine( | |||
Here all available options for params are listed, | |||
however it does not mean that you can use all the combinations. | |||
On different platforms, different combinations are supported. | |||
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed. | |||
""" | |||
conv_format = _config._get_actual_op_param(format, _config.__conv_format) | |||
@@ -104,6 +104,7 @@ class TensorInfo: | |||
"shape", | |||
"is_const", | |||
"bound_data", | |||
"bound_data_numpy", | |||
# resources for execution | |||
"varnode", | |||
"data_setter", | |||
@@ -119,12 +120,18 @@ class TensorInfo: | |||
self.shape_read = None | |||
self.value_read = None | |||
self.bound_data = None | |||
self.bound_data_numpy = None | |||
self.data_setter = None | |||
self.shape_reader = None | |||
self.value_reader = None | |||
self.data_reader = None | |||
def get_numpy(self): | |||
if self.bound_data_numpy is None: | |||
self.bound_data_numpy = self.bound_data.numpy() | |||
return self.bound_data_numpy | |||
_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} | |||
@@ -292,7 +299,7 @@ class trace: | |||
# Const op is represented by a str | |||
assert isinstance(op_, str) and op_ == "Const" | |||
expected = self._tinfo[ohandles[0]].bound_data.numpy() | |||
expected = self._tinfo[ohandles[0]].get_numpy() | |||
shape = value.shape | |||
if shape != expected.shape or dtype != expected.dtype: | |||
eq = False | |||
@@ -369,6 +376,7 @@ class trace: | |||
info.dtype = x.dtype | |||
info.shape = x.shape | |||
info.bound_data = x | |||
info.bound_data_numpy = None | |||
info.is_const = True | |||
x._mixin_handle = h | |||
x._recording = True | |||
@@ -612,9 +620,7 @@ class trace: | |||
assert info.external | |||
assert info.bound_data | |||
info.varnode = graph.make_const( | |||
info.bound_data.numpy(), | |||
info.bound_data.dtype, | |||
info.bound_data.device, | |||
info.get_numpy(), info.bound_data.dtype, info.bound_data.device, | |||
) | |||
continue | |||
@@ -627,7 +633,7 @@ class trace: | |||
if info.bound_data: | |||
if getattr(info, "is_const", False): | |||
info.varnode = graph.make_const( | |||
info.bound_data.numpy(), | |||
info.get_numpy(), | |||
info.bound_data.dtype, | |||
info.bound_data.device, | |||
) | |||
@@ -1174,7 +1180,7 @@ class trace: | |||
assert info.external | |||
assert info.bound_data | |||
h2v[h] = graph.make_const( | |||
info.bound_data.numpy(), | |||
info.get_numpy(), | |||
dtype=info.dtype, | |||
device=dumped_device(info), | |||
name=info.name, | |||
@@ -1187,7 +1193,7 @@ class trace: | |||
assert info.external | |||
assert info.bound_data | |||
h2v[h] = graph.make_const( | |||
info.bound_data.numpy(), | |||
info.get_numpy(), | |||
dtype=info.dtype, | |||
device=dumped_device(info), | |||
name=info.name, | |||
@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) { | |||
[]() { | |||
interpreter_for_py->sync(); | |||
CompNode::sync_all(); | |||
CompNode::foreach ([](CompNode cn) { | |||
auto err = cn.check_async_error(); | |||
mgb_assert(!err, "%s", err->what()); | |||
}); | |||
sync_py_task_q(); | |||
}, | |||
py::call_guard<py::gil_scoped_release>()); | |||
@@ -96,6 +96,15 @@ def test_regression_2870(): | |||
(x + x).numpy() | |||
@pytest.mark.require_ngpu(1) | |||
def test_async_error_check(): | |||
src = mge.tensor([[1.0, 2.0]]) | |||
index = mge.tensor([3]) | |||
val = F.indexing_one_hot(src, index) | |||
with pytest.raises(RuntimeError): | |||
val.numpy() | |||
# NOTE: DO NOT REMOVE THIS TEST | |||
# This is also a compatibility test for | |||
# mge.core.set_option('async_level', 0). | |||
@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { | |||
if (m_async_level == 0) { | |||
sync_impl(); | |||
info->desc.comp_node.sync(); | |||
auto err = info->desc.comp_node.check_async_error(); | |||
mgb_assert(!err, "%s", err->what()); | |||
} | |||
return info; | |||
} | |||
@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel( | |||
for (auto&& oup : *outputs) { | |||
auto info = reinterpret_cast<TensorInfo*>(oup); | |||
info->ptr->comp_node().sync(); | |||
auto err = info->ptr->comp_node().check_async_error(); | |||
mgb_assert(!err, "%s", err->what()); | |||
} | |||
} | |||
} | |||
@@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); | |||
bool require_host = prop == TensorProp::HostValue; | |||
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; | |||
if (require_host && !host_available()) { | |||
bool wait_host = !host_available(); | |||
if (require_host && wait_host) { | |||
// avoid dead lock | |||
lock.unlock(); | |||
m_buffer.enqueue(GetValue{info}); | |||
@@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
}); | |||
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); | |||
m_waitee = nullptr; | |||
if (require_host && wait_host) { | |||
auto err = info->ptr->comp_node().check_async_error(); | |||
mgb_assert(!err, "%s", err->what()); | |||
} | |||
return info->ptr; | |||
} | |||