diff --git a/ge/graph/build/graph_builder.cc b/ge/graph/build/graph_builder.cc index 87d2a206..ee9be124 100644 --- a/ge/graph/build/graph_builder.cc +++ b/ge/graph/build/graph_builder.cc @@ -15,6 +15,7 @@ */ #include "graph/build/graph_builder.h" +#include "graph/build/memory/graph_mem_assigner.h" #include "common/ge/ge_util.h" #include "common/helper/model_helper.h" #include "graph/build/logical_stream_allocator.h" @@ -197,10 +198,8 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vectorGetGraphUnknownFlag()) { GE_CHK_STATUS_RET( BuildForDynamicShapeGraph(comp_graph, subgraph_ptr_list, ge_root_model_ptr, ge_model_ptr, session_id), "Build for dynamic shape graph failed."); @@ -270,16 +269,78 @@ Status GraphBuilder::BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, std::v return SUCCESS; } +Status GraphBuilder::SetConstantInputOffset(ComputeGraphPtr &comp_graph) { + for (auto &node : comp_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto num_inputs = op_desc->GetInputsSize(); + std::vector input_offsets(num_inputs, 0); + int valid_input_index = -1; + for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + auto in_anchor = node->GetInDataAnchor(i); + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + + ++valid_input_index; + auto peer_node = peer_out_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; + } + + if (peer_node->GetType() != CONSTANT) { + continue; + } + + std::vector weights = OpDescUtils::MutableWeights(peer_node); + if (weights.empty()) { + GELOGE(FAILED, "weights size of node %s is empty", node->GetName().c_str()); + return FAILED; + } + GeTensorPtr weight = weights[0]; + GE_CHECK_NOTNULL(weight); + int64_t input_offset = 0; + (void) TensorUtils::GetDataOffset(weight->MutableTensorDesc(), input_offset); + // valid_input_index must smaller than num_inputs + input_offsets[valid_input_index] = input_offset; + GELOGD("[%s] input[%u] is const, offset = %ld", node->GetName().c_str(), valid_input_index, input_offset); + } + + op_desc->SetInputOffset(input_offsets); + std::vector output_offsets(op_desc->GetOutputsSize(), 0); + op_desc->SetOutputOffset(output_offsets); + } + return SUCCESS; +} + Status GraphBuilder::BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, uint64_t session_id) { GELOGI("Begin to build unknown shape graph[%s].", comp_graph->GetName().c_str()); + Graph2SubGraphInfoList subgraph_map; + ge::ModelBuilder builder(session_id, comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); + GE_DUMP(comp_graph, "BeforePreBuildModel"); + GE_TIMESTAMP_START(PreBuildModel); + GE_CHK_STATUS_RET(builder.PreBuildModel(), "Graph[%s] builder PreBuildModel() return fail.", + comp_graph->GetName().c_str()); + GE_TIMESTAMP_END(PreBuildModel, "GraphBuilder::PreBuildModel"); + GE_DUMP(comp_graph, "AfterPreBuildModel"); + GE_TIMESTAMP_START(CalcOpParam); GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Graph[%s] builder CalcOpParam() return fail.", comp_graph->GetName().c_str()); GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); GE_DUMP(comp_graph, "AfterCalcOpParam"); - Graph2SubGraphInfoList subgraph_map; - ge::ModelBuilder builder(session_id, comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); + + GE_TIMESTAMP_START(SetConstantInputOffset); + GE_CHK_STATUS_RET(SetConstantInputOffset(comp_graph), + "Graph[%s] failed to set constant input offset.", comp_graph->GetName().c_str()); + GE_TIMESTAMP_END(SetConstantInputOffset); + GE_TIMESTAMP_START(MergeWeights); + GE_CHK_STATUS_RET(MergeWeights(), "Graph[%s] failed to merge weights.", comp_graph->GetName().c_str()); + GE_TIMESTAMP_END(MergeWeights, "GraphBuilder::MergeWeights"); + ModelPtr model_ptr = MakeShared(); if (model_ptr == nullptr) { return MEMALLOC_FAILED; @@ -375,10 +436,15 @@ Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, op_desc->GetName().c_str()); } } - // - for (auto &sub_graph : comp_graph->GetAllSubgraphs()) { + + auto all_graphs = comp_graph->GetAllSubgraphs(); + if (all_graphs.empty()) { + all_graphs.push_back(comp_graph); + } + for (auto &sub_graph : all_graphs) { // exclude functional subgraph in known subgraph - if (sub_graph->GetParentGraph() != comp_graph && !sub_graph->GetParentGraph()->GetGraphUnknownFlag()) { + if (sub_graph->GetParentGraph() != nullptr && sub_graph->GetParentGraph() != comp_graph && + !sub_graph->GetParentGraph()->GetGraphUnknownFlag()) { continue; } diff --git a/ge/graph/build/graph_builder.h b/ge/graph/build/graph_builder.h index 329f3ebc..b828a80d 100644 --- a/ge/graph/build/graph_builder.h +++ b/ge/graph/build/graph_builder.h @@ -67,6 +67,7 @@ class GraphBuilder { GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); Status BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); + Status SetConstantInputOffset(ComputeGraphPtr &comp_graph); Status AddOutputMemTypeForNode(const NodePtr &node); Status BuildForHostCpuGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); diff --git a/ge/graph/build/model_builder.h b/ge/graph/build/model_builder.h index de079768..12420614 100644 --- a/ge/graph/build/model_builder.h +++ b/ge/graph/build/model_builder.h @@ -55,13 +55,13 @@ class ModelBuilder { ge::Buffer GetWeightBuffer() const; + Status MergeWeights(); + protected: void AddNodeInputProperty(); void ClearOriginalFormat(); - Status MergeWeights(); - private: bool SetInputConst(const OpDescPtr &op_desc, const NodePtr &src_node, size_t index, vector &is_input_const); diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 25bf6855..2a1a14e6 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -477,6 +477,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & single_op.num_inputs_ = data_ops_.size(); single_op.num_outputs_ = netoutput_op_->GetAllInputsSize(); GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); + model_params_.memory_size = UINT_MAX; return BuildTaskListForDynamicOp(single_op); } } // namespace ge