|
|
@@ -126,20 +126,6 @@ Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set<OpDesc *> |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckHasHostMem(NodeItem &node_item) { |
|
|
|
if (node_item.NodeType() == DATA) { |
|
|
|
auto op_desc = node_item.GetOpDesc(); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto tensor = op_desc->MutableInputDesc(0); |
|
|
|
if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) |
|
|
|
: hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) { |
|
|
@@ -298,6 +284,47 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::ParseDependencies(NodeItem &node_item, const std::vector<string> &dependencies, |
|
|
|
std::set<NodePtr> &dependent_for_shape_inference) { |
|
|
|
auto &ge_node = node_item.node; |
|
|
|
for (const auto &input_name : dependencies) { |
|
|
|
int input_index = node_item.op_desc->GetInputIndexByName(input_name); |
|
|
|
if (input_index < 0) { |
|
|
|
GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", |
|
|
|
node_item.NodeName().c_str(), input_name.c_str()); |
|
|
|
REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", |
|
|
|
node_item.NodeName().c_str(), input_name.c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &in_anchor = ge_node->GetInDataAnchor(input_index); |
|
|
|
GE_CHECK_NOTNULL(in_anchor); |
|
|
|
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
|
GE_CHECK_NOTNULL(peer_out_anchor); |
|
|
|
const auto &src_node = peer_out_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL(src_node); |
|
|
|
auto src_node_item = MutableNodeItem(src_node); |
|
|
|
GE_CHECK_NOTNULL(src_node_item); |
|
|
|
if (src_node_item->NodeType() == DATA) { |
|
|
|
auto op_desc = src_node_item->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
auto tensor = op_desc->MutableInputDesc(0); |
|
|
|
if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { |
|
|
|
GELOGD("Skip d2h memcpy, get hostmem from node %s.", src_node_item->NodeName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); |
|
|
|
dependent_for_shape_inference.emplace(src_node); |
|
|
|
host_input_value_dependencies_[&node_item].emplace_back(peer_out_anchor->GetIdx(), src_node_item); |
|
|
|
GELOGD("[%s] Dependent added from output of [%s:%d]", |
|
|
|
node_item.NodeName().c_str(), |
|
|
|
src_node_item->NodeName().c_str(), |
|
|
|
peer_out_anchor->GetIdx()); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { |
|
|
|
std::set<NodePtr> dependent_for_shape_inference; |
|
|
|
std::set<NodePtr> dependent_for_execution; |
|
|
@@ -361,35 +388,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s |
|
|
|
src_node_item->NodeName().c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &input_name : dependencies) { |
|
|
|
int input_index = node_item.op_desc->GetInputIndexByName(input_name); |
|
|
|
if (input_index < 0) { |
|
|
|
GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", |
|
|
|
node_item.NodeName().c_str(), input_name.c_str()); |
|
|
|
REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", |
|
|
|
node_item.NodeName().c_str(), input_name.c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &in_anchor = ge_node->GetInDataAnchor(input_index); |
|
|
|
GE_CHECK_NOTNULL(in_anchor); |
|
|
|
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
|
GE_CHECK_NOTNULL(peer_out_anchor); |
|
|
|
const auto &src_node = peer_out_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL(src_node); |
|
|
|
auto src_node_item = MutableNodeItem(src_node); |
|
|
|
GE_CHECK_NOTNULL(src_node_item); |
|
|
|
GE_IF_BOOL_EXEC(CheckHasHostMem(*src_node_item), |
|
|
|
GELOGD("Skip d2h memcpy, get hostmem from node %s.", src_node_item->NodeName().c_str()); |
|
|
|
continue;) |
|
|
|
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); |
|
|
|
dependent_for_shape_inference.emplace(src_node); |
|
|
|
host_input_value_dependencies_[&node_item].emplace_back(peer_out_anchor->GetIdx(), src_node_item); |
|
|
|
GELOGD("[%s] Dependent added from output of [%s:%d]", |
|
|
|
node_item.NodeName().c_str(), |
|
|
|
src_node_item->NodeName().c_str(), |
|
|
|
peer_out_anchor->GetIdx()); |
|
|
|
} |
|
|
|
GE_CHK_STATUS_RET(ParseDependencies(node_item, dependencies, dependent_for_shape_inference)); |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference)); |
|
|
|
for (const auto &dep_node : dependent_for_shape_inference) { |
|
|
|