|
|
@@ -58,6 +58,41 @@ int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { |
|
|
|
} |
|
|
|
return var_size; |
|
|
|
} |
|
|
|
|
|
|
|
Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set<OpDesc *> &data_ops) { |
|
|
|
for (const auto &node : node_item.fused_subgraph->nodes) { |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
const auto &depends = op_desc->GetOpInferDepends(); |
|
|
|
if (depends.empty()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &input_name : depends) { |
|
|
|
auto input_index = op_desc->GetInputIndexByName(input_name); |
|
|
|
const auto &in_anchor = ge_node->GetInDataAnchor(input_index); |
|
|
|
GE_CHECK_NOTNULL(in_anchor); |
|
|
|
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
|
GE_CHECK_NOTNULL(peer_out_anchor); |
|
|
|
const auto &src_node = peer_out_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL(src_node); |
|
|
|
auto src_op_desc = src_node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(src_op_desc); |
|
|
|
if (src_node->GetType() != DATA_TYPE) { |
|
|
|
GELOGE(UNSUPPORTED, |
|
|
|
"[%s::%s] Node in fused subgraph can only depend on Data nodes, but depend on %s", |
|
|
|
node_item.NodeName().c_str(), |
|
|
|
node->GetName().c_str(), |
|
|
|
src_node->GetType().c_str()); |
|
|
|
return UNSUPPORTED; |
|
|
|
} |
|
|
|
|
|
|
|
data_ops.emplace(src_op_desc.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) |
|
|
|
: hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) { |
|
|
@@ -272,6 +307,53 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s |
|
|
|
node_item.dependents_for_shape_inference.emplace_back(dep_node); |
|
|
|
} |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(noe_item)); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { |
|
|
|
if (node_item.fused_subgraph == nullptr) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
std::set<OpDesc *> data_ops; |
|
|
|
GE_CHK_STATUS_RET_NOLOG(CollectDependenciesForFusedGraph(node_item, data_ops)); |
|
|
|
for (auto &op_desc : data_ops) { |
|
|
|
uint32_t parent_index = 0; |
|
|
|
if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { |
|
|
|
GELOGE(INTERNAL_ERROR, |
|
|
|
"[%s] Failed to get attr [%s]", |
|
|
|
data_op_desc->GetName().c_str(), |
|
|
|
ATTR_NAME_PARENT_NODE_INDEX.c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &in_anchor = node_item.node->GetInDataAnchor(parent_index); |
|
|
|
GE_CHECK_NOTNULL(in_anchor); |
|
|
|
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
|
GE_CHECK_NOTNULL(peer_out_anchor); |
|
|
|
const auto &src_node = peer_out_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL(src_node); |
|
|
|
NodeItem *src_node_item = nullptr; |
|
|
|
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(src_node, &src_node_item)); |
|
|
|
op_desc->SetId(src_node_item->op_desc->GetId()); |
|
|
|
GELOGD("[%S::%S] Node id was set to that of outer src node's, src_node = %s", |
|
|
|
node_item.NodeName().c_str(), |
|
|
|
op_desc->GetName().c_str(), |
|
|
|
src_node_item->NodeName().c_str()); |
|
|
|
src_node_item->has_observer = true; |
|
|
|
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); |
|
|
|
|
|
|
|
auto &depends = node_item.dependents_for_shape_inference; |
|
|
|
if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) { |
|
|
|
depends.emplace_back(src_node); |
|
|
|
GELOGD("[%s] Dependent added from output of [%s:%d]", |
|
|
|
node_item.NodeName().c_str(), |
|
|
|
src_node_item->NodeName().c_str(), |
|
|
|
peer_out_anchor->GetIdx()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|