From a9b4cf400ab86baf3b6f6e2059831398018752b8 Mon Sep 17 00:00:00 2001 From: wjm Date: Tue, 29 Dec 2020 15:03:03 +0800 Subject: [PATCH 1/2] fix dynamic aipp error --- ge/graph/passes/multi_batch_clone_pass.cc | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 872f94fb..516e7eb7 100644 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -22,6 +22,8 @@ #include "graph/preprocess/multi_batch_options.h" #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" #include "register/op_registry.h" namespace ge { @@ -478,8 +480,28 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { return SUCCESS; } - (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); + + GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); + std::vector input_dims_str; + for (size_t i = 0; i < batch_shapes_.size(); ++i) { + auto shape = data_shape; + auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str()); + return ret; + } + tensor.SetShape(shape); + int64_t tensor_size = 0; + (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); + string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + + TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + + std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + + formats::JoinToString(tensor.GetShape().GetDims()); + input_dims_str.emplace_back(input_str); + } + (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); + size_t max_shape_index = 0; int64_t max_size = 0; for (size_t i = 0; i < batch_shapes_.size(); ++i) { @@ -593,7 +615,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const graph->AddSubgraph(subgraph->GetName(), subgraph); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), - "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); + "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); const string key_name = "branches" + std::to_string(i); op_desc->AddSubgraphName(key_name); From ba745a12d34de7fed9d9e849c8e725c6fcdd3250 Mon Sep 17 00:00:00 2001 From: wjm Date: Wed, 30 Dec 2020 11:32:11 +0800 Subject: [PATCH 2/2] fix --- ge/graph/passes/multi_batch_clone_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 516e7eb7..f8451ace 100644 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -495,8 +495,8 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { int64_t tensor_size = 0; (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + - TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + - std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + + TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + + std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + formats::JoinToString(tensor.GetShape().GetDims()); input_dims_str.emplace_back(input_str); }