|
|
@@ -49,17 +49,28 @@ class _CompGraphProfilerImpl { |
|
|
|
return json->to_string(); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
struct WeakRendezvousArray: |
|
|
|
public std::vector<std::weak_ptr<RendezvousBase>>, |
|
|
|
public UserDataContainer::UserData { |
|
|
|
MGB_TYPEINFO_OBJ_DECL; |
|
|
|
}; |
|
|
|
MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray); |
|
|
|
} |
|
|
|
#define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name) |
|
|
|
|
|
|
|
template<typename T> |
|
|
|
auto def_rendezvous(py::object m, const char* name) { |
|
|
|
return py::class_<Rendezvous<T>, std::shared_ptr<Rendezvous<T>>>(m, name) |
|
|
|
.def(py::init([](){return std::make_shared<Rendezvous<T>>();})) |
|
|
|
.def(py::init([](){return Rendezvous<T>::make();})) |
|
|
|
.def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));}) |
|
|
|
.def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>()) |
|
|
|
.def("drop", &Rendezvous<T>::drop) |
|
|
|
.def("reset", &Rendezvous<T>::reset); |
|
|
|
.def("reset", &Rendezvous<T>::reset) |
|
|
|
.def("set_exception", [](Rendezvous<T>& r, std::string&& message) { |
|
|
|
r.set_exception(std::make_exception_ptr( |
|
|
|
std::runtime_error(std::move(message)))); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
using TensorAttr = LogicalTensorDesc; |
|
|
@@ -186,7 +197,21 @@ void init_graph_rt(py::module m) { |
|
|
|
|
|
|
|
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") |
|
|
|
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) |
|
|
|
.def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>()); |
|
|
|
.def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>()) |
|
|
|
// only used for exception handle |
|
|
|
.def_property_readonly("_all_rendezvous", [](cg::AsyncExecutable* exec) { |
|
|
|
auto ud = exec->owner_graph()->options().user_data |
|
|
|
.get_user_data<WeakRendezvousArray>(); |
|
|
|
std::vector<std::shared_ptr<RendezvousBase>> ret; |
|
|
|
if (ud.second) { |
|
|
|
for (auto&& r: *ud.first[0]) { |
|
|
|
if (auto p = r.lock()) { |
|
|
|
ret.emplace_back(std::move(p)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}); |
|
|
|
|
|
|
|
auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph") |
|
|
|
.def(py::init(py::overload_cast<>(&cg::ComputingGraph::make))) |
|
|
@@ -483,7 +508,14 @@ void init_graph_rt(py::module m) { |
|
|
|
}, |
|
|
|
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); |
|
|
|
|
|
|
|
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false, bool prefer_host_value = false) { |
|
|
|
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, |
|
|
|
std::shared_ptr<RendezvousBase> r = {}, bool borrow = false, bool prefer_host_value = false) { |
|
|
|
if (r) { |
|
|
|
mgb_assert(inputs.size()); |
|
|
|
auto cg = inputs[0]->owner_graph(); |
|
|
|
cg->options().user_data.get_user_data_or_create<WeakRendezvousArray>() |
|
|
|
->emplace_back(r); |
|
|
|
} |
|
|
|
SymbolVarArray sinputs; |
|
|
|
for (auto i : inputs) { |
|
|
|
sinputs.emplace_back(i); |
|
|
@@ -508,7 +540,7 @@ void init_graph_rt(py::module m) { |
|
|
|
auto f = [p](DeviceTensorND dv) { |
|
|
|
p->set(std::move(dv)); |
|
|
|
}; |
|
|
|
return output_callback(std::move(f), std::move(inputs)); |
|
|
|
return output_callback(std::move(f), std::move(inputs), p); |
|
|
|
}); |
|
|
|
|
|
|
|
m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) { |
|
|
@@ -519,13 +551,13 @@ void init_graph_rt(py::module m) { |
|
|
|
hv_with_event.second->record(); |
|
|
|
p->set(std::move(hv_with_event)); |
|
|
|
}; |
|
|
|
return output_callback(std::move(f), std::move(inputs), true, true); |
|
|
|
return output_callback(std::move(f), std::move(inputs), p, true, true); |
|
|
|
}); |
|
|
|
|
|
|
|
m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { |
|
|
|
auto f = [p](DeviceTensorND dv) { |
|
|
|
p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()}); |
|
|
|
}; |
|
|
|
return output_callback(std::move(f), std::move(inputs), true); |
|
|
|
return output_callback(std::move(f), std::move(inputs), p, true); |
|
|
|
}); |
|
|
|
} |