|
@@ -148,13 +148,15 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) { |
|
|
/* ========================== CallbackCaller ========================== */ |
|
|
/* ========================== CallbackCaller ========================== */ |
|
|
MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, |
|
|
MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, |
|
|
SingleCNOperatorNodeBase) // { |
|
|
SingleCNOperatorNodeBase) // { |
|
|
std::vector<ComputingGraph::Callback> m_cb; |
|
|
|
|
|
|
|
|
std::vector<std::vector<ComputingGraph::Callback>> m_cb; |
|
|
|
|
|
|
|
|
void scn_do_execute() override { |
|
|
void scn_do_execute() override { |
|
|
auto&& dv = input(0)->dev_tensor(); |
|
|
|
|
|
for (auto&& i : m_cb) { |
|
|
|
|
|
// const cast for backward API compatibility |
|
|
|
|
|
i(const_cast<DeviceTensorND&>(dv)); |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < input().size(); ++i) { |
|
|
|
|
|
auto&& in = input(i)->dev_tensor(); |
|
|
|
|
|
for (auto&& callback : m_cb[i]) { |
|
|
|
|
|
// const cast for backward API compatibility |
|
|
|
|
|
callback(const_cast<DeviceTensorND&>(in)); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -168,14 +170,29 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, |
|
|
if (owner_graph()->options().comp_node_seq_record_level) { |
|
|
if (owner_graph()->options().comp_node_seq_record_level) { |
|
|
// the user callback usually copies from device to host, which |
|
|
// the user callback usually copies from device to host, which |
|
|
// involves tmp alloc if input is not contiguous |
|
|
// involves tmp alloc if input is not contiguous |
|
|
input(0)->add_layout_constraint_contiguous(); |
|
|
|
|
|
|
|
|
for (auto&& inp : input()) { |
|
|
|
|
|
inp->add_layout_constraint_contiguous(); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void init_output_dtype() override { |
|
|
|
|
|
if (output(0)->dtype().valid()) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
mgb_assert(!input().empty()); |
|
|
|
|
|
DType dtype = input(0)->dtype(); |
|
|
|
|
|
mgb_assert(dtype.valid() && dtype != dtype::Byte()); |
|
|
|
|
|
output(0)->dtype(dtype); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
NodeProp* do_make_node_prop() const override { |
|
|
NodeProp* do_make_node_prop() const override { |
|
|
auto ret = Super::do_make_node_prop(); |
|
|
auto ret = Super::do_make_node_prop(); |
|
|
ret->add_dep_type_existing_var(input(0), |
|
|
|
|
|
NodeProp::DepType::VALUE_ALLOW_EMPTY); |
|
|
|
|
|
|
|
|
for (auto&& inp : input()) { |
|
|
|
|
|
ret->add_dep_type_existing_var( |
|
|
|
|
|
inp, NodeProp::DepType::VALUE_ALLOW_EMPTY); |
|
|
|
|
|
} |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -185,25 +202,38 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public: |
|
|
public: |
|
|
CallbackCaller(VarNode* inp) |
|
|
|
|
|
: Super{inp->owner_graph(), {}, "callback", {inp}} { |
|
|
|
|
|
add_input({inp}); |
|
|
|
|
|
|
|
|
CallbackCaller(const VarNodeArrayView& inp) |
|
|
|
|
|
: Super{inp[0]->owner_graph(), {}, "callback", inp} { |
|
|
|
|
|
mgb_assert(!inp.empty()); |
|
|
|
|
|
m_cb.resize(inp.size()); |
|
|
|
|
|
for (auto&& i : inp) { |
|
|
|
|
|
add_input({i}); |
|
|
|
|
|
} |
|
|
using F = VarNode::Flag; |
|
|
using F = VarNode::Flag; |
|
|
add_output(None) |
|
|
add_output(None) |
|
|
->add_flag(F::ALLOW_EMPTY_SHAPE) |
|
|
->add_flag(F::ALLOW_EMPTY_SHAPE) |
|
|
.add_flag(F::VOLATILE_CONTENT); |
|
|
.add_flag(F::VOLATILE_CONTENT); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static SymbolVar make(SymbolVar inp) { |
|
|
|
|
|
return inp.insert_single_output_opr<CallbackCaller>(inp.node()); |
|
|
|
|
|
|
|
|
static SymbolVar make(const VarNodeArrayView& inp) { |
|
|
|
|
|
mgb_assert(!inp.empty()); |
|
|
|
|
|
return SymbolVar{inp[0]} |
|
|
|
|
|
.node() |
|
|
|
|
|
->owner_graph() |
|
|
|
|
|
->insert_opr(std::make_unique<CallbackCaller>(inp)) |
|
|
|
|
|
->output(0); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void add_callback(const ComputingGraph::Callback& cb) { |
|
|
|
|
|
mgb_assert(cb); |
|
|
|
|
|
m_cb.push_back(cb); |
|
|
|
|
|
|
|
|
void add_callback(const ComputingGraph::Callback& cb, size_t i = 0) { |
|
|
|
|
|
mgb_assert(cb && i < m_cb.size()); |
|
|
|
|
|
m_cb[i].push_back(cb); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void clear_callback() { m_cb.clear(); } |
|
|
|
|
|
|
|
|
void clear_callback() { |
|
|
|
|
|
for (size_t i = 0; i < m_cb.size(); ++i) { |
|
|
|
|
|
m_cb[i].clear(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
}; |
|
|
}; |
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller); |
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller); |
|
|
|
|
|
|
|
@@ -529,22 +559,39 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( |
|
|
cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars); |
|
|
cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars); |
|
|
|
|
|
|
|
|
auto init_opr_seq = [&]() { |
|
|
auto init_opr_seq = [&]() { |
|
|
ThinHashMap<VarNode*, CallbackCaller*> var2cb_caller; |
|
|
|
|
|
|
|
|
ThinHashMap<VarNode*, size_t> var2idx; |
|
|
|
|
|
std::unordered_map<CallbackCallerKey, CallbackCallerVal, |
|
|
|
|
|
CallbackCallerKey::Hash> |
|
|
|
|
|
opr2vars; |
|
|
for (size_t i = 0; i < out_spec.size(); ++i) { |
|
|
for (size_t i = 0; i < out_spec.size(); ++i) { |
|
|
auto&& cb = out_spec[i].second; |
|
|
auto&& cb = out_spec[i].second; |
|
|
if (cb) { |
|
|
if (cb) { |
|
|
auto var = dest_vars[i]; |
|
|
auto var = dest_vars[i]; |
|
|
auto&& cb_caller = var2cb_caller[var]; |
|
|
|
|
|
if (!cb_caller) { |
|
|
|
|
|
auto dvar = CallbackCaller::make(var); |
|
|
|
|
|
cb_caller = &dvar.node() |
|
|
|
|
|
->owner_opr() |
|
|
|
|
|
->cast_final_safe<CallbackCaller>(); |
|
|
|
|
|
++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req; |
|
|
|
|
|
cb_caller->clear_callback(); |
|
|
|
|
|
|
|
|
CallbackCallerKey key{var->owner_opr(), var->comp_node()}; |
|
|
|
|
|
auto&& vals = opr2vars[key]; |
|
|
|
|
|
auto&& var2idx_iter = var2idx.find(var); |
|
|
|
|
|
if ( var2idx_iter == var2idx.end()) { |
|
|
|
|
|
vals.vars.push_back(var); |
|
|
|
|
|
vals.indexs.push_back({i}); |
|
|
|
|
|
var2idx[var] = vals.vars.size() - 1; |
|
|
|
|
|
} else { |
|
|
|
|
|
vals.indexs[var2idx_iter->second].push_back(i); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
for (auto& item : opr2vars) { |
|
|
|
|
|
auto&& val = item.second; |
|
|
|
|
|
auto dvar = CallbackCaller::make(val.vars); |
|
|
|
|
|
CallbackCaller* cb_caller = &dvar.node() |
|
|
|
|
|
->owner_opr() |
|
|
|
|
|
->cast_final_safe<CallbackCaller>(); |
|
|
|
|
|
++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req; |
|
|
|
|
|
cb_caller->clear_callback(); |
|
|
|
|
|
for (size_t i=0;i<val.vars.size(); ++i) { |
|
|
|
|
|
for (auto&& idx : val.indexs[i]) { |
|
|
|
|
|
cb_caller->add_callback(out_spec[idx].second, i); |
|
|
|
|
|
dest_vars[idx] = cb_caller->output(0); |
|
|
} |
|
|
} |
|
|
cb_caller->add_callback(cb); |
|
|
|
|
|
dest_vars[i] = cb_caller->output(0); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars); |
|
|
opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars); |
|
|