|
|
@@ -11,8 +11,9 @@ |
|
|
|
|
|
|
|
#include "megbrain/graph/operator_node.h" |
|
|
|
#include "megbrain/imperative/op_def.h" |
|
|
|
#include "megbrain/imperative/physical_tensor.h" |
|
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
|
|
|
|
|
|
#include "../blob_manager_impl.h" |
|
|
|
#include "./common.h" |
|
|
|
#include "./proxy_graph_base.h" |
|
|
|
|
|
|
@@ -80,6 +81,20 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>; |
|
|
|
template <typename T> |
|
|
|
TensorAdaptor(T*) -> TensorAdaptor<T, void>; |
|
|
|
|
|
|
|
SmallVector<Tensor*> to_raw_ptr_array( |
|
|
|
const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) { |
|
|
|
SmallVector<Tensor*> ret; |
|
|
|
for (auto&& i : inputs) { |
|
|
|
mgb_assert(i); |
|
|
|
ret.push_back(i.get()); |
|
|
|
if (ensure_storage) { |
|
|
|
// apply lazy allocation |
|
|
|
i->blob()->storage(); |
|
|
|
} |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
// single opr graph, for static inference and execution |
|
|
|
// contains static inference descs |
|
|
|
class ProxyGraph::MiniGraph { |
|
|
@@ -146,6 +161,9 @@ protected: |
|
|
|
virtual const DeviceTensorND* infer_value_fallible(VarNode*) { mgb_assert(0); } |
|
|
|
}; |
|
|
|
|
|
|
|
size_t buf_size; |
|
|
|
SmallVector<size_t> hash_buf; |
|
|
|
|
|
|
|
OperatorNodeBase* m_opr = nullptr; |
|
|
|
SmallVector<std::unique_ptr<OperatorNodeBase>> opr_ref_keeper; |
|
|
|
|
|
|
@@ -194,6 +212,7 @@ protected: |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
return &storage.value(); |
|
|
|
} else { |
|
|
|
auto& value = tensor.value(); |
|
|
|
return value.shape_valid() ? &value : nullptr; |
|
|
@@ -203,8 +222,10 @@ protected: |
|
|
|
|
|
|
|
public: |
|
|
|
template <typename I, typename G> |
|
|
|
MiniGraph(G& graph, const OpDef& opdef, const I& inputs) |
|
|
|
: input_value_storage(inputs.size()) { |
|
|
|
MiniGraph( |
|
|
|
G& graph, const OpDef& opdef, const I& inputs, const size_t* hash_buf_, |
|
|
|
const size_t buf_size_) |
|
|
|
: buf_size(buf_size_), input_value_storage(inputs.size()) { |
|
|
|
mgb_assert(!m_opr); |
|
|
|
auto _ = graph.scoped_attach(this); |
|
|
|
cg::VarNodeArray vinputs(inputs.size()); |
|
|
@@ -222,7 +243,8 @@ public: |
|
|
|
} |
|
|
|
m_opr->init_output_static_infer_desc(); |
|
|
|
|
|
|
|
// fix permuted input |
|
|
|
// fix permuted input: the order of m_opr->input() and vinputs may be |
|
|
|
// different, input_remap keeps the index map of m_opr->input() and vinputs |
|
|
|
input_remap.reserve(m_opr->input().size()); |
|
|
|
for (auto* v : m_opr->input()) { |
|
|
|
auto [found, i] = find_index(vinputs, v); |
|
|
@@ -248,6 +270,23 @@ public: |
|
|
|
mgb_assert(found); |
|
|
|
output_remap.push_back(i); |
|
|
|
} |
|
|
|
|
|
|
|
hash_buf.resize(buf_size); |
|
|
|
for (size_t i = 0; i < buf_size; ++i) { |
|
|
|
hash_buf[i] = hash_buf_[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool is_same_buf(const size_t hash_buf_[], const size_t buf_size_) { |
|
|
|
if (buf_size != buf_size_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < buf_size; i++) { |
|
|
|
if (hash_buf[i] != hash_buf_[i]) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// methods for containing graph |
|
|
@@ -264,6 +303,87 @@ public: |
|
|
|
return m_opr; |
|
|
|
} |
|
|
|
|
|
|
|
void init_input_tensor(const SmallVector<Tensor*>& inputs) { |
|
|
|
auto&& opr_inputs = m_opr->input(); |
|
|
|
mgb_assert(opr_inputs.size() == inputs.size()); |
|
|
|
size_t idx = 0; |
|
|
|
for (auto&& input : opr_inputs) { |
|
|
|
mgb_assert(input->owner_opr()->same_type<InputPlaceholder>()); |
|
|
|
input->m_dev_tensor.storage({}); |
|
|
|
auto&& dev_tensor = inputs[input_remap[idx]]->dev_tensor(); |
|
|
|
auto&& layout = dev_tensor.layout(); |
|
|
|
input->shape(dev_tensor.shape()); |
|
|
|
|
|
|
|
auto&& chk = input->m_mem_plan.reset_from_owner_var().chunk(); |
|
|
|
input->m_dev_tensor.reset(dev_tensor.storage(), layout); |
|
|
|
input->m_mem_plan.layout(layout); |
|
|
|
chk.mem_alloc_status.set_from_owner_var(); |
|
|
|
|
|
|
|
mgb_assert(input->comp_node() == dev_tensor.comp_node()); |
|
|
|
mgb_assert(input->shape().eq_shape(layout)); |
|
|
|
mgb_assert(input->dtype() == layout.dtype); |
|
|
|
idx++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void init_output_tensor(const SmallVector<Tensor*>& outputs) { |
|
|
|
size_t idx = 0; |
|
|
|
mgb_assert(m_opr->usable_output().size() == outputs.size()); |
|
|
|
for (auto&& var : m_opr->output()) { |
|
|
|
auto&& chk = var->m_mem_plan.reset_from_owner_var().chunk(); |
|
|
|
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { |
|
|
|
// alloc workspace |
|
|
|
TensorLayout layout{var->shape(), var->dtype(), var->format()}; |
|
|
|
var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag( |
|
|
|
var->comp_node(), layout); |
|
|
|
} else { |
|
|
|
mgb_assert(idx < outputs.size()); |
|
|
|
auto&& tensor = outputs[idx]; |
|
|
|
auto&& layout = tensor->layout(); |
|
|
|
mgb_assert(var->comp_node() == tensor->comp_node()); |
|
|
|
mgb_assert(var->shape().eq_shape(layout)); |
|
|
|
mgb_assert(var->dtype() == layout.dtype); |
|
|
|
if (!tensor->layout().is_empty()) { |
|
|
|
var->assign_dev_tensor_from_tensor(tensor->dev_tensor()); |
|
|
|
} else { |
|
|
|
var->m_dev_tensor.storage({var->comp_node()}); |
|
|
|
} |
|
|
|
++idx; |
|
|
|
} |
|
|
|
chk.mem_alloc_status.set_from_owner_var(); |
|
|
|
} |
|
|
|
mgb_assert(idx == outputs.size()); |
|
|
|
|
|
|
|
// Memory forwarding was bypassed in megbrain with graph option |
|
|
|
// imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly |
|
|
|
// to initialize some opr(e.g. Subtensor)'s internal state |
|
|
|
// TODO: implement memory forwarding |
|
|
|
m_opr->mem_plan_fwd_in2out_readonly(); |
|
|
|
{ |
|
|
|
// some opr (e.g. Reduce) rely on on_mem_status_changed to set |
|
|
|
// input/output tensor corretly, since we bypass var_node_mem_mgr |
|
|
|
// on_mem_status_changed should be called here |
|
|
|
auto&& cb = m_opr->get_opr_event_callback().on_mem_status_changed; |
|
|
|
if (cb.valid()) { |
|
|
|
cb.val()(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void execute( |
|
|
|
const SmallVector<Tensor*>& inputs, const SmallVector<Tensor*>& outputs, |
|
|
|
cg::GraphExecutable::ExecEnv& env) { |
|
|
|
init_input_tensor(inputs); |
|
|
|
init_output_tensor(outputs); |
|
|
|
m_opr->execute(env); |
|
|
|
for (auto&& i : m_opr->input()) { |
|
|
|
i->m_dev_tensor.storage({}); |
|
|
|
} |
|
|
|
for (auto&& i : m_opr->output()) { |
|
|
|
i->m_dev_tensor.storage({}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void register_shape_infer( |
|
|
|
VarNode* varnode, const cg::static_infer::ShapeInferDesc& desc) { |
|
|
|
auto [found, i] = find_index(m_opr->output(), varnode); |
|
|
@@ -278,15 +398,22 @@ public: |
|
|
|
output_data[i].value_infer.initialize(m_opr, desc.deps, desc.infer_func); |
|
|
|
} |
|
|
|
|
|
|
|
const TensorShape& infer_shape(VarNode* var) { return m_sess->infer_shape(var); } |
|
|
|
const TensorShape& infer_shape(VarNode* var) { |
|
|
|
mgb_assert(m_sess); |
|
|
|
return m_sess->infer_shape(var); |
|
|
|
} |
|
|
|
|
|
|
|
const DeviceTensorND& infer_value(VarNode* var) { return m_sess->infer_value(var); } |
|
|
|
const DeviceTensorND& infer_value(VarNode* var) { |
|
|
|
mgb_assert(m_sess); |
|
|
|
return m_sess->infer_value(var); |
|
|
|
} |
|
|
|
|
|
|
|
OperatorNodeBase* opr() { return m_opr; } |
|
|
|
|
|
|
|
// inference routine template for type of input |
|
|
|
template <typename I> |
|
|
|
class InferSession : protected InferSessionBase { |
|
|
|
public: |
|
|
|
MiniGraph& owner; |
|
|
|
SmallVector<OutputData>& output_data; |
|
|
|
InputAdaptor<I> inputs; |
|
|
@@ -355,7 +482,7 @@ public: |
|
|
|
auto [found, i] = find_index(owner.m_opr->input(), var); |
|
|
|
mgb_assert(found); |
|
|
|
i = owner.input_remap[i]; |
|
|
|
auto* value = inputs.value(i, false); |
|
|
|
auto* value = inputs.value(i, true); |
|
|
|
mgb_assert(value); |
|
|
|
return *value; |
|
|
|
} |
|
|
@@ -379,12 +506,18 @@ public: |
|
|
|
|
|
|
|
const TensorShape* infer_shape(size_t i, bool sync) { |
|
|
|
i = owner.output_remap[i]; |
|
|
|
return infer(output_data[i].shape_infer, sync); |
|
|
|
auto* p = infer(output_data[i].shape_infer, sync); |
|
|
|
if (sync) |
|
|
|
mgb_assert(p, "failed to infer shape"); |
|
|
|
return p; |
|
|
|
} |
|
|
|
|
|
|
|
const DeviceTensorND* infer_value(size_t i, bool sync) { |
|
|
|
i = owner.output_remap[i]; |
|
|
|
return infer(output_data[i].shape_infer, sync); |
|
|
|
auto* p = infer(output_data[i].value_infer, sync); |
|
|
|
if (sync) |
|
|
|
mgb_assert(p, "failed to infer value"); |
|
|
|
return p; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
@@ -499,10 +632,12 @@ class ProxyGraphTypeI : public ProxyGraphBase { |
|
|
|
public: |
|
|
|
void register_shape_infer( |
|
|
|
VarNode* var, const cg::static_infer::ShapeInferDesc& desc) override { |
|
|
|
mgb_assert(target); |
|
|
|
target->register_shape_infer(var, desc); |
|
|
|
}; |
|
|
|
void register_value_infer( |
|
|
|
VarNode* var, const cg::static_infer::ValueInferDesc& desc) override { |
|
|
|
mgb_assert(target); |
|
|
|
target->register_value_infer(var, desc); |
|
|
|
}; |
|
|
|
cg::static_infer::InferType get_infer_type(VarNode*) override { |
|
|
@@ -511,17 +646,22 @@ class ProxyGraphTypeI : public ProxyGraphBase { |
|
|
|
} |
|
|
|
// some poorly written inference func would call infer_{shape,value} |
|
|
|
const TensorShape& infer_shape(VarNode* var) override { |
|
|
|
mgb_assert(target); |
|
|
|
return target->infer_shape(var); |
|
|
|
} |
|
|
|
const DeviceTensorND& infer_value(VarNode* var) override { |
|
|
|
mgb_assert(target); |
|
|
|
return target->infer_value(var); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
ProxyGraph::MiniGraph* target = nullptr; |
|
|
|
StaticInferManager m_static_infer_manager; |
|
|
|
std::unordered_map<size_t, ProxyGraph::MiniGraph> m_mini_graph_cache; |
|
|
|
std::unordered_multimap<size_t, ProxyGraph::MiniGraph> m_mini_graph_cache; |
|
|
|
std::mutex m_mini_graph_cache_mtx; |
|
|
|
size_t opr_count = 0; |
|
|
|
ExecEnvBase m_env; |
|
|
|
CompNode::UnorderedSet m_used_comp_node; |
|
|
|
|
|
|
|
static thread_local std::unique_ptr<ProxyGraphTypeI> sm_instance; |
|
|
|
|
|
|
@@ -531,8 +671,12 @@ class ProxyGraphTypeI : public ProxyGraphBase { |
|
|
|
|
|
|
|
size_t next_node_id() override { return opr_count; } |
|
|
|
|
|
|
|
void add_used_comp_node(CompNode cn) { m_used_comp_node.insert(cn); } |
|
|
|
|
|
|
|
std::shared_ptr<void> on_comp_node_finalize() override { |
|
|
|
sm_instance.reset(); |
|
|
|
assert(!target); |
|
|
|
MGB_LOCK_GUARD(m_mini_graph_cache_mtx); |
|
|
|
m_mini_graph_cache.clear(); |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
@@ -575,38 +719,62 @@ class ProxyGraphTypeI : public ProxyGraphBase { |
|
|
|
} |
|
|
|
|
|
|
|
public: |
|
|
|
~ProxyGraphTypeI() { |
|
|
|
if (is_finalized()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
for (auto&& i : m_used_comp_node) { |
|
|
|
if (i.device_type() == CompNode::DeviceType::CUDA) |
|
|
|
continue; |
|
|
|
i.sync(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) override { |
|
|
|
mgb_assert(target); |
|
|
|
return target->insert_opr(std::move(opr_uniqp)); |
|
|
|
} |
|
|
|
|
|
|
|
static ProxyGraphTypeI& inst() { |
|
|
|
if (!sm_instance) { |
|
|
|
if (!sm_instance || sm_instance->is_finalized()) { |
|
|
|
sm_instance.reset(new ProxyGraphTypeI); |
|
|
|
} |
|
|
|
return *sm_instance; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
template <typename T> |
|
|
|
ProxyGraph::MiniGraph& get_cached_minigraph(const OpDef& def, const T& inputs) { |
|
|
|
mgb_assert(!is_finalized()); |
|
|
|
size_t buf_size = 2 * inputs.size() + 1; |
|
|
|
size_t buf[buf_size]; |
|
|
|
size_t pos = 0; |
|
|
|
buf[pos++] = def.hash(); |
|
|
|
for (auto&& desc : inputs) { |
|
|
|
buf[pos++] = mgb::hash(desc.layout.dtype.handle()); |
|
|
|
buf[pos++] = mgb::hash(desc.comp_node); |
|
|
|
for (auto&& inp : inputs) { |
|
|
|
auto tensor = TensorAdaptor(inp); |
|
|
|
buf[pos++] = mgb::hash(tensor.dtype().handle()); |
|
|
|
buf[pos++] = mgb::hash(tensor.comp_node()); |
|
|
|
} |
|
|
|
mgb_assert(pos == buf_size); |
|
|
|
auto key = XXHash{}.update(buf, buf_size * sizeof(size_t)).digest(); |
|
|
|
auto it = m_mini_graph_cache.find(key); |
|
|
|
if (it == m_mini_graph_cache.end()) { |
|
|
|
auto&& result = m_mini_graph_cache.emplace( |
|
|
|
std::piecewise_construct, std::make_tuple(key), |
|
|
|
std::forward_as_tuple(*this, def, inputs)); |
|
|
|
mgb_assert(result.second); |
|
|
|
it = result.first; |
|
|
|
} |
|
|
|
auto& minigraph = it->second; |
|
|
|
auto its = m_mini_graph_cache.equal_range(key); |
|
|
|
auto it = its.first; |
|
|
|
for (; it != its.second; ++it) { |
|
|
|
if (it->second.is_same_buf(buf, buf_size)) { |
|
|
|
return it->second; |
|
|
|
} |
|
|
|
mgb_log_warn("hash collision occurs in minigraph cache with key: %lu", key); |
|
|
|
} |
|
|
|
auto&& result = m_mini_graph_cache.emplace( |
|
|
|
std::piecewise_construct, std::make_tuple(key), |
|
|
|
std::forward_as_tuple( |
|
|
|
*this, def, inputs, static_cast<size_t*>(buf), buf_size)); |
|
|
|
mgb_assert(result->first); |
|
|
|
return result->second; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto& minigraph = get_cached_minigraph(def, inputs); |
|
|
|
auto _ = scoped_attach(&minigraph); |
|
|
|
auto sess = minigraph.infer_session(inputs); |
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret; |
|
|
@@ -627,6 +795,88 @@ public: |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> infer_output_attrs( |
|
|
|
const OpDef& def, const SmallVector<Tensor*>& inputs) { |
|
|
|
SmallVector<LogicalTensorDesc> descs; |
|
|
|
auto& minigraph = get_cached_minigraph(def, inputs); |
|
|
|
auto _ = scoped_attach(&minigraph); |
|
|
|
auto sess = minigraph.infer_session(inputs); |
|
|
|
// some output var in minigraph.opr()->output() may not appears in |
|
|
|
// minigraph.opr()->usable_output() bug execution may use the attrs for those |
|
|
|
// output var, so we infer attrs for all outputs, but only return |
|
|
|
// LogicalTensorDesc for minigraph.opr()->usable_output() |
|
|
|
for (size_t i = 0; i < minigraph.opr()->output().size(); ++i) { |
|
|
|
auto* shape = sess.infer(sess.output_data[i].shape_infer, true); |
|
|
|
mgb_assert(shape); |
|
|
|
minigraph.opr()->output()[i]->shape(*shape); |
|
|
|
} |
|
|
|
descs.reserve(minigraph.output_size()); |
|
|
|
for (size_t i = 0; i < minigraph.output_size(); ++i) { |
|
|
|
auto* ovar = minigraph.output_var(i); |
|
|
|
descs.emplace_back(); |
|
|
|
auto& desc = descs.back(); |
|
|
|
desc.layout.dtype = ovar->dtype(); |
|
|
|
desc.comp_node = ovar->comp_node(); |
|
|
|
mgb_assert(ovar->dtype().valid() && ovar->comp_node().valid()); |
|
|
|
mgb_assert( |
|
|
|
ovar->shape().ndim || |
|
|
|
ovar->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)); |
|
|
|
desc.layout.init_contiguous_stride(ovar->shape()); |
|
|
|
} |
|
|
|
return descs; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<LogicalTensorDesc> infer_output_attrs( |
|
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs) { |
|
|
|
return infer_output_attrs(def, to_raw_ptr_array(inputs)); |
|
|
|
} |
|
|
|
|
|
|
|
void exec( |
|
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs, |
|
|
|
const SmallVector<TensorPtr>& outputs) { |
|
|
|
auto raw_inputs = to_raw_ptr_array(inputs), |
|
|
|
raw_outputs = to_raw_ptr_array(outputs); |
|
|
|
CompNode::UnorderedSet used_cns; |
|
|
|
for (auto&& out : raw_outputs) { |
|
|
|
auto cn = out->comp_node(); |
|
|
|
add_used_comp_node(cn); |
|
|
|
if (used_cns.insert(cn).second) { |
|
|
|
for (auto&& in : inputs) { |
|
|
|
if (in->comp_node() != cn) { |
|
|
|
auto&& e = in->get_or_create_event(); |
|
|
|
e->device_wait_by(cn); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
auto& minigraph = get_cached_minigraph(def, raw_inputs); |
|
|
|
auto _ = scoped_attach(&minigraph); |
|
|
|
// some opr (e.g. Subtensor) may invoke infer_value during execution, |
|
|
|
// so we need create inference session here |
|
|
|
auto sess = minigraph.infer_session(raw_inputs); |
|
|
|
minigraph.execute(raw_inputs, raw_outputs, m_env); |
|
|
|
for (auto&& cn : used_cns) { |
|
|
|
for (auto&& in : inputs) { |
|
|
|
if (in->comp_node() != cn) { |
|
|
|
in->add_release_callback(cn); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, SmallVector<TensorPtr> inputs) { |
|
|
|
auto&& raw_inputs = to_raw_ptr_array(inputs); |
|
|
|
auto output_descs = infer_output_attrs(def, raw_inputs); |
|
|
|
SmallVector<TensorPtr> outputs(output_descs.size(), {}); |
|
|
|
for (size_t i = 0; i < outputs.size(); i++) { |
|
|
|
outputs[i] = |
|
|
|
Tensor::make(output_descs[i].layout, output_descs[i].comp_node); |
|
|
|
} |
|
|
|
exec(def, inputs, outputs); |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
} // namespace mgb::imperative::proxy_graph |