From 619d78ed86573e1f04ac64875273afd13338e8a4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 17 Nov 2021 12:38:49 +0800 Subject: [PATCH] fix(imperative): check async error when getting value GitOrigin-RevId: 52b8a29932d2abb33f4bb3d4acff91fe53a6a998 --- imperative/python/megengine/functional/vision.py | 1 + imperative/python/megengine/jit/tracing.py | 20 +++++++++++++------- imperative/python/src/tensor.cpp | 4 ++++ imperative/python/test/unit/core/test_interpreter.py | 9 +++++++++ imperative/src/impl/interpreter/interpreter_impl.cpp | 11 ++++++++++- 5 files changed, 37 insertions(+), 8 deletions(-) diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index 7ab6bf24..0fd905ee 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -420,6 +420,7 @@ def warp_affine( Here all available options for params are listed, however it does not mean that you can use all the combinations. On different platforms, different combinations are supported. + ``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed. """ conv_format = _config._get_actual_op_param(format, _config.__conv_format) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 0e136d6b..19f51544 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -104,6 +104,7 @@ class TensorInfo: "shape", "is_const", "bound_data", + "bound_data_numpy", # resources for execution "varnode", "data_setter", @@ -119,12 +120,18 @@ class TensorInfo: self.shape_read = None self.value_read = None self.bound_data = None + self.bound_data_numpy = None self.data_setter = None self.shape_reader = None self.value_reader = None self.data_reader = None + def get_numpy(self): + if self.bound_data_numpy is None: + self.bound_data_numpy = self.bound_data.numpy() + return self.bound_data_numpy + _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} @@ -292,7 +299,7 @@ class trace: # Const op is represented by a str assert isinstance(op_, str) and op_ == "Const" - expected = self._tinfo[ohandles[0]].bound_data.numpy() + expected = self._tinfo[ohandles[0]].get_numpy() shape = value.shape if shape != expected.shape or dtype != expected.dtype: eq = False @@ -369,6 +376,7 @@ class trace: info.dtype = x.dtype info.shape = x.shape info.bound_data = x + info.bound_data_numpy = None info.is_const = True x._mixin_handle = h x._recording = True @@ -612,9 +620,7 @@ class trace: assert info.external assert info.bound_data info.varnode = graph.make_const( - info.bound_data.numpy(), - info.bound_data.dtype, - info.bound_data.device, + info.get_numpy(), info.bound_data.dtype, info.bound_data.device, ) continue @@ -627,7 +633,7 @@ class trace: if info.bound_data: if getattr(info, "is_const", False): info.varnode = graph.make_const( - info.bound_data.numpy(), + info.get_numpy(), info.bound_data.dtype, info.bound_data.device, ) @@ -1174,7 +1180,7 @@ class trace: assert info.external assert info.bound_data h2v[h] = graph.make_const( - info.bound_data.numpy(), + info.get_numpy(), dtype=info.dtype, device=dumped_device(info), name=info.name, @@ -1187,7 +1193,7 @@ class trace: assert info.external assert info.bound_data h2v[h] = graph.make_const( - info.bound_data.numpy(), + info.get_numpy(), dtype=info.dtype, device=dumped_device(info), name=info.name, diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 87f2459d..f415f572 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -1074,6 +1074,10 @@ void init_tensor(py::module m) { []() { interpreter_for_py->sync(); CompNode::sync_all(); + CompNode::foreach ([](CompNode cn) { + auto err = cn.check_async_error(); + mgb_assert(!err, "%s", err->what()); + }); sync_py_task_q(); }, py::call_guard()); diff --git a/imperative/python/test/unit/core/test_interpreter.py b/imperative/python/test/unit/core/test_interpreter.py index 3bbb5caf..2513c36b 100644 --- a/imperative/python/test/unit/core/test_interpreter.py +++ b/imperative/python/test/unit/core/test_interpreter.py @@ -96,6 +96,15 @@ def test_regression_2870(): (x + x).numpy() +@pytest.mark.require_ngpu(1) +def test_async_error_check(): + src = mge.tensor([[1.0, 2.0]]) + index = mge.tensor([3]) + val = F.indexing_one_hot(src, index) + with pytest.raises(RuntimeError): + val.numpy() + + # NOTE: DO NOT REMOVE THIS TEST # This is also a compatibility test for # mge.core.set_option('async_level', 0). diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index bf122982..ecf63e65 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { if (m_async_level == 0) { sync_impl(); info->desc.comp_node.sync(); + auto err = info->desc.comp_node.check_async_error(); + mgb_assert(!err, "%s", err->what()); } return info; } @@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel( for (auto&& oup : *outputs) { auto info = reinterpret_cast(oup); info->ptr->comp_node().sync(); + auto err = info->ptr->comp_node().check_async_error(); + mgb_assert(!err, "%s", err->what()); } } } @@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); bool require_host = prop == TensorProp::HostValue; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; - if (require_host && !host_available()) { + bool wait_host = !host_available(); + if (require_host && wait_host) { // avoid dead lock lock.unlock(); m_buffer.enqueue(GetValue{info}); @@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { }); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); m_waitee = nullptr; + if (require_host && wait_host) { + auto err = info->ptr->comp_node().check_async_error(); + mgb_assert(!err, "%s", err->what()); + } return info->ptr; }