GitOrigin-RevId: 8615a23b75
release-1.1
@@ -77,12 +77,14 @@ void init_imperative_rt(py::module m) { | |||||
.def("get_shape", &Interpreter::Channel::get_shape) | .def("get_shape", &Interpreter::Channel::get_shape) | ||||
.def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor) | .def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor) | ||||
.def("apply_op", &Interpreter::Channel::apply_op) | .def("apply_op", &Interpreter::Channel::apply_op) | ||||
.def("config_async_level", &Interpreter::Channel::config_async_level) | |||||
.def("get_async_level", &Interpreter::Channel::get_async_level) | |||||
.def("sync", &Interpreter::Channel::sync, py::call_guard<py::gil_scoped_release>()); | .def("sync", &Interpreter::Channel::sync, py::call_guard<py::gil_scoped_release>()); | ||||
std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel(); | std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel(); | ||||
m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast( | m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast( | ||||
std::move(ch), py::return_value_policy::move, {}); | std::move(ch), py::return_value_policy::move, {}); | ||||
for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op"}) { | |||||
for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level"}) { | |||||
m.attr(name) = m.attr("interpreter").attr(name); | m.attr(name) = m.attr("interpreter").attr(name); | ||||
} | } | ||||
@@ -0,0 +1,35 @@ | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
from megengine.core._imperative_rt.imperative import config_async_level, get_async_level | |||||
def test_basic(): | |||||
config_async_level(2) | |||||
assert get_async_level() == 2 | |||||
with pytest.raises(RuntimeError): | |||||
config_async_level(3) | |||||
def test_level1_infer_value(): | |||||
config_async_level(1) | |||||
a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") | |||||
b = mge.tensor([1, 1], dtype="float32") | |||||
# make DepType::VALUE unknown | |||||
c = b * 2 | |||||
with pytest.raises(RuntimeError): | |||||
d = F.reshape(a, c) | |||||
def test_level1_infer_shape_with_unknown(): | |||||
config_async_level(2) | |||||
a = mge.tensor([[1, 2, 2, 3]], dtype="float32") | |||||
b = mge.tensor([1, 1]) | |||||
c = b * 2 | |||||
# make DepType::SHAPE unknown | |||||
d = F.reshape(a, c) | |||||
config_async_level(1) | |||||
e = mge.tensor([[1, 2]], dtype="float32") | |||||
with pytest.raises(RuntimeError): | |||||
f = F.matmul(d, e) |
@@ -54,21 +54,25 @@ void ChannelImpl::del(void* handle) { | |||||
SmallVector<void*> ChannelImpl::apply_op( | SmallVector<void*> ChannelImpl::apply_op( | ||||
std::shared_ptr<OpDef> op, | std::shared_ptr<OpDef> op, | ||||
const SmallVector<void*>& inputs) { | const SmallVector<void*>& inputs) { | ||||
SmallVector<TensorInfo*> input_infos; | |||||
input_infos.reserve(inputs.size()); | |||||
SmallVector<LogicalTensorDesc> input_descs; | SmallVector<LogicalTensorDesc> input_descs; | ||||
input_descs.reserve(inputs.size()); | input_descs.reserve(inputs.size()); | ||||
for (auto h : inputs) { | |||||
auto info = reinterpret_cast<TensorInfo*>(h); | |||||
for (auto i : inputs) { | |||||
auto info = reinterpret_cast<TensorInfo*>(i); | |||||
input_infos.push_back(info); | |||||
input_descs.push_back(info->desc); | input_descs.push_back(info->desc); | ||||
} | } | ||||
auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs); | auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs); | ||||
ApplyOp cmd{std::move(op)}; | ApplyOp cmd{std::move(op)}; | ||||
cmd.inputs.reserve(inputs.size()); | |||||
for (auto i : inputs) { | |||||
cmd.inputs.push_back(reinterpret_cast<TensorInfo*>(i)); | |||||
} | |||||
cmd.inputs = std::move(input_infos); | |||||
cmd.outputs.reserve(output_descs.size()); | cmd.outputs.reserve(output_descs.size()); | ||||
SmallVector<void*> outputs; | SmallVector<void*> outputs; | ||||
bool is_fallible = false; | |||||
for (auto&& desc : output_descs) { | for (auto&& desc : output_descs) { | ||||
if (desc.layout.ndim == 0) { | |||||
is_fallible = true; | |||||
} | |||||
auto info = alloc(); | auto info = alloc(); | ||||
info->desc = desc; | info->desc = desc; | ||||
m_valid_handle.insert(info); | m_valid_handle.insert(info); | ||||
@@ -76,6 +80,9 @@ SmallVector<void*> ChannelImpl::apply_op( | |||||
outputs.push_back(info); | outputs.push_back(info); | ||||
} | } | ||||
m_worker.add_task(std::move(cmd)); | m_worker.add_task(std::move(cmd)); | ||||
if (is_fallible && m_async_level <= 1) { | |||||
sync(); | |||||
} | |||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -162,7 +169,12 @@ void ChannelImpl::close() { | |||||
} | } | ||||
void ChannelImpl::config_async_level(int level) { | void ChannelImpl::config_async_level(int level) { | ||||
mgb_assert(0); | |||||
mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2"); | |||||
m_async_level = level; | |||||
} | |||||
int ChannelImpl::get_async_level() { | |||||
return m_async_level; | |||||
} | } | ||||
TensorInfo* ChannelImpl::alloc() { | TensorInfo* ChannelImpl::alloc() { | ||||
@@ -74,6 +74,7 @@ struct ChannelImpl : Interpreter::Channel { | |||||
void close() override; | void close() override; | ||||
void config_async_level(int level) override; | void config_async_level(int level) override; | ||||
int get_async_level() override; | |||||
private: | private: | ||||
TensorInfo* alloc(); | TensorInfo* alloc(); | ||||
@@ -101,7 +102,11 @@ private: | |||||
ChannelImpl* m_owner; | ChannelImpl* m_owner; | ||||
} m_worker; | } m_worker; | ||||
int m_async_level = 2; | |||||
//! config whether raise error exactly when invoking op. | |||||
//! level 2: both device and user side errors are async; | |||||
//! level 1: user side errors are sync; | |||||
//! level 0: both sync. | |||||
int m_async_level = 1; | |||||
}; | }; | ||||
} // namespace mgb::imperative::interpreter::intl | } // namespace mgb::imperative::interpreter::intl |
@@ -41,6 +41,7 @@ struct Interpreter { | |||||
virtual void close() = 0; | virtual void close() = 0; | ||||
virtual void config_async_level(int level) = 0; | virtual void config_async_level(int level) = 0; | ||||
virtual int get_async_level() = 0; | |||||
}; | }; | ||||
virtual std::unique_ptr<Channel> create_channel() = 0; | virtual std::unique_ptr<Channel> create_channel() = 0; | ||||