Browse Source

fix(imperative): set exception which worker thread throws for rendezvous

GitOrigin-RevId: f583888fdf
release-1.1
Megvii Engine Team 4 years ago
parent
commit
8109cc4bd1
4 changed files with 109 additions and 10 deletions
  1. +10
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +39
    -7
      imperative/python/src/graph_rt.cpp
  3. +43
    -2
      imperative/python/src/graph_rt.h
  4. +17
    -0
      imperative/python/test/unit/core/test_megbrain_graph.py

+ 10
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -50,7 +50,16 @@ class Graph(_imperative_rt.ComputingGraph):

def execute(self, *args):
assert self._future is None
self._future = self._executor.submit(self._function.execute, *args)

def wrapped(*args):
try:
self._function.execute(*args)
except Exception as exc:
for i in self._function._all_rendezvous:
i.set_exception(str(exc))
raise exc

self._future = self._executor.submit(wrapped, *args)

def wait(self):
assert self._future is not None


+ 39
- 7
imperative/python/src/graph_rt.cpp View File

@@ -49,17 +49,28 @@ class _CompGraphProfilerImpl {
return json->to_string();
}
};

struct WeakRendezvousArray:
public std::vector<std::weak_ptr<RendezvousBase>>,
public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
};
MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray);
}
#define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)

template<typename T>
auto def_rendezvous(py::object m, const char* name) {
return py::class_<Rendezvous<T>, std::shared_ptr<Rendezvous<T>>>(m, name)
.def(py::init([](){return std::make_shared<Rendezvous<T>>();}))
.def(py::init([](){return Rendezvous<T>::make();}))
.def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));})
.def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>())
.def("drop", &Rendezvous<T>::drop)
.def("reset", &Rendezvous<T>::reset);
.def("reset", &Rendezvous<T>::reset)
.def("set_exception", [](Rendezvous<T>& r, std::string&& message) {
r.set_exception(std::make_exception_ptr(
std::runtime_error(std::move(message))));
});
}

using TensorAttr = LogicalTensorDesc;
@@ -186,7 +197,21 @@ void init_graph_rt(py::module m) {

py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>())
.def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>());
.def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>())
// only used for exception handle
.def_property_readonly("_all_rendezvous", [](cg::AsyncExecutable* exec) {
auto ud = exec->owner_graph()->options().user_data
.get_user_data<WeakRendezvousArray>();
std::vector<std::shared_ptr<RendezvousBase>> ret;
if (ud.second) {
for (auto&& r: *ud.first[0]) {
if (auto p = r.lock()) {
ret.emplace_back(std::move(p));
}
}
}
return ret;
});

auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph")
.def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
@@ -483,7 +508,14 @@ void init_graph_rt(py::module m) {
},
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none());

auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false, bool prefer_host_value = false) {
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs,
std::shared_ptr<RendezvousBase> r = {}, bool borrow = false, bool prefer_host_value = false) {
if (r) {
mgb_assert(inputs.size());
auto cg = inputs[0]->owner_graph();
cg->options().user_data.get_user_data_or_create<WeakRendezvousArray>()
->emplace_back(r);
}
SymbolVarArray sinputs;
for (auto i : inputs) {
sinputs.emplace_back(i);
@@ -508,7 +540,7 @@ void init_graph_rt(py::module m) {
auto f = [p](DeviceTensorND dv) {
p->set(std::move(dv));
};
return output_callback(std::move(f), std::move(inputs));
return output_callback(std::move(f), std::move(inputs), p);
});

m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) {
@@ -519,13 +551,13 @@ void init_graph_rt(py::module m) {
hv_with_event.second->record();
p->set(std::move(hv_with_event));
};
return output_callback(std::move(f), std::move(inputs), true, true);
return output_callback(std::move(f), std::move(inputs), p, true, true);
});

m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) {
auto f = [p](DeviceTensorND dv) {
p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()});
};
return output_callback(std::move(f), std::move(inputs), true);
return output_callback(std::move(f), std::move(inputs), p, true);
});
}

+ 43
- 2
imperative/python/src/graph_rt.h View File

@@ -35,18 +35,36 @@ public:

PYBIND11_DECLARE_HOLDER_TYPE(T, GraphNodePtr<T>, true);

class RendezvousBase {
public:
virtual ~RendezvousBase() = default;
virtual void set_exception(std::exception_ptr p) = 0;
};

template<typename R>
class Rendezvous {
class Rendezvous: public RendezvousBase {
std::mutex m_lock;
int m_read_ahead = 0;
bool m_drop_next = false;
std::promise<R> m_promise;
public:
Rendezvous() = default;
struct Factory {
template<typename ...Args>
static auto make_rendezvous(Args&& ...args) {
auto ptr = new Rendezvous<R>{std::forward(args)...};
return std::shared_ptr<Rendezvous<R>>(ptr);
}
};
public:
Rendezvous(const Rendezvous& rhs) = delete;
Rendezvous(Rendezvous&& rhs) = delete;
Rendezvous& operator=(const Rendezvous& rhs) = delete;

template<typename ...Args>
static auto make(Args&& ...args) {
return Factory::make_rendezvous(std::forward<Args>(args)...);
}

R get() {
std::future<R> f;
{
@@ -96,6 +114,29 @@ public:
m_read_ahead = 0;
m_drop_next = false;
}

void set_exception(std::exception_ptr e) {
if (e) {
MGB_LOCK_GUARD(m_lock);
if (m_read_ahead >= 0) {
mgb_assert(m_read_ahead <= 1);
if (m_drop_next) {
m_drop_next = false;
} else {
m_promise.set_exception(e);
}
if (m_read_ahead == 1) {
m_promise = {};
}
--m_read_ahead;
} else {
mgb_assert(m_read_ahead == -1);
// TODO: maybe exception should be ignored
// if value was already set ?
m_promise.set_exception(e);
}
}
}
};

void init_graph_rt(pybind11::module m);

+ 17
- 0
imperative/python/test/unit/core/test_megbrain_graph.py View File

@@ -82,3 +82,20 @@ def test_op():
f()

np.testing.assert_equal(x.numpy(), -y.result().numpy())


def test_exception():
err_msg = "QwQ"

def throw_exc():
raise RuntimeError(err_msg)

g = mgb_graph.Graph()
x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g)
y = mgb_graph.OutputNode(F.neg(x))
f = g.compile(y.outputs[0])
try:
f.execute()
y.get_value()
except Exception as exc:
assert err_msg in str(exc)

Loading…
Cancel
Save