@@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST | |||||
"client/ge_api.cc" | "client/ge_api.cc" | ||||
"client/ge_prof.cc" | "client/ge_prof.cc" | ||||
"analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
"ir_build/atc_ir_common.cc" | |||||
) | ) | ||||
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | ||||
@@ -475,6 +475,9 @@ REGISTER_OPTYPE_DEFINE(HVDCALLBACKALLGATHER, "HorovodAllgather"); | |||||
REGISTER_OPTYPE_DEFINE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); | REGISTER_OPTYPE_DEFINE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); | ||||
REGISTER_OPTYPE_DEFINE(HVDWAIT, "HorovodWait"); | REGISTER_OPTYPE_DEFINE(HVDWAIT, "HorovodWait"); | ||||
// aicpu op for online_infer dynamic_dims | |||||
REGISTER_OPTYPE_DEFINE(GETDYNAMICDIMS, "GetDynamicDims"); | |||||
const std::string MODEL_ATTR_TASKS = "tasks"; | const std::string MODEL_ATTR_TASKS = "tasks"; | ||||
const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | ||||
const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; | const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; | ||||
@@ -61,6 +61,7 @@ set(SRC_LIST | |||||
"../graph/load/new_model_manager/task_info/model_exit_task_info.cc" | "../graph/load/new_model_manager/task_info/model_exit_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | ||||
"../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | "../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | ||||
"../graph/common/local_context.cc" | |||||
"../opskernel_manager/ops_kernel_builder_manager.cc" | "../opskernel_manager/ops_kernel_builder_manager.cc" | ||||
"../single_op/single_op_manager.cc" | "../single_op/single_op_manager.cc" | ||||
"../single_op/single_op_model.cc" | "../single_op/single_op_model.cc" | ||||
@@ -63,6 +63,7 @@ local_ge_executor_src_files := \ | |||||
../single_op/task/aicpu_kernel_task_builder.cc \ | ../single_op/task/aicpu_kernel_task_builder.cc \ | ||||
../hybrid/hybrid_davinci_model_stub.cc\ | ../hybrid/hybrid_davinci_model_stub.cc\ | ||||
../hybrid/node_executor/aicpu/aicpu_ext_info.cc \ | ../hybrid/node_executor/aicpu/aicpu_ext_info.cc \ | ||||
../graph/common/local_context.cc \ | |||||
local_ge_executor_c_include := \ | local_ge_executor_c_include := \ | ||||
proto/insert_op.proto \ | proto/insert_op.proto \ | ||||
@@ -300,6 +300,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
hybrid/hybrid_davinci_model.cc \ | hybrid/hybrid_davinci_model.cc \ | ||||
executor/ge_executor.cc \ | executor/ge_executor.cc \ | ||||
analyzer/analyzer.cc \ | analyzer/analyzer.cc \ | ||||
ir_build/atc_ir_common.cc \ | |||||
LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
proto/ge_api.proto \ | proto/ge_api.proto \ | ||||
@@ -62,6 +62,8 @@ | |||||
#include "runtime/rt_model.h" | #include "runtime/rt_model.h" | ||||
#include "runtime/stream.h" | #include "runtime/stream.h" | ||||
#include "securec.h" | #include "securec.h" | ||||
#include "graph/common/local_context.h" | |||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
// create std::thread, catch exceptions using try/catch | // create std::thread, catch exceptions using try/catch | ||||
#define CREATE_STD_THREAD(thread_id, func, args) \ | #define CREATE_STD_THREAD(thread_id, func, args) \ | ||||
@@ -80,6 +82,7 @@ namespace { | |||||
const uint32_t kDataIndex = 0; | const uint32_t kDataIndex = 0; | ||||
const uint32_t kOutputNum = 1; | const uint32_t kOutputNum = 1; | ||||
const uint32_t kTrueBranchStreamNum = 1; | const uint32_t kTrueBranchStreamNum = 1; | ||||
const uint32_t kGetDynamicDimsCount = 1; | |||||
const uint32_t kThreadNum = 16; | const uint32_t kThreadNum = 16; | ||||
const uint32_t kAddrLen = sizeof(void *); | const uint32_t kAddrLen = sizeof(void *); | ||||
const int kDecimal = 10; | const int kDecimal = 10; | ||||
@@ -883,6 +886,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
GE_TIMESTAMP_ADD(InitTbeHandle); | GE_TIMESTAMP_ADD(InitTbeHandle); | ||||
} | } | ||||
AdjustDataOpList(data_by_index); | AdjustDataOpList(data_by_index); | ||||
GE_TIMESTAMP_CALLNUM_END(LoadTBEKernelBinToOpDesc, "GraphLoader::LoadTBEKernelBinToOpDesc."); | GE_TIMESTAMP_CALLNUM_END(LoadTBEKernelBinToOpDesc, "GraphLoader::LoadTBEKernelBinToOpDesc."); | ||||
GE_TIMESTAMP_CALLNUM_END(InitTbeHandle, "GraphLoader::InitTbeHandle."); | GE_TIMESTAMP_CALLNUM_END(InitTbeHandle, "GraphLoader::InitTbeHandle."); | ||||
@@ -1038,6 +1042,15 @@ Status DavinciModel::InitInputZeroCopyNodes(const NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool DavinciModel::IsGetNextSinkDynamic(const OpDescPtr &op_desc) { | |||||
bool getnext_sink_dynamic = false; | |||||
if (ge::AttrUtils::GetBool(op_desc, ATTR_GETNEXT_SINK_DYNMAIC, getnext_sink_dynamic) && getnext_sink_dynamic) { | |||||
GELOGI("ATTR_GETNEXT_SINK_DYNMAIC has been set and is true."); | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief NetOutput Op Initialize. | /// @brief NetOutput Op Initialize. | ||||
/// @param [in] NodePtr: NetOutput Op. | /// @param [in] NodePtr: NetOutput Op. | ||||
@@ -1079,7 +1092,13 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { | |||||
size_t num = new_output_data_info_.size(); | size_t num = new_output_data_info_.size(); | ||||
bool fusion_flag = false; | bool fusion_flag = false; | ||||
for (size_t idx = 0; idx < input_size_list.size(); ++idx) { | |||||
size_t input_count = input_size_list.size(); | |||||
bool is_getnext_sink_dynamic = false; | |||||
if (IsGetNextSinkDynamic(op_desc)) { | |||||
input_count = input_size_list.size() - kGetDynamicDimsCount; | |||||
is_getnext_sink_dynamic = true; | |||||
} | |||||
for (size_t idx = 0; idx < input_count; ++idx) { | |||||
ZeroCopyOffset zero_copy_offset; | ZeroCopyOffset zero_copy_offset; | ||||
Status ret = zero_copy_offset.InitOutputDataInfo(input_size_list, virtual_addr_list, op_desc, idx, fusion_flag); | Status ret = zero_copy_offset.InitOutputDataInfo(input_size_list, virtual_addr_list, op_desc, idx, fusion_flag); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -1109,10 +1128,164 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { | |||||
GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); | GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (is_getnext_sink_dynamic) { | |||||
merge_nodes_gear_and_real_out_size_info_.clear(); | |||||
merge_nodes_gear_and_real_out_shape_info_.clear(); | |||||
GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, | |||||
GELOGE(PARAM_INVALID, "Failed to get info of getdynamicdims node."); return PARAM_INVALID;); | |||||
GE_IF_BOOL_EXEC(GetGearAndRealOutSizeInfo(input_count, node) != SUCCESS, | |||||
GELOGE(PARAM_INVALID, "Failed to get gear and real out size info."); return PARAM_INVALID;); | |||||
GE_IF_BOOL_EXEC(GetGearAndRealOutShapeInfo(input_count, op_desc) != SUCCESS, | |||||
GELOGE(PARAM_INVALID, "Failed to get gear and real out shape info."); return PARAM_INVALID;); | |||||
} | |||||
GELOGI("DavinciModel::InitNetoutput success."); | GELOGI("DavinciModel::InitNetoutput success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DavinciModel::GetGetDynamicDimsNodeInfo(const NodePtr &node) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
size_t input_count = node->GetAllInDataAnchors().size(); | |||||
GELOGI("input_anchor count of %s is %zu.", node->GetName().c_str(), input_count); | |||||
size_t get_dynamic_dims_index = input_count - kGetDynamicDimsCount; | |||||
auto in_anchor = node->GetAllInDataAnchors().at(get_dynamic_dims_index); | |||||
auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
GELOGE(PARAM_INVALID, "Out anchor of getdynmaicdims node should not be nullptr."); | |||||
return PARAM_INVALID; | |||||
} | |||||
auto peer_node = peer_out_anchor->GetOwnerNode(); | |||||
auto op_desc = peer_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->GetName() == "ascend_mbatch_get_dynamic_dims_node" && op_desc->GetType() == GETDYNAMICDIMS) { | |||||
GELOGI("Start get info of %s.", op_desc->GetName().c_str()); | |||||
auto input_addr = ModelUtils::GetInputDataAddrs(runtime_param_, node->GetOpDesc()); | |||||
auto input_size = ModelUtils::GetInputSize(node->GetOpDesc()); | |||||
if (input_addr.empty() || input_size.empty()) { | |||||
GELOGE(PARAM_INVALID, "Not set output of %s", op_desc->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
auto input_desc = node->GetOpDesc()->GetInputDescPtr(get_dynamic_dims_index); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
if (input_desc->GetShape().GetDims().empty()) { | |||||
GELOGE(PARAM_INVALID, "Not set output desc shape of %s.", op_desc->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
netoutput_last_input_addr_ = input_addr[get_dynamic_dims_index]; | |||||
netoutput_last_input_size_ = input_size[get_dynamic_dims_index]; | |||||
shape_of_cur_dynamic_dims_ = input_desc->GetShape().GetDims().at(0); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node) { | |||||
GELOGI("Start get gear and real output size info for each input merge node."); | |||||
for (size_t idx = 0; idx < input_count; ++idx) { | |||||
auto in_anchor = node->GetAllInDataAnchors().at(idx); | |||||
auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto peer_node = peer_out_anchor->GetOwnerNode(); | |||||
auto op_desc = peer_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if ((peer_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||||
if (GetRealOutputSizeOfMerge(idx, peer_node) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Get real output size of %s failed.", peer_node->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node) { | |||||
GELOGD("Start get output size of %s, which is %zu input to netoutput.", merge_node->GetName().c_str(), input_index); | |||||
std::map<vector<int64_t>, int64_t> gear_and_real_out_size_info; | |||||
for (auto &in_anchor : merge_node->GetAllInDataAnchors()) { | |||||
auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto in_node = peer_out_anchor->GetOwnerNode(); | |||||
GELOGD("Input node of merge is %s.", in_node->GetName().c_str()); | |||||
auto op_desc = in_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
string batch_label; | |||||
if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
size_t batch_index = static_cast<size_t>(stoi(batch_label.substr(batch_label.rfind('_') + 1))); | |||||
GELOGD("Batch index of %s is %zu.", op_desc->GetName().c_str(), batch_index); | |||||
if (batch_index > ge::GetLocalOmgContext().all_gears_info.size()) { | |||||
GELOGE(PARAM_INVALID, "The value of ATTR_NAME_BATCH_LABEL is invalid."); | |||||
return PARAM_INVALID; | |||||
} | |||||
const vector<int64_t> output_size_list = ModelUtils::GetOutputSize(op_desc); | |||||
int output_index = ge::AnchorUtils::GetIdx(peer_out_anchor); | |||||
auto tensor_desc = op_desc->GetOutputDescPtr(output_index); | |||||
GE_CHECK_NOTNULL(tensor_desc); | |||||
int64_t data_size = 0; | |||||
if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, data_size) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Get tensor size in bytes failed."); | |||||
return FAILED; | |||||
} | |||||
gear_and_real_out_size_info[ge::GetLocalOmgContext().all_gears_info[batch_index]] = data_size; | |||||
GELOGI("Get real gear index is: %zu, GetSize is %ld, GetTensorSizeInBytes is %ld", | |||||
batch_index, output_size_list[output_index], data_size); | |||||
} | |||||
} | |||||
merge_nodes_gear_and_real_out_size_info_[input_index] = gear_and_real_out_size_info; | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc) { | |||||
GELOGI("Start to get dynamic output dims of %s.", op_desc->GetName().c_str()); | |||||
std::vector<std::string> dynamic_output_shape_info; | |||||
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { | |||||
GELOGD("Can not get dynamic output dims attr"); | |||||
return SUCCESS; | |||||
} | |||||
GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); | |||||
std::vector<vector<int64_t>> dynamic_output_shape; | |||||
ParseDynamicOutShape(dynamic_output_shape_info, dynamic_output_shape); | |||||
// idx: input_index to netoutput | |||||
for (size_t idx = 0; idx < input_count; ++idx) { | |||||
std::map<vector<int64_t>, vector<int64_t>> gear_and_real_out_shape_info; | |||||
vector<int64_t> output_shape; | |||||
for (auto &it : dynamic_output_shape) { | |||||
auto gear_index = static_cast<size_t>(it[0]); | |||||
if (gear_index > GetLocalOmgContext().all_gears_info.size()) { | |||||
GELOGE(PARAM_INVALID, "The value of cur index: %zu is invalid.", static_cast<size_t>(it[0])); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (static_cast<size_t>(it[1]) == idx) { | |||||
for (size_t i = 2; i < it.size(); ++i) { | |||||
output_shape.emplace_back(it[i]); | |||||
} | |||||
gear_and_real_out_shape_info[GetLocalOmgContext().all_gears_info[gear_index]] = output_shape; | |||||
} | |||||
} | |||||
merge_nodes_gear_and_real_out_shape_info_[idx] = gear_and_real_out_shape_info; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void DavinciModel::ParseDynamicOutShape(const std::vector<std::string> &str_info, std::vector<vector<int64_t>> &vec_info) { | |||||
for (size_t i = 0; i < str_info.size(); ++i) { | |||||
std::vector<int64_t> shape; | |||||
std::vector<std::string> dims = ge::StringUtils::Split(str_info[i], ','); | |||||
for (const auto &dim : dims) { | |||||
if (dim.empty()) { | |||||
continue; | |||||
} | |||||
shape.emplace_back(std::strtol(dim.c_str(), nullptr, kDecimal)); | |||||
} | |||||
GELOGI("Shape from attr is %s.", formats::JoinToString(shape).c_str()); | |||||
vec_info.emplace_back(shape); | |||||
} | |||||
} | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief output zero copy node Initialize. | /// @brief output zero copy node Initialize. | ||||
@@ -2311,6 +2484,7 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, r | |||||
} | } | ||||
std::vector<DataBuffer> &blobs = output_data.blobs; | std::vector<DataBuffer> &blobs = output_data.blobs; | ||||
size_t idx = 0; | |||||
for (const auto &output : new_output_data_info_) { | for (const auto &output : new_output_data_info_) { | ||||
if (output.first >= blobs.size()) { | if (output.first >= blobs.size()) { | ||||
GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), | GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), | ||||
@@ -2336,13 +2510,19 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, r | |||||
} else if (buffer.length > mem_size) { | } else if (buffer.length > mem_size) { | ||||
GELOGW("Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); | GELOGW("Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); | ||||
} | } | ||||
uint64_t data_size = output.second.GetDataSize(); | |||||
int64_t data_size = output.second.GetDataSize(); | |||||
if (is_getnext_sink_dynamic_) { | |||||
auto gear_and_real_out_size_info = merge_nodes_gear_and_real_out_size_info_[idx]; | |||||
data_size = gear_and_real_out_size_info[cur_dynamic_dims_]; | |||||
} | |||||
uint64_t buffer_length = buffer.length; | uint64_t buffer_length = buffer.length; | ||||
void *buffer_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data)); | void *buffer_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data)); | ||||
GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%u] memaddr[%p] mem_size[%lu] datasize[%u]", | |||||
GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%u] memaddr[%p] mem_size[%ld] datasize[%u]", | |||||
runtime_param_.graph_id, output.first, output.second.GetBasicAddr(), data_size, buffer_length); | runtime_param_.graph_id, output.first, output.second.GetBasicAddr(), data_size, buffer_length); | ||||
GE_CHK_RT_RET(rtMemcpy(buffer_addr, buffer_length, output.second.GetBasicAddr(), data_size, kind)); | GE_CHK_RT_RET(rtMemcpy(buffer_addr, buffer_length, output.second.GetBasicAddr(), data_size, kind)); | ||||
idx++; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2358,19 +2538,29 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data | |||||
std::vector<int64_t> out_buffer_size_vec; | std::vector<int64_t> out_buffer_size_vec; | ||||
std::vector<std::vector<int64_t>> shape_info_vec; | std::vector<std::vector<int64_t>> shape_info_vec; | ||||
size_t input_num = op_desc->GetInputsSize(); | size_t input_num = op_desc->GetInputsSize(); | ||||
if (is_getnext_sink_dynamic_) { | |||||
input_num = input_num - kGetDynamicDimsCount; | |||||
} | |||||
for (size_t i = 0; i < input_num; ++i) { | for (size_t i = 0; i < input_num; ++i) { | ||||
int64_t size = 0; | int64_t size = 0; | ||||
auto input_desc = op_desc->GetInputDescPtr(i); | auto input_desc = op_desc->GetInputDescPtr(i); | ||||
GE_CHECK_NOTNULL(input_desc); | GE_CHECK_NOTNULL(input_desc); | ||||
auto ret = TensorUtils::GetTensorSizeInBytes(*input_desc, size); | auto ret = TensorUtils::GetTensorSizeInBytes(*input_desc, size); | ||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Get size from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i); | |||||
return ret; | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, | |||||
GELOGE(ret, "Get size from TensorDesc failed, op:%s, input id:%zu", op_desc->GetName().c_str(), i); | |||||
return ret); | |||||
std::vector<int64_t> output_shape = input_desc->GetShape().GetDims(); | |||||
if (is_getnext_sink_dynamic_) { | |||||
auto gear_and_real_out_size_info = merge_nodes_gear_and_real_out_size_info_[i]; | |||||
size = gear_and_real_out_size_info[cur_dynamic_dims_]; | |||||
auto gear_and_real_out_shape_info = merge_nodes_gear_and_real_out_shape_info_[i]; | |||||
output_shape = gear_and_real_out_shape_info[cur_dynamic_dims_]; | |||||
is_dynamic_ = true; | |||||
} | } | ||||
GELOGI("Output size is %ld, output shape is %s.", size, formats::JoinToString(output_shape).c_str()); | |||||
out_buffer_size_vec.push_back(size); | out_buffer_size_vec.push_back(size); | ||||
shape_info_vec.push_back(input_desc->GetShape().GetDims()); | |||||
shape_info_vec.push_back(output_shape); | |||||
} | } | ||||
GELOGI("Output blobs size:%zu, data index:%u, model id:%u", out_buffer_size_vec.size(), data_index, model_id_); | GELOGI("Output blobs size:%zu, data index:%u, model id:%u", out_buffer_size_vec.size(), data_index, model_id_); | ||||
for (size_t i = 0; i < out_buffer_size_vec.size(); ++i) { | for (size_t i = 0; i < out_buffer_size_vec.size(); ++i) { | ||||
std::unique_ptr<uint8_t[]> data_buf(new (std::nothrow) uint8_t[out_buffer_size_vec[i]]); | std::unique_ptr<uint8_t[]> data_buf(new (std::nothrow) uint8_t[out_buffer_size_vec[i]]); | ||||
@@ -2384,7 +2574,8 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data | |||||
output.data = std::move(data_buf); | output.data = std::move(data_buf); | ||||
output.length = out_buffer_size_vec[i]; | output.length = out_buffer_size_vec[i]; | ||||
outputs.emplace_back(std::move(output)); | outputs.emplace_back(std::move(output)); | ||||
GELOGI("Output index:%zu, data_length:%lu.", i, output.length); | |||||
GELOGD("Output index:%zu, output dims is %s, data length:%lu.", i, | |||||
formats::JoinToString(output.dims).c_str(), output.length); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2433,8 +2624,18 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b | |||||
output_data->index = data_id; | output_data->index = data_id; | ||||
output_data->model_id = model_id_; | output_data->model_id = model_id_; | ||||
is_getnext_sink_dynamic_ = false; | |||||
// copy output data from op to designated position | // copy output data from op to designated position | ||||
for (auto &op_desc : output_op_list_) { | for (auto &op_desc : output_op_list_) { | ||||
if (IsGetNextSinkDynamic(op_desc)) { | |||||
is_getnext_sink_dynamic_ = true; | |||||
cur_dynamic_dims_.clear(); | |||||
cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); | |||||
GE_CHK_RT_RET(rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), | |||||
netoutput_last_input_addr_, netoutput_last_input_size_, | |||||
RT_MEMCPY_DEVICE_TO_HOST)); | |||||
GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims_).c_str()); | |||||
} | |||||
if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { | if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
@@ -2507,11 +2708,16 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
continue; | continue; | ||||
} | } | ||||
GELOGI("Getting the input data, model_id:%u", model_id); | GELOGI("Getting the input data, model_id:%u", model_id); | ||||
GE_IF_BOOL_EXEC(!model->RunFlag(), break); | GE_IF_BOOL_EXEC(!model->RunFlag(), break); | ||||
InputData current_data = data_wrapper->GetInput(); | InputData current_data = data_wrapper->GetInput(); | ||||
GELOGI("Model thread Run begin, model id:%u, data index:%u.", model_id, current_data.index); | GELOGI("Model thread Run begin, model id:%u, data index:%u.", model_id, current_data.index); | ||||
model->cur_dynamic_dims_.clear(); | |||||
auto shape_data_buffer_data = current_data.blobs.back().data; | |||||
auto shape_data_buffer_length = current_data.blobs.back().length; | |||||
GE_CHK_RT_RET(rtMemcpy(model->cur_dynamic_dims_.data(), shape_data_buffer_length, | |||||
shape_data_buffer_data, shape_data_buffer_length, | |||||
RT_MEMCPY_DEVICE_TO_HOST)); | |||||
GE_TIMESTAMP_START(Model_SyncVarData); | GE_TIMESTAMP_START(Model_SyncVarData); | ||||
ret = model->SyncVarData(); | ret = model->SyncVarData(); | ||||
@@ -47,6 +47,7 @@ | |||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
#include "proto/task.pb.h" | #include "proto/task.pb.h" | ||||
#include "task_info/task_info.h" | #include "task_info/task_info.h" | ||||
#include "graph/common/local_context.h" | |||||
namespace ge { | namespace ge { | ||||
// op debug need 2048 bits buffer | // op debug need 2048 bits buffer | ||||
@@ -521,7 +522,6 @@ class DavinciModel { | |||||
bool is_inner_p2p_mem_base_; | bool is_inner_p2p_mem_base_; | ||||
// input data manager | // input data manager | ||||
DataInputer *data_inputer_; | DataInputer *data_inputer_; | ||||
int64_t load_begin_time_; | int64_t load_begin_time_; | ||||
int64_t load_end_time_; | int64_t load_end_time_; | ||||
struct timeInfo time_info_; | struct timeInfo time_info_; | ||||
@@ -840,6 +840,12 @@ class DavinciModel { | |||||
void ParseAIPPInfo(std::string in_out_info, InputOutputDims &dims_info); | void ParseAIPPInfo(std::string in_out_info, InputOutputDims &dims_info); | ||||
void SetLabelForDynamic(const NodePtr &node); | void SetLabelForDynamic(const NodePtr &node); | ||||
void ParseDynamicOutShape(const std::vector<std::string> &str_info, std::vector<vector<int64_t>> &vec_info); | |||||
bool IsGetNextSinkDynamic(const OpDescPtr &op_desc); | |||||
Status GetGetDynamicDimsNodeInfo(const NodePtr &node); | |||||
Status GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node); | |||||
Status GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node); | |||||
bool is_model_has_inited_; | bool is_model_has_inited_; | ||||
uint32_t model_id_; | uint32_t model_id_; | ||||
uint32_t runtime_model_id_; | uint32_t runtime_model_id_; | ||||
@@ -993,6 +999,15 @@ class DavinciModel { | |||||
void *op_debug_addr_ = nullptr; | void *op_debug_addr_ = nullptr; | ||||
void *p2p_debug_addr_ = nullptr; | void *p2p_debug_addr_ = nullptr; | ||||
bool is_new_model_desc_{false}; | bool is_new_model_desc_{false}; | ||||
bool is_getnext_sink_dynamic_ = false; | |||||
std::vector<int64_t> cur_dynamic_dims_; | |||||
void *netoutput_last_input_addr_ = nullptr; | |||||
int64_t netoutput_last_input_size_ = 0; | |||||
size_t shape_of_cur_dynamic_dims_ = 0; | |||||
// key: input_index: input is merge node; value: each gear info and each output size | |||||
std::map<size_t, std::map<vector<int64_t>, int64_t>> merge_nodes_gear_and_real_out_size_info_; | |||||
// key: input_index: input is merge node; value: each gear info and each output shape | |||||
std::map<size_t, std::map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ | #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ |
@@ -29,6 +29,8 @@ | |||||
#include "graph/load/new_model_manager/davinci_model.h" | #include "graph/load/new_model_manager/davinci_model.h" | ||||
#include "graph/load/new_model_manager/davinci_model_parser.h" | #include "graph/load/new_model_manager/davinci_model_parser.h" | ||||
#include "model/ge_root_model.h" | #include "model/ge_root_model.h" | ||||
#include "graph/common/local_context.h" | |||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
namespace ge { | namespace ge { | ||||
thread_local uint32_t device_count = 0; | thread_local uint32_t device_count = 0; | ||||
@@ -36,6 +38,7 @@ namespace { | |||||
const int kCmdParSize = 2; | const int kCmdParSize = 2; | ||||
const int kDumpCmdPairSize = 2; | const int kDumpCmdPairSize = 2; | ||||
const int kProfStartCmdParaSize = 2; | const int kProfStartCmdParaSize = 2; | ||||
const int kDecimal = 10; | |||||
const std::string kCmdTypeProfile = "profile"; | const std::string kCmdTypeProfile = "profile"; | ||||
const std::string kCmdTypeDump = "dump"; | const std::string kCmdTypeDump = "dump"; | ||||
const std::string kCmdTypeProfiling = "profiling"; | const std::string kCmdTypeProfiling = "profiling"; | ||||
@@ -444,6 +447,34 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status ModelManager::GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims, | |||||
const vector<pair<string, vector<int64_t>>> &user_input_dims, | |||||
vector<int64_t> &cur_dynamic_dims) { | |||||
GELOGD(" Start get cur dynamic dims."); | |||||
if (user_real_input_dims.size() != user_input_dims.size()) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"The input count of user: %zu should be equal to the data count of graph: %zu", | |||||
user_real_input_dims.size(), user_input_dims.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
for (size_t i = 0; i < user_input_dims.size(); ++i) { | |||||
if (user_real_input_dims[i].size() != user_input_dims[i].second.size()) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"The shape size: %zu of dynamic input: %s should be equal to the shape size of input shape: %zu.", | |||||
user_real_input_dims[i].size(), user_input_dims[i].first.c_str(), user_input_dims[i].second.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
for (size_t j = 0; j < user_input_dims.at(i).second.size(); ++j) { | |||||
if (user_input_dims.at(i).second.at(j) < 0) { | |||||
cur_dynamic_dims.emplace_back(user_real_input_dims[i][j]); | |||||
} | |||||
} | |||||
} | |||||
GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims).c_str()); | |||||
return SUCCESS; | |||||
} | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief load Input and output TensorInfo for Model | /// @brief load Input and output TensorInfo for Model | ||||
@@ -461,13 +492,27 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT | |||||
input_data.timeout = 0; | input_data.timeout = 0; | ||||
input_data.timestamp = 0; | input_data.timestamp = 0; | ||||
input_data.index = 0; | input_data.index = 0; | ||||
for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
DataBuffer data; | DataBuffer data; | ||||
data.data = inputs[i].data; | data.data = inputs[i].data; | ||||
data.length = inputs[i].length; | data.length = inputs[i].length; | ||||
input_data.blobs.push_back(data); | input_data.blobs.push_back(data); | ||||
} | } | ||||
if (!GetLocalOmgContext().user_input_dims.empty() && GetLocalOmgContext().need_multi_batch) { | |||||
std::vector<int64_t> cur_dynamic_dims; | |||||
if (!GetLocalOmgContext().user_real_input_dims.empty()) { | |||||
if (GetCurDynamicDims(GetLocalOmgContext().user_real_input_dims, GetLocalOmgContext().user_input_dims, | |||||
cur_dynamic_dims) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "[Train_Dynamic] Failed to Parse real_dynamic_dims."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
DataBuffer data; | |||||
data.data = cur_dynamic_dims.data(); | |||||
uint64_t length = static_cast<uint64_t>(cur_dynamic_dims.size() * sizeof(uint64_t)); | |||||
data.length = length; | |||||
input_data.blobs.push_back(data); | |||||
} | |||||
} | |||||
OutputData output_data; | OutputData output_data; | ||||
output_data.model_id = model_id; | output_data.model_id = model_id; | ||||
@@ -126,6 +126,18 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief Get cur_dynamic_dims for all input. | |||||
/// @param [in] vector<vector<uint64_t>> &user_real_input_dims: dims info of all user_inputs. | |||||
/// @param [in] vector<pair<string, vector<int64_t>>> &user_input_dims: key:name. value:dynamic dims from option. | |||||
/// @param [out] vector<uint64_t> &cur_dynamic_dims: real dims gather, where the index of -1. | |||||
/// @return 0: SUCCESS / others: INTERNAL_ERROR | |||||
/// | |||||
Status GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims, | |||||
const vector<pair<string, vector<int64_t>>> &user_input_dims, | |||||
vector<int64_t> &cur_dynamic_dims); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief model start to run | /// @brief model start to run | ||||
/// | /// | ||||
ge::Status Start(uint32_t model_id); | ge::Status Start(uint32_t model_id); | ||||
@@ -106,6 +106,10 @@ | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "ir_build/atc_ir_common.h" | |||||
#include "graph/common/local_context.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
namespace { | namespace { | ||||
const char *const kSummary = "Summary"; | const char *const kSummary = "Summary"; | ||||
@@ -119,6 +123,11 @@ const char *const kCheckPointGraph = "checkpoint_graph"; | |||||
const char *const kVectorEngine = "VectorEngine"; | const char *const kVectorEngine = "VectorEngine"; | ||||
const char *const kAIcoreEngine = "AIcoreEngine"; | const char *const kAIcoreEngine = "AIcoreEngine"; | ||||
const char *const kOffOptimize = "off_optimize"; | const char *const kOffOptimize = "off_optimize"; | ||||
const int32_t kDynamicDimsTypeIsGetNext = 0; | |||||
const int32_t kDynamicDimsTypeIsData = 1; | |||||
const int64_t kInvalidDynaimcDimsType = -1; | |||||
const char *const kSubstrOfGetNextNosinkName = "IteratorGetNext"; | |||||
const char *const kShapeDataName = "ascend_mbatch_shape_data"; | |||||
bool IsTailingOptimization() { | bool IsTailingOptimization() { | ||||
string is_tailing_optimization_option; | string is_tailing_optimization_option; | ||||
@@ -260,6 +269,42 @@ Status GraphManager::Finalize() { | |||||
return unload_model_ret; | return unload_model_ret; | ||||
} | } | ||||
Status GraphManager::InitDynamicParams(ComputeGraphPtr &compute_graph) { | |||||
for (const auto &node : compute_graph->GetAllNodes()) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
continue; | |||||
} | |||||
GetLocalOmgContext().need_multi_batch = false; | |||||
std::string op_type; | |||||
auto ret = GetOriginalType(node, op_type); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "Failed to get node %s original type.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if ((op_desc->GetType() == DATA) || (op_type == ITERATORV2)) { | |||||
GELOGI("Need to process multi batch for compute graph."); | |||||
GetLocalOmgContext().need_multi_batch = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!options_.input_shape.empty() && !options_.dynamic_dims.empty()) { | |||||
if (!ge::ParseInputShape(options_.input_shape, GetLocalOmgContext().input_dims, GetLocalOmgContext().user_input_dims, | |||||
true)) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Failed to parse input shape: %s.", options_.input_shape.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
GetLocalOmgContext().dynamic_dims = options_.dynamic_dims; | |||||
} | |||||
if (options_.dynamic_node_type == kDynamicDimsTypeIsGetNext) { | |||||
GetLocalOmgContext().dynamic_node_type = GETNEXT; | |||||
} | |||||
if (options_.dynamic_node_type == kDynamicDimsTypeIsData) { | |||||
GetLocalOmgContext().dynamic_node_type = DATA; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | ||||
const std::map<std::string, std::string> &options, | const std::map<std::string, std::string> &options, | ||||
const OmgContext &omg_context) { | const OmgContext &omg_context) { | ||||
@@ -279,6 +324,7 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
return GE_GRAPH_GRAPH_ALREADY_EXIST; | return GE_GRAPH_GRAPH_ALREADY_EXIST; | ||||
} | } | ||||
(void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); | (void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); | ||||
compute_graph_ = compute_graph; | |||||
} else { | } else { | ||||
GELOGE(FAILED, "compute graph is null"); | GELOGE(FAILED, "compute graph is null"); | ||||
return FAILED; | return FAILED; | ||||
@@ -296,15 +342,11 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
} | } | ||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | ||||
if (graph_node == nullptr) { | |||||
GELOGE(FAILED, "GraphNode make shared failed"); | |||||
return FAILED; | |||||
} | |||||
GE_IF_BOOL_EXEC(graph_node == nullptr, GELOGE(FAILED, "GraphNode make shared failed"); | |||||
return FAILED); | |||||
std::shared_ptr<Graph> graph_ptr = MakeShared<ge::Graph>(graph); | std::shared_ptr<Graph> graph_ptr = MakeShared<ge::Graph>(graph); | ||||
if (graph_ptr == nullptr) { | |||||
GELOGE(FAILED, "GraphPtr make shared failed"); | |||||
return FAILED; | |||||
} | |||||
GE_IF_BOOL_EXEC(graph_ptr == nullptr, GELOGE(FAILED, "GraphPtr make shared failed"); | |||||
return FAILED); | |||||
graph_node->SetGraph(graph_ptr); | graph_node->SetGraph(graph_ptr); | ||||
graph_node->SetOptions(options); | graph_node->SetOptions(options); | ||||
@@ -314,6 +356,10 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
if (!options_.output_datatype.empty()) { | if (!options_.output_datatype.empty()) { | ||||
GetLocalOmgContext().output_type = options_.output_datatype; | GetLocalOmgContext().output_type = options_.output_datatype; | ||||
} | } | ||||
if (InitDynamicParams() != SUCCESS) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Failed to init params when online infer is dynamic."); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
CompilerStages &stages = GetCompilerStages(graph_id); | CompilerStages &stages = GetCompilerStages(graph_id); | ||||
stages.preparer.SetOptions(options_); | stages.preparer.SetOptions(options_); | ||||
@@ -1299,10 +1345,9 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
// get encrypt mode | // get encrypt mode | ||||
ret = ParseOption(options, ENCRYPT_MODE, options_.encrypt_mode); | ret = ParseOption(options, ENCRYPT_MODE, options_.encrypt_mode); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.encryptMode value invalid."); | |||||
return GE_GRAPH_OPTIONS_INVALID; | |||||
} | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.encryptMode value invalid."); | |||||
return GE_GRAPH_OPTIONS_INVALID); | |||||
// get ek file | // get ek file | ||||
ParseOption(options, EK_FILE, options_.ek_file); | ParseOption(options, EK_FILE, options_.ek_file); | ||||
@@ -1340,33 +1385,29 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
// get weight compress flag | // get weight compress flag | ||||
ret = ParseOption(options, COMPRESS_FLAG, options_.compress_flag); | ret = ParseOption(options, COMPRESS_FLAG, options_.compress_flag); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.compressFlag value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID; | |||||
} | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.compressFlag value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID); | |||||
// ge.graphType. | // ge.graphType. | ||||
options_.run_graph_flag = true; | options_.run_graph_flag = true; | ||||
ret = ParseOption(options, RUN_FLAG, options_.run_graph_flag); | ret = ParseOption(options, RUN_FLAG, options_.run_graph_flag); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID; | |||||
} | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID); | |||||
// ge.graphType | // ge.graphType | ||||
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | |||||
return GE_GRAPH_OPTIONS_INVALID; | |||||
} | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | |||||
return GE_GRAPH_OPTIONS_INVALID); | |||||
// parse FmkOp | // parse FmkOp | ||||
options_.local_fmk_op_flag = false; | options_.local_fmk_op_flag = false; | ||||
ret = ParseOption(options, LOCAL_FMKOP_FLAG, options_.local_fmk_op_flag); | ret = ParseOption(options, LOCAL_FMKOP_FLAG, options_.local_fmk_op_flag); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.localFmkopFlag value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID; | |||||
} | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.localFmkopFlag value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID); | |||||
options_.enable_print_op_pass = true; | options_.enable_print_op_pass = true; | ||||
ret = ParseOption(options, ENABLE_PRINT_OP_PASS, options_.enable_print_op_pass); | ret = ParseOption(options, ENABLE_PRINT_OP_PASS, options_.enable_print_op_pass); | ||||
@@ -1378,11 +1419,9 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
// parse hcom parallel | // parse hcom parallel | ||||
options_.hcom_parallel = false; | options_.hcom_parallel = false; | ||||
ret = ParseOption(options, HCOM_PARALLEL, options_.hcom_parallel); | ret = ParseOption(options, HCOM_PARALLEL, options_.hcom_parallel); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.hcomParallel value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID; | |||||
} | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.hcomParallel value is invalid, must be 0 or 1."); | |||||
return GE_GRAPH_OPTIONS_INVALID); | |||||
// net output node dataType | // net output node dataType | ||||
ParseOption(options, OUTPUT_DATATYPE, options_.output_datatype); | ParseOption(options, OUTPUT_DATATYPE, options_.output_datatype); | ||||
@@ -1392,6 +1431,22 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
// Original model file name | // Original model file name | ||||
ParseOption(options, ORIGINAL_MODEL_FILE, options_.original_model_file); | ParseOption(options, ORIGINAL_MODEL_FILE, options_.original_model_file); | ||||
ParseOption(options, INPUT_SHAPE, options_.input_shape); | |||||
ParseOption(options, kDynamicDims, options_.dynamic_dims); | |||||
ParseOption(options, DYNAMIC_NODE_TYPE, options_.dynamic_node_type); | |||||
GELOGD("Dynamic dims params: input shape is %s, dynamic dims is %s, dynamic node type is %d.", | |||||
options_.input_shape.c_str(), options_.dynamic_dims.c_str(), options_.dynamic_node_type); | |||||
if ((!options_.input_shape.empty() && options_.dynamic_dims.empty()) || | |||||
(options_.input_shape.empty() && !options_.dynamic_dims.empty())) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Should set input shape and dynamic dims at the same time"); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
if ((!options_.input_shape.empty() && options_.dynamic_node_type == kInvalidDynaimcDimsType) || | |||||
(!options_.dynamic_dims.empty() && options_.dynamic_node_type == kInvalidDynaimcDimsType)) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Should set valid dynamic node type"); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
// Set Build model and step | // Set Build model and step | ||||
ParseOption(options, BUILD_MODE, options_.build_mode); | ParseOption(options, BUILD_MODE, options_.build_mode); | ||||
ParseOption(options, BUILD_STEP, options_.build_step); | ParseOption(options, BUILD_STEP, options_.build_step); | ||||
@@ -2550,6 +2605,118 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
} | } | ||||
} | } | ||||
Status GraphManager::DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes, | |||||
vector<NodePtr> &getnext_nosink_nodes, | |||||
vector<NodePtr> &getnext_sink_nodes) { | |||||
GELOGD("Start distinguish getnext and data node."); | |||||
for (NodePtr &input_node : graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(input_node); | |||||
OpDescPtr op_desc = input_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->GetType() == DATA && op_desc->GetName() != kShapeDataName) { | |||||
if (op_desc->GetName().find(kSubstrOfGetNextNosinkName) == string::npos) { | |||||
data_nodes.emplace_back(input_node); | |||||
} else { | |||||
getnext_nosink_nodes.emplace_back(input_node); | |||||
} | |||||
} | |||||
std::string op_type; | |||||
auto ret = GetOriginalType(input_node, op_type); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "Failed to get node %s original type.", input_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (op_type == ITERATORV2) { | |||||
GELOGD("Name of getnext sink is %s.", op_desc->GetName().c_str()); | |||||
getnext_sink_nodes.emplace_back(input_node); | |||||
} | |||||
} | |||||
GELOGI("data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), | |||||
getnext_nosink_nodes.size(), getnext_sink_nodes.size()); | |||||
return SUCCESS; | |||||
} | |||||
void GraphManager::ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor) { | |||||
GELOGD("Start parse input dims from data."); | |||||
for (size_t i = 0; i < input_tensor.size(); ++i) { | |||||
std::vector<int64_t> dynamic_dim; | |||||
for (size_t j = 0; j < input_tensor[i].dims.size(); ++j) { | |||||
dynamic_dim.emplace_back(input_tensor[i].dims[j]); | |||||
} | |||||
GELOGD("input tensor dims is %s.", formats::JoinToString(dynamic_dim).c_str()); | |||||
GetLocalOmgContext().user_real_input_dims.emplace_back(input_tensor[i].dims); | |||||
} | |||||
} | |||||
Status GraphManager::ParseInputsDimsForGetNexNosinkAndData(const vector<NodePtr> &dynamic_nodes, | |||||
const std::vector<InputTensorInfo> &input_tensor) { | |||||
GELOGD("Start parse inputs dims when coexist data and getnext sink."); | |||||
for (size_t i = 0; i < dynamic_nodes.size(); ++i) { | |||||
auto op_desc = dynamic_nodes.at(i)->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
continue; | |||||
} | |||||
GeAttrValue::INT index = 0; | |||||
if (!(AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, index))) { | |||||
GELOGE(PARAM_INVALID, "Get index from attr failed"); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (static_cast<size_t>(index) > input_tensor.size()) { | |||||
GELOGE(PARAM_INVALID, "The count of input tensor should be equal to the count of data."); | |||||
return PARAM_INVALID; | |||||
} | |||||
GetLocalOmgContext().user_real_input_dims.emplace_back(input_tensor.at(index).dims); | |||||
GELOGI("Shape dims of %d data is %s.", index, formats::JoinToString(input_tensor.at(index).dims).c_str()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor) { | |||||
GELOGI("Start parse input dims of %zu input tensor.", input_tensor.size()); | |||||
GetLocalOmgContext().user_real_input_dims.clear(); | |||||
if (!GetLocalOmgContext().dynamic_node_type.empty()) { | |||||
vector<NodePtr> data_nodes; | |||||
vector<NodePtr> getnext_nosink_nodes; | |||||
vector<NodePtr> getnext_sink_nodes; | |||||
if (DistinguishGetNextAndData(compute_graph_, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to distinguish getnext and data node."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (GetLocalOmgContext().dynamic_node_type == DATA) { | |||||
if (getnext_nosink_nodes.empty()) { | |||||
// just data or data+getnext_sink | |||||
ParseInputsDimsForData(input_tensor); | |||||
} else { | |||||
// data+getnext_nosink, but only need to get shape_dims of data | |||||
if (ParseInputsDimsForGetNexNosinkAndData(data_nodes, input_tensor) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to parse dims from data, when data coexist with getnext nosink."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} else { | |||||
if (getnext_nosink_nodes.empty()) { | |||||
// just getnext_sink or getnext_sink+data, need to get shape_dims from aicpu op | |||||
GELOGI("Need to get dims from aicpu op: GETDYNAMICDIMS."); | |||||
return SUCCESS; | |||||
} else { | |||||
if (data_nodes.empty()) { | |||||
// just getnext_nosink | |||||
ParseInputsDimsForData(input_tensor); | |||||
} else { | |||||
// getnext_nosink + data, but only need to get shape_dims of getnext_nosink | |||||
if (ParseInputsDimsForGetNexNosinkAndData(getnext_nosink_nodes, input_tensor) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to parse dims from getnext nosink, when data coexist with getnext nosink"); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
GELOGI("Parse %zu inputs dims success.", GetLocalOmgContext().user_real_input_dims.size()); | |||||
return SUCCESS; | |||||
} | |||||
void GraphManager::RunThread(GraphManager *graph_manager) { | void GraphManager::RunThread(GraphManager *graph_manager) { | ||||
if (prctl(PR_SET_NAME, ("GE_Run")) != 0) { | if (prctl(PR_SET_NAME, ("GE_Run")) != 0) { | ||||
GELOGW("Set thread name failed."); | GELOGW("Set thread name failed."); | ||||
@@ -2571,6 +2738,11 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||||
if (args.graph_node->graph_run_async_listener_ != nullptr) { | if (args.graph_node->graph_run_async_listener_ != nullptr) { | ||||
args.graph_node->graph_run_async_listener_->SetCallback(args.callback); | args.graph_node->graph_run_async_listener_->SetCallback(args.callback); | ||||
} | } | ||||
// parse inputs.dims to vector<vector<uint64_t>> dynamic_dims | |||||
if (graph_manager->ParseInputsDims(args.input_tensor) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Parse input dims failed."); | |||||
return; | |||||
} | |||||
Status ret; | Status ret; | ||||
if (!args.graph_node->GetLoadFlag()) { | if (!args.graph_node->GetLoadFlag()) { | ||||
@@ -72,6 +72,7 @@ class GraphManager { | |||||
/// | /// | ||||
Status AddGraph(const GraphId &graph_id, const Graph &graph, const std::map<std::string, std::string> &options, | Status AddGraph(const GraphId &graph_id, const Graph &graph, const std::map<std::string, std::string> &options, | ||||
const OmgContext &omg_context); | const OmgContext &omg_context); | ||||
Status InitDynamicParams(ComputeGraphPtr &compute_graph); | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
@@ -205,6 +206,12 @@ class GraphManager { | |||||
static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | ||||
const SubGraphInfoPtr &sub_graph_info_ptr, uint64_t session_id, | const SubGraphInfoPtr &sub_graph_info_ptr, uint64_t session_id, | ||||
const GEThreadLocalContext &ge_context); | const GEThreadLocalContext &ge_context); | ||||
Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor); | |||||
Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes, | |||||
vector<NodePtr> &getnext_nosink_nodes, vector<NodePtr> &getnext_sink_nodes); | |||||
void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor); | |||||
Status ParseInputsDimsForGetNexNosinkAndData(const vector<NodePtr> &dynamic_nodes, | |||||
const std::vector<InputTensorInfo> &input_tensor); | |||||
Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, GeRootModelPtr &ge_root_model, | Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, GeRootModelPtr &ge_root_model, | ||||
uint64_t session_id = INVALID_SESSION_ID); | uint64_t session_id = INVALID_SESSION_ID); | ||||
@@ -360,7 +367,7 @@ class GraphManager { | |||||
BlockingQueue<RunArgs> run_args_q_{}; | BlockingQueue<RunArgs> run_args_q_{}; | ||||
std::thread prerun_thread_; | std::thread prerun_thread_; | ||||
std::thread run_thread_; | std::thread run_thread_; | ||||
ComputeGraphPtr compute_graph_; | |||||
std::map<GraphId, GraphNodePtr> graph_map_; | std::map<GraphId, GraphNodePtr> graph_map_; | ||||
std::map<GraphId, ModelCacheHelperPtr> cache_helper_map_; | std::map<GraphId, ModelCacheHelperPtr> cache_helper_map_; | ||||
@@ -249,6 +249,9 @@ struct GraphManagerOptions { | |||||
std::string save_original_model; | std::string save_original_model; | ||||
std::string build_mode; | std::string build_mode; | ||||
std::string build_step; | std::string build_step; | ||||
std::string input_shape; | |||||
std::string dynamic_dims; | |||||
int32_t dynamic_node_type = -1; | |||||
GraphManagerOptions() | GraphManagerOptions() | ||||
: stream_num(1), | : stream_num(1), | ||||
perf_level(domi::GEN_TASK_WITHOUT_FUSION), | perf_level(domi::GEN_TASK_WITHOUT_FUSION), | ||||
@@ -130,6 +130,8 @@ const char *const kMbatchSwitchnName = "mbatch-switch-name"; | |||||
// the size of user defined output datatype or format string after split by ":". | // the size of user defined output datatype or format string after split by ":". | ||||
const size_t kUserDefinedElementCount = 2; | const size_t kUserDefinedElementCount = 2; | ||||
const int kDataOutIndex = 0; | |||||
const int64_t kInvalidDynaimcDimsType = -1; | |||||
OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { | OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { | ||||
GeTensorPtr tensor = MakeShared<GeTensor>(); | GeTensorPtr tensor = MakeShared<GeTensor>(); | ||||
@@ -1130,6 +1132,9 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input) { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (IsDynamicDims(input_node)) { | |||||
continue; | |||||
} | |||||
GeTensorDesc desc(user_input[index].GetTensorDesc()); | GeTensorDesc desc(user_input[index].GetTensorDesc()); | ||||
auto format = desc.GetFormat(); | auto format = desc.GetFormat(); | ||||
auto origin_format = desc.GetOriginFormat(); | auto origin_format = desc.GetOriginFormat(); | ||||
@@ -1523,6 +1528,22 @@ Status GraphPrepare::VerifyConstOp(const NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool GraphPrepare::IsDynamicDims(const NodePtr &input_node) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*input_node, kDataOutIndex).GetShape(); | |||||
const auto &dims = data_shape.GetDims(); | |||||
bool all_is_positive = false; | |||||
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | |||||
all_is_positive = true; | |||||
} | |||||
if (!all_is_positive && !options_.input_shape.empty() && !options_.dynamic_dims.empty() && | |||||
options_.dynamic_node_type != kInvalidDynaimcDimsType) { | |||||
GELOGI("No need to check and update desc info, the dims of %s is %s.", input_node->GetName().c_str(), | |||||
formats::JoinToString(dims).c_str()); | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | ||||
if (GetLocalOmgContext().is_dynamic_input) { | if (GetLocalOmgContext().is_dynamic_input) { | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -1545,6 +1566,9 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | |||||
GELOGE(GE_GRAPH_INIT_FAILED, "user_input size:%zu, data op index:%ld.", user_input.size(), index); | GELOGE(GE_GRAPH_INIT_FAILED, "user_input size:%zu, data op index:%ld.", user_input.size(), index); | ||||
return GE_GRAPH_INIT_FAILED; | return GE_GRAPH_INIT_FAILED; | ||||
} | } | ||||
if (IsDynamicDims(input_node)) { | |||||
continue; | |||||
} | |||||
GeTensorDesc desc(user_input[index].GetTensorDesc()); | GeTensorDesc desc(user_input[index].GetTensorDesc()); | ||||
for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { | for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { | ||||
@@ -84,6 +84,7 @@ class GraphPrepare { | |||||
Status GraphEquivalentTransformation(); | Status GraphEquivalentTransformation(); | ||||
void TypeConversionOfConstant(); | void TypeConversionOfConstant(); | ||||
bool IsDynamicDims(const NodePtr &input_node); | |||||
ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
GraphManagerOptions options_; | GraphManagerOptions options_; | ||||
@@ -54,10 +54,18 @@ const int kDataOutIndex = 0; | |||||
const int kDataInIndex = 0; | const int kDataInIndex = 0; | ||||
const int kMergeDataOutIndex = 0; | const int kMergeDataOutIndex = 0; | ||||
const int kStaticOutput = -1; | const int kStaticOutput = -1; | ||||
const int kDivisionConst = 2; | |||||
inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } | inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } | ||||
inline bool IsGetNextType(const NodePtr &node) { | |||||
std::string original_type; | |||||
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | |||||
GELOGW("Get original type failed"); return false); | |||||
return (original_type == ITERATORV2); | |||||
} | |||||
NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) { | NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) { | ||||
OpDescPtr desc = MakeShared<OpDesc>(); | OpDescPtr desc = MakeShared<OpDesc>(); | ||||
if (desc == nullptr) { | if (desc == nullptr) { | ||||
@@ -180,29 +188,6 @@ bool IsOnlyOutputToAipp(const NodePtr &node) { | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
Status CheckDataShape(const std::vector<NodePtr> &nodes) { | |||||
size_t unknown_shape_count = 0; | |||||
for (const auto &node : nodes) { | |||||
if (node->GetType() != DATA) { | |||||
continue; | |||||
} | |||||
for (auto dim : NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims()) { | |||||
if (dim < 0) { | |||||
unknown_shape_count++; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
if (unknown_shape_count == 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10040"); | |||||
GELOGE(PARAM_INVALID, | |||||
"Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims"); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace | } // namespace | ||||
Status MultiBatchGraphCopyer::CopyGraph() { | Status MultiBatchGraphCopyer::CopyGraph() { | ||||
@@ -258,15 +243,55 @@ Status MultiBatchGraphCopyer::Init() { | |||||
if (IsDataLikeType(node->GetType())) { | if (IsDataLikeType(node->GetType())) { | ||||
origin_data_nodes_.emplace_back(node); | origin_data_nodes_.emplace_back(node); | ||||
} | } | ||||
if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) { | |||||
origin_data_nodes_.emplace_back(node); | |||||
} | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||||
GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), | |||||
formats::JoinToString(data_shape.GetDims()).c_str()); | |||||
if (!IsAllDimsPositive(data_shape.GetDims())) { | |||||
origin_nodes_status_[data.get()] = kNodeInBatchBranch; | |||||
} | |||||
} | |||||
void MultiBatchGraphCopyer::LabelStatusForGetNextSink(const NodePtr &data) { | |||||
auto op_desc = data->GetOpDesc(); | |||||
GELOGI("Out count of %s is %zu.", data->GetName().c_str(), op_desc->GetOutputsSize()); | |||||
size_t data_count = op_desc->GetOutputsSize() / kDivisionConst; | |||||
for (size_t i = 0; i < data_count; ++i) { | |||||
GeTensorDesc output_desc = op_desc->GetOutputDesc(i); | |||||
GELOGD("The %zu data shape from getnext sink is %s.", i, | |||||
formats::JoinToString(output_desc.GetShape().GetDims()).c_str()); | |||||
const auto &out_data_anchor = data->GetOutDataAnchor(i); | |||||
if (out_data_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
size_t reference_times = out_data_anchor->GetPeerInDataAnchors().size(); | |||||
GELOGD("The %zu data has %zu referenced times.", i, reference_times); | |||||
getnext_sink_dynamic_out_mapping_.emplace_back(std::make_pair(i, reference_times)); | |||||
if (!IsAllDimsPositive(output_desc.GetShape().GetDims())) { | |||||
getnext_sink_dynamic_dims_ = true; | |||||
} | |||||
} | |||||
if (getnext_sink_dynamic_dims_) { | |||||
origin_nodes_status_[data.get()] = kNodeInBatchBranch; | |||||
} | |||||
} | |||||
Status MultiBatchGraphCopyer::LabelStatus() { | Status MultiBatchGraphCopyer::LabelStatus() { | ||||
for (const auto &data : origin_data_nodes_) { | for (const auto &data : origin_data_nodes_) { | ||||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||||
if (!IsAllDimsPositive(data_shape.GetDims())) { | |||||
origin_nodes_status_[data.get()] = kNodeInBatchBranch; | |||||
auto op_desc = data->GetOpDesc(); | |||||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(PARAM_INVALID, "Op desc is nullptr."); | |||||
return PARAM_INVALID); | |||||
LabelStatusForData(data); | |||||
if (!GetLocalOmgContext().dynamic_node_type.empty()) { | |||||
LabelStatusForGetNextSink(data); | |||||
} | } | ||||
} | } | ||||
bool changed = true; | bool changed = true; | ||||
@@ -299,13 +324,24 @@ Status MultiBatchGraphCopyer::LabelStatus() { | |||||
origin_nodes_status_[node.get()] = kNodeOutBatchBranch; | origin_nodes_status_[node.get()] = kNodeOutBatchBranch; | ||||
continue; | continue; | ||||
} | } | ||||
if (IsDataLikeType(node->GetType())) { | |||||
if (IsOnlyOutputToAipp(node)) { | |||||
origin_nodes_status_[node.get()] = kNodeOutBatchBranch; | |||||
} else { | |||||
if (GetLocalOmgContext().dynamic_node_type.empty()) { | |||||
if (IsDataLikeType(node->GetType())) { | |||||
if (IsOnlyOutputToAipp(node)) { | |||||
origin_nodes_status_[node.get()] = kNodeOutBatchBranch; | |||||
} else { | |||||
origin_nodes_status_[node.get()] = kNodeStartNode; | |||||
} | |||||
continue; | |||||
} | |||||
} else { | |||||
if (IsDataLikeType(node->GetType())) { | |||||
origin_nodes_status_[node.get()] = kNodeStartNode; | origin_nodes_status_[node.get()] = kNodeStartNode; | ||||
continue; | |||||
} | |||||
if (IsGetNextType(node)) { | |||||
origin_nodes_status_[node.get()] = kNodeStartNode; | |||||
continue; | |||||
} | } | ||||
continue; | |||||
} | } | ||||
if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { | if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { | ||||
origin_nodes_status_[node.get()] = kNodeOutBatchBranch; | origin_nodes_status_[node.get()] = kNodeOutBatchBranch; | ||||
@@ -318,50 +354,51 @@ Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){ | |||||
size_t unknown_shape_count = 0; | size_t unknown_shape_count = 0; | ||||
auto data_name_and_shape = GetLocalOmgContext().user_input_dims; | auto data_name_and_shape = GetLocalOmgContext().user_input_dims; | ||||
GELOGD("raw data_name_and_shape size: %zu", data_name_and_shape.size()); | GELOGD("raw data_name_and_shape size: %zu", data_name_and_shape.size()); | ||||
for (const auto &node : origin_all_nodes_) { | |||||
auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex); | |||||
auto data_shape = data_desc.GetShape(); | |||||
auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" : | |||||
data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others"; | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
for (const auto &node : origin_all_nodes_) { | |||||
auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex); | |||||
auto data_shape = data_desc.GetShape(); | |||||
auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" : | |||||
data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others"; | |||||
auto data_name = node->GetName(); | |||||
auto branch_status = GetNodeStatus(node); | |||||
if (branch_status != kNodeStartNode) { | |||||
continue; | |||||
} | |||||
GELOGI("CheckAndParseDynamicData shape_dims is %s.", formats::JoinToString(data_shape.GetDims()).c_str()); | |||||
if (IsAllDimsPositive(data_shape.GetDims())) { | |||||
continue; | |||||
} | |||||
auto data_name = node->GetName(); | |||||
auto branch_status = GetNodeStatus(node); | |||||
if (branch_status != kNodeStartNode) { | |||||
continue; | |||||
} | |||||
if (IsAllDimsPositive(data_shape.GetDims())) { | |||||
continue; | |||||
} | |||||
++unknown_shape_count; | |||||
auto iter = find(data_name_order_.begin(), data_name_order_.end(), data_name); | |||||
if (iter == data_name_order_.end()) { | |||||
if (dynamic_type_ == DynamicType::kDynamicBatch) { | |||||
auto ret = CheckDynamicBatchShape(data_shape.GetDims(), data_name); | |||||
if (!ret) { | |||||
return PARAM_INVALID; | |||||
std::vector<int64_t> data_shape_dims = data_shape.GetDims(); | |||||
++unknown_shape_count; | |||||
auto iter = find(data_name_order_.begin(), data_name_order_.end(), data_name); | |||||
if (iter == data_name_order_.end()) { | |||||
if (dynamic_type_ == DynamicType::kDynamicBatch) { | |||||
auto ret = CheckDynamicBatchShape(data_shape_dims, data_name); | |||||
GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic batch shape of %s.", | |||||
data_name.c_str()); return PARAM_INVALID); | |||||
} else if (dynamic_type_ == DynamicType::kDynamicImageSize) { | |||||
auto ret = CheckDynamicImageSizeShape(data_shape_dims, data_name, data_format); | |||||
GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic image size shape of %s.", | |||||
data_name.c_str()); return PARAM_INVALID); | |||||
} else if (dynamic_type_ == DynamicType::kDynamicDims) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10001", | |||||
{"parameter", "reason"}, | |||||
{"--input_shape", | |||||
"all dynamic data must be set in --input_shape"}); | |||||
GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape", | |||||
node->GetName().c_str(), data_shape.ToString().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | } | ||||
} else if (dynamic_type_ == DynamicType::kDynamicImageSize) { | |||||
auto ret = CheckDynamicImageSizeShape(data_shape.GetDims(), data_name, data_format); | |||||
if (!ret) { | |||||
return PARAM_INVALID; | |||||
} | |||||
} else if (dynamic_type_ == DynamicType::kDynamicDims) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10001", | |||||
{"parameter", "reason"}, | |||||
{"--input_shape", | |||||
"all dynamic data must be set in --input_shape"}); | |||||
GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape", | |||||
node->GetName().c_str(), data_shape.ToString().c_str()); | |||||
return INTERNAL_ERROR; | |||||
GELOGI("Data shape of %s is %s", data_name.c_str(), formats::JoinToString(data_shape_dims).c_str()); | |||||
data_name_and_shape.emplace_back(data_name, data_shape_dims); | |||||
} | } | ||||
data_name_and_shape.emplace_back(data_name, data_shape.GetDims()); | |||||
} | } | ||||
} | } | ||||
auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_); | auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_); | ||||
if (ret != SUCCESS){ | |||||
return ret; | |||||
} | |||||
if (unknown_shape_count == 0) { | |||||
GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info."); | |||||
if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10040"); | ErrorManager::GetInstance().ATCReportErrMessage("E10040"); | ||||
GELOGE(PARAM_INVALID, | GELOGE(PARAM_INVALID, | ||||
"Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims"); | "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims"); | ||||
@@ -371,13 +408,17 @@ Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){ | |||||
} | } | ||||
Status MultiBatchGraphCopyer::CreateNewNodes() { | Status MultiBatchGraphCopyer::CreateNewNodes() { | ||||
shape_data_ = InsertShapeDataNode(); | |||||
if (shape_data_ == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to create the shape data node for muti-batch"); | |||||
return INTERNAL_ERROR; | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
shape_data_ = InsertShapeDataNode(); | |||||
} else { | |||||
shape_data_ = InsertGetDynamicDimsNode(); | |||||
} | } | ||||
GE_IF_BOOL_EXEC(shape_data_ == nullptr, GELOGE(INTERNAL_ERROR, "Failed to create the shape node for multi batch"); | |||||
return INTERNAL_ERROR); | |||||
GE_CHECK_NOTNULL(shape_data_->GetOpDesc()); | |||||
for (const auto &node : origin_all_nodes_) { | for (const auto &node : origin_all_nodes_) { | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
auto node_type = node->GetType(); | auto node_type = node->GetType(); | ||||
Status ret = INTERNAL_ERROR; | Status ret = INTERNAL_ERROR; | ||||
auto branch_status = GetNodeStatus(node); | auto branch_status = GetNodeStatus(node); | ||||
@@ -385,10 +426,7 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { | |||||
switch (branch_status) { | switch (branch_status) { | ||||
case kNodeStartNode: | case kNodeStartNode: | ||||
GELOGD("Name: %s, type: %s, status: kNodeStartNode.", node->GetName().c_str(), node->GetType().c_str()); | GELOGD("Name: %s, type: %s, status: kNodeStartNode.", node->GetName().c_str(), node->GetType().c_str()); | ||||
ret = InsertSwitchNForData(node); | |||||
if (ret == SUCCESS) { | |||||
ret = UpdateMaxShapeToData(node); | |||||
} | |||||
ret = InsertSwitchNAndUpdateMaxShape(node); | |||||
break; | break; | ||||
case kNodeInBatchBranch: | case kNodeInBatchBranch: | ||||
GELOGD("Name: %s, type: %s, status: kNodeInBatchBranch.", node->GetName().c_str(), node->GetType().c_str()); | GELOGD("Name: %s, type: %s, status: kNodeInBatchBranch.", node->GetName().c_str(), node->GetType().c_str()); | ||||
@@ -397,6 +435,9 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { | |||||
case kNodeOutBatchBranch: | case kNodeOutBatchBranch: | ||||
GELOGD("Name: %s, type: %s, status: kNodeOutBatchBranch.", node->GetName().c_str(), node->GetType().c_str()); | GELOGD("Name: %s, type: %s, status: kNodeOutBatchBranch.", node->GetName().c_str(), node->GetType().c_str()); | ||||
ret = InsertMergeForEdgeNode(node); | ret = InsertMergeForEdgeNode(node); | ||||
if (ret == SUCCESS) { | |||||
ret = LinkGetDynamicDimsToNetOutput(node); | |||||
} | |||||
break; | break; | ||||
case kNodeNotSupportNode: | case kNodeNotSupportNode: | ||||
GELOGD("Name: %s, type: %s, status: kNodeNotSupportNode.", node->GetName().c_str(), node->GetType().c_str()); | GELOGD("Name: %s, type: %s, status: kNodeNotSupportNode.", node->GetName().c_str(), node->GetType().c_str()); | ||||
@@ -443,7 +484,59 @@ NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) { | |||||
GELOGI("Create merge node %s for node %s index %d", merge_node_name.c_str(), node->GetName().c_str(), index); | GELOGI("Create merge node %s for node %s index %d", merge_node_name.c_str(), node->GetName().c_str(), index); | ||||
return merge_node; | return merge_node; | ||||
} | } | ||||
NodePtr MultiBatchGraphCopyer::FindSwitchnNodeForDataEdge(const OutDataAnchorPtr &data_out_anchor, | |||||
const NodePtr &origin_node) { | |||||
auto data_node = data_out_anchor->GetOwnerNode(); | |||||
GELOGD("Start find switchn node insert between %s and %s", data_node->GetName().c_str(), | |||||
origin_node->GetName().c_str()); | |||||
NodePtr switchn = nullptr; | |||||
if (!getnext_sink_dynamic_dims_ && data_nodes_to_switchn_.count(data_node.get()) > 0) { | |||||
switchn = data_nodes_to_switchn_[data_node.get()]; | |||||
return switchn; | |||||
} | |||||
bool is_getnext_sink_data = false; | |||||
for (size_t i = 0; i < getnext_nodes_to_switchn_.size(); ++i) { | |||||
for (size_t j = 0; j < getnext_nodes_to_switchn_.at(i).size(); ++j) { | |||||
if (getnext_nodes_to_switchn_.at(i).at(j).first == data_node.get()) { | |||||
is_getnext_sink_data = true; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
// get output_idx of origin_node(getnext) | |||||
if (is_getnext_sink_data) { | |||||
auto output_idx = data_out_anchor->GetIdx(); | |||||
size_t referenced_index = 0; | |||||
GELOGI("The output idx %zu has %zu referenced nums.", output_idx, data_out_anchor->GetPeerInDataAnchors().size()); | |||||
for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) { | |||||
if (peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Op desc should not be nullptr."); | |||||
return nullptr; | |||||
} | |||||
if (getnext_nodes_to_switchn_.at(output_idx).empty()) { | |||||
GELOGI("Output idx %zu of %s is static output.", output_idx, data_node->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
if (output_idx >= getnext_nodes_to_switchn_.size() || | |||||
referenced_index >= getnext_nodes_to_switchn_.at(output_idx).size()) { | |||||
GELOGE(INTERNAL_ERROR, "Output idx is %zu, referenced index is %zu", output_idx, referenced_index); | |||||
return nullptr; | |||||
} | |||||
if (peer_in_anchor->GetOwnerNode()->GetOpDesc()->GetName() == origin_node->GetName()) { | |||||
switchn = getnext_nodes_to_switchn_.at(output_idx).at(referenced_index).second; | |||||
GELOGI("Name of switchn is %s.", switchn->GetName().c_str()); | |||||
return switchn; | |||||
} | |||||
referenced_index++; | |||||
} | |||||
} | |||||
return switchn; | |||||
} | |||||
// origin_node = Add, batch_num = 0,1,2,3; copyed_node = Add_0,Add_1,Add_2,Add_3 | |||||
Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr ©ed_node) { | Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr ©ed_node) { | ||||
GELOGI("Start copy data edges for %s and %s.", origin_node->GetName().c_str(), copyed_node->GetName().c_str()); | |||||
for (auto &in_anchor : origin_node->GetAllInDataAnchors()) { | for (auto &in_anchor : origin_node->GetAllInDataAnchors()) { | ||||
auto origin_src_anchor = in_anchor->GetPeerOutAnchor(); | auto origin_src_anchor = in_anchor->GetPeerOutAnchor(); | ||||
if (origin_src_anchor == nullptr) { | if (origin_src_anchor == nullptr) { | ||||
@@ -453,16 +546,16 @@ Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int ba | |||||
auto origin_src_node = origin_src_anchor->GetOwnerNode(); | auto origin_src_node = origin_src_anchor->GetOwnerNode(); | ||||
auto dst_anchor = copyed_node->GetInDataAnchor(in_anchor->GetIdx()); | auto dst_anchor = copyed_node->GetInDataAnchor(in_anchor->GetIdx()); | ||||
GE_CHECK_NOTNULL(dst_anchor); | GE_CHECK_NOTNULL(dst_anchor); | ||||
auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get()); | |||||
if (switchn_iter != data_nodes_to_switchn_.end()) { | |||||
auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutDataAnchor(batch_num), dst_anchor); | |||||
auto switchn = FindSwitchnNodeForDataEdge(origin_src_anchor, origin_node); | |||||
if (switchn != nullptr) { | |||||
auto ret = GraphUtils::AddEdge(switchn->GetOutDataAnchor(batch_num), dst_anchor); | |||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u", | GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u", | ||||
switchn_iter->second->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(), | |||||
switchn->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(), | |||||
ret); | ret); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
GELOGD("Add data edge from %s(%d) to %s(%d)", switchn_iter->second->GetName().c_str(), batch_num, | |||||
GELOGD("Add data edge from %s(%d) to %s(%d)", switchn->GetName().c_str(), batch_num, | |||||
copyed_node->GetName().c_str(), in_anchor->GetIdx()); | copyed_node->GetName().c_str(), in_anchor->GetIdx()); | ||||
continue; | continue; | ||||
} | } | ||||
@@ -493,7 +586,9 @@ Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int ba | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr ©ed_node) { | Status MultiBatchGraphCopyer::CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr ©ed_node) { | ||||
GELOGI("Start copy control edge for %s and %s.", node->GetName().c_str(), copyed_node->GetName().c_str()); | |||||
for (auto &origin_src_node : node->GetInControlNodes()) { | for (auto &origin_src_node : node->GetInControlNodes()) { | ||||
auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get()); | auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get()); | ||||
if (switchn_iter != data_nodes_to_switchn_.end()) { | if (switchn_iter != data_nodes_to_switchn_.end()) { | ||||
@@ -533,6 +628,7 @@ Status MultiBatchGraphCopyer::CopyInControlEdges(const NodePtr &node, int batch_ | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() { | NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() { | ||||
auto desc = MakeShared<OpDesc>(); | auto desc = MakeShared<OpDesc>(); | ||||
if (desc == nullptr) { | if (desc == nullptr) { | ||||
@@ -546,11 +642,8 @@ NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() { | |||||
} | } | ||||
desc->SetName(node_name); | desc->SetName(node_name); | ||||
desc->SetType(DATA); | desc->SetType(DATA); | ||||
GeTensorDesc tensor_desc; | |||||
tensor_desc.SetFormat(FORMAT_ND); | |||||
tensor_desc.SetShape(GeShape({static_cast<int64_t>(shapes_.at(0).size())})); | |||||
tensor_desc.SetDataType(DT_INT64); | |||||
// input and output of DATA is gear_info | |||||
GeTensorDesc tensor_desc(GeShape({static_cast<int64_t>(shapes_.at(0).size())}), FORMAT_ND, DT_INT64); | |||||
auto ret = desc->AddInputDesc(tensor_desc); | auto ret = desc->AddInputDesc(tensor_desc); | ||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); | GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); | ||||
@@ -580,6 +673,61 @@ NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() { | |||||
return data_node; | return data_node; | ||||
} | } | ||||
NodePtr MultiBatchGraphCopyer::InsertGetDynamicDimsNode() { | |||||
GELOGD("Start insert getdynamicdims node to get shape info."); | |||||
auto desc = MakeShared<OpDesc>(); | |||||
if (desc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to create shape data node, out of memory"); | |||||
return nullptr; | |||||
} | |||||
string node_name = "ascend_mbatch_get_dynamic_dims_node"; | |||||
// Only flush subgraph name | |||||
if (graph_->GetParentGraph() != nullptr) { | |||||
node_name = graph_->GetName() + "_" + node_name; | |||||
} | |||||
desc->SetName(node_name); | |||||
desc->SetType(GETDYNAMICDIMS); | |||||
// input of GetDynamicDims is shape_of_each_data, output is gear_info | |||||
for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { | |||||
size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size(); | |||||
GeTensorDesc tensor_desc(GeShape({static_cast<int64_t>(input_shape_dims)}), FORMAT_ND, DT_INT64); | |||||
auto ret = desc->AddInputDesc(tensor_desc); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); | |||||
return nullptr; | |||||
} | |||||
} | |||||
GeTensorDesc tensor_desc(GeShape({static_cast<int64_t>(shapes_.at(0).size())}), FORMAT_ND, DT_INT64); | |||||
auto ret = desc->AddOutputDesc(tensor_desc); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data"); | |||||
return nullptr; | |||||
} | |||||
if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add attr for created data"); | |||||
return nullptr; | |||||
} | |||||
auto data_node = graph_->AddNode(desc); | |||||
if (data_node == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add shape data node to graph"); | |||||
return nullptr; | |||||
} | |||||
ret = GraphUtils::AppendInputNode(graph_, data_node); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to append data node %s as input to graph", data_node->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
return data_node; | |||||
} | |||||
Status MultiBatchGraphCopyer::CheckArguments() { | Status MultiBatchGraphCopyer::CheckArguments() { | ||||
if (graph_ == nullptr) { | if (graph_ == nullptr) { | ||||
GELOGE(PARAM_INVALID, "Failed to copy graph, the graph is null"); | GELOGE(PARAM_INVALID, "Failed to copy graph, the graph is null"); | ||||
@@ -588,6 +736,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { | |||||
return CheckDynamicParams(shapes_); | return CheckDynamicParams(shapes_); | ||||
} | } | ||||
Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_nodes) { | Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_nodes) { | ||||
for (auto &node : start_nodes) { | for (auto &node : start_nodes) { | ||||
if (IsOnlyOutputToAipp(node)) { | if (IsOnlyOutputToAipp(node)) { | ||||
@@ -602,12 +751,24 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_ | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool MultiBatchGraphCopyer::IsInBatchBranch(const NodePtr &node) { | bool MultiBatchGraphCopyer::IsInBatchBranch(const NodePtr &node) { | ||||
return (nodes_to_batch_nodes_.count(node.get()) > 0) || (data_nodes_to_switchn_.count(node.get()) > 0); | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
return (nodes_to_batch_nodes_.count(node.get()) > 0) || (data_nodes_to_switchn_.count(node.get()) > 0); | |||||
} else { | |||||
for (size_t i = 0; i < getnext_nodes_to_switchn_.size(); ++i) { | |||||
for (size_t j = 0; j < getnext_nodes_to_switchn_.at(i).size(); ++j) { | |||||
if (getnext_nodes_to_switchn_.at(i).at(j).first == node.get()) { | |||||
return true; | |||||
} | |||||
} | |||||
} | |||||
return nodes_to_batch_nodes_.count(node.get()) > 0; | |||||
} | |||||
} | } | ||||
Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr &merge) { | |||||
Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr &merge, const NodePtr &switchn) { | |||||
// The caller should make sure that the there is a SwitchN node in the map | // The caller should make sure that the there is a SwitchN node in the map | ||||
auto &switchn = data_nodes_to_switchn_[data.get()]; | |||||
GELOGI("Link edge between data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(), | GELOGI("Link edge between data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(), | ||||
switchn->GetName().c_str()); | switchn->GetName().c_str()); | ||||
for (size_t i = 0; i < shapes_.size(); ++i) { | for (size_t i = 0; i < shapes_.size(); ++i) { | ||||
@@ -619,6 +780,7 @@ Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge) { | Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge) { | ||||
auto ©ed_nodes = nodes_to_batch_nodes_[node.get()]; | auto ©ed_nodes = nodes_to_batch_nodes_[node.get()]; | ||||
if (copyed_nodes.size() != shapes_.size()) { | if (copyed_nodes.size() != shapes_.size()) { | ||||
@@ -659,12 +821,87 @@ Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||||
auto data_name = data->GetName(); | |||||
Status MultiBatchGraphCopyer::InsertSwitchNAndUpdateMaxShape(const NodePtr &node) { | |||||
std::vector<std::pair<Node *, NodePtr>> dynamic_out_to_switchn; | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
if (InsertSwitchNForData(node, kDataOutIndex, kDataOutIndex, dynamic_out_to_switchn) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to insert switchn for %s.", node->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (UpdateMaxShapeToData(node, kDataOutIndex) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", node->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} else { | |||||
if (!IsGetNextType(node)) { | |||||
GELOGI("No need to insert switchn and update max shape for %s when get sink dynamic.", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
for (size_t i = 0; i < getnext_sink_dynamic_out_mapping_.size(); ++i) { | |||||
dynamic_out_to_switchn.clear(); | |||||
for (size_t j = 0; j < getnext_sink_dynamic_out_mapping_.at(i).second; ++j) { | |||||
GELOGI("The %zu data_index has %zu referenced nums.", getnext_sink_dynamic_out_mapping_.at(i).first, | |||||
getnext_sink_dynamic_out_mapping_.at(i).second); | |||||
if (InsertSwitchNForData(node, getnext_sink_dynamic_out_mapping_.at(i).first, j, dynamic_out_to_switchn) != | |||||
SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to insert switchn for %s of %zu out anchor when referenced index is %zu", | |||||
node->GetName().c_str(), getnext_sink_dynamic_out_mapping_.at(i).first, j); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
getnext_nodes_to_switchn_.emplace_back(dynamic_out_to_switchn); | |||||
} | |||||
for (size_t i = 0; i < getnext_sink_dynamic_out_mapping_.size(); ++i) { | |||||
if(UpdateMaxShapeToData(node, i) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to update max shape of %zu out anchor", node->GetName().c_str(), i); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
// TODO:delete | |||||
for (size_t i = 0; i < getnext_nodes_to_switchn_.size(); ++i) { | |||||
for (size_t j = 0; j < getnext_nodes_to_switchn_.at(i).size(); ++j) { | |||||
auto data_node = getnext_nodes_to_switchn_.at(i).at(j).first; | |||||
auto switchn = getnext_nodes_to_switchn_.at(i).at(j).second; | |||||
GELOGI("the output idx is %zu, ref idx is %zu, data node is %s, switchn is %s.", i, j, | |||||
data_node->GetName().c_str(), switchn->GetName().c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchGraphCopyer::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); | |||||
size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst); | |||||
GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index); | |||||
std::vector<int64_t> output_dims = {data_shape.GetDims().size()}; | |||||
GeShape output_shape(output_dims); | |||||
output_desc.SetShape(output_shape); | |||||
if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) { | |||||
GELOGE(FAILED, "Update output desc fail."); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); | |||||
string data_name = node->GetName(); | |||||
if (IsAllDimsPositive(data_shape.GetDims())) { | if (IsAllDimsPositive(data_shape.GetDims())) { | ||||
return SUCCESS; | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
return SUCCESS; | |||||
} else { | |||||
data_name.append("_").append(std::to_string(out_anchor_index)); | |||||
GELOGD("Update max shape of %s, shape dims is %s.", data_name, | |||||
formats::JoinToString(data_shape.GetDims()).c_str()); | |||||
// need to update shape of Shape_node | |||||
GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node"); | |||||
return SUCCESS; | |||||
} | |||||
} | } | ||||
size_t max_shape_index = 0; | size_t max_shape_index = 0; | ||||
int64_t max_size = 0; | int64_t max_size = 0; | ||||
for (size_t i = 0; i < shapes_.size(); ++i) { | for (size_t i = 0; i < shapes_.size(); ++i) { | ||||
@@ -684,39 +921,52 @@ Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) { | |||||
} | } | ||||
// must not be error, the calc result has been checked in function InsertSwitchNForData | // must not be error, the calc result has been checked in function InsertSwitchNForData | ||||
(void)CalcShape(data_to_dynamic_info_.at(data_name).at(max_shape_index), data_shape); | (void)CalcShape(data_to_dynamic_info_.at(data_name).at(max_shape_index), data_shape); | ||||
auto ret = NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
ret = NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(), | |||||
auto ret = NodeUtils::UpdateOutputShape(*node, out_anchor_index, data_shape); | |||||
GE_CHK_STATUS_RET(ret, "Failed to update output shape for data %s", node->GetName().c_str()); | |||||
// getnext_sink not has input | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
ret = NodeUtils::UpdateInputShape(*node, kDataInIndex, data_shape); | |||||
GE_CHK_STATUS_RET(ret, "Failed to update input shape for data %s", node->GetName().c_str()); | |||||
} else { | |||||
// need to update shape of Shape_node when getnext_sink_dynamic | |||||
GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node"); | |||||
} | |||||
GELOGI("Update the data %s input/output shape to the max %s", node->GetName().c_str(), | |||||
formats::ShapeToString(data_shape).c_str()); | formats::ShapeToString(data_shape).c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||||
auto data_name = data->GetName(); | |||||
(void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | |||||
Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, | |||||
const size_t &peer_in_anchor_index, | |||||
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn) { | |||||
auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); | |||||
string data_name = node->GetName(); | |||||
if (getnext_sink_dynamic_dims_) { | |||||
data_name.append("_").append(std::to_string(out_anchor_index)); | |||||
} | |||||
(void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | |||||
GELOGI("Insert switchn node of %s, shape dims is %s.", data_name.c_str(), | |||||
formats::JoinToString(data_shape.GetDims()).c_str()); | |||||
if (IsAllDimsPositive(data_shape.GetDims())) { | if (IsAllDimsPositive(data_shape.GetDims())) { | ||||
GELOGI("The shape of data %s are positive(%s), skip the multi batch process", data->GetName().c_str(), | |||||
GELOGI("The shape of data %s are positive(%s), skip the multi batch process", node->GetName().c_str(), | |||||
data_shape.ToString().c_str()); | data_shape.ToString().c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
auto switchn_desc = MakeShared<OpDesc>(); | auto switchn_desc = MakeShared<OpDesc>(); | ||||
if (switchn_desc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to create switchn for data %s", data->GetName().c_str()); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
switchn_desc->SetName(data->GetName() + "_ascend_mbatch_switchn"); | |||||
GE_IF_BOOL_EXEC(switchn_desc == nullptr, | |||||
GELOGE(OUT_OF_MEMORY, "Failed to create switchn for data %s", node->GetName().c_str()); | |||||
return OUT_OF_MEMORY); | |||||
string switchn_name = node->GetName() + "_ascend_mbatch_switchn"; | |||||
if (getnext_sink_dynamic_dims_) { | |||||
switchn_name.append("_").append(std::to_string(out_anchor_index)) | |||||
.append("_").append(std::to_string(peer_in_anchor_index)); | |||||
} | |||||
GELOGI("name of switchn is %s.", switchn_name.c_str()); | |||||
switchn_desc->SetName(switchn_name); | |||||
switchn_desc->SetType(SWITCHN); | switchn_desc->SetType(SWITCHN); | ||||
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); | |||||
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, out_anchor_index)); | |||||
if (switchn_desc->AddInputDesc("data", tensor) != GRAPH_SUCCESS) { // data | if (switchn_desc->AddInputDesc("data", tensor) != GRAPH_SUCCESS) { // data | ||||
return OUT_OF_MEMORY; | return OUT_OF_MEMORY; | ||||
} | } | ||||
@@ -726,11 +976,13 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { | |||||
} | } | ||||
std::vector<std::string> input_dims_str; | std::vector<std::string> input_dims_str; | ||||
for (size_t i = 0; i < shapes_.size(); ++i) { | for (size_t i = 0; i < shapes_.size(); ++i) { | ||||
GELOGI("Start clac shape for data %s, batch shape is %s.", data_name.c_str(), | |||||
formats::JoinToString(data_to_dynamic_info_.at(data_name).at(i)).c_str()); | |||||
auto shape = data_shape; | auto shape = data_shape; | ||||
auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", | GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", | ||||
data->GetName().c_str()); | |||||
node->GetName().c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
tensor.SetShape(shape); | tensor.SetShape(shape); | ||||
@@ -738,7 +990,7 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { | |||||
int64_t tensor_size = 0; | int64_t tensor_size = 0; | ||||
(void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | ||||
input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | ||||
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + | |||||
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" + | |||||
std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | ||||
formats::JoinToString(tensor.GetShape().GetDims()); | formats::JoinToString(tensor.GetShape().GetDims()); | ||||
input_dims_str.emplace_back(input_str); | input_dims_str.emplace_back(input_str); | ||||
@@ -751,9 +1003,9 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { | |||||
GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed"); | GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed"); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
GELOGD("The SwitchN %s output index %zu, shape %s", switchn_desc->GetName().c_str(), i, shape.ToString().c_str()); | |||||
GELOGD("The switchn %s output index %zu, shape %s", switchn_desc->GetName().c_str(), i, shape.ToString().c_str()); | |||||
} | } | ||||
(void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | |||||
(void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | |||||
if (!AttrUtils::SetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) { | if (!AttrUtils::SetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to add user designate shape order attr on switchn node %s", | GELOGE(INTERNAL_ERROR, "Failed to add user designate shape order attr on switchn node %s", | ||||
switchn_desc->GetName().c_str()); | switchn_desc->GetName().c_str()); | ||||
@@ -763,8 +1015,8 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add insert attr on switchn node %s", switchn_desc->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "Failed to add insert attr on switchn node %s", switchn_desc->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (!AttrUtils::SetStr(data->GetOpDesc(), kMbatchSwitchnName, switchn_desc->GetName())) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", data->GetName().c_str()); | |||||
if (!AttrUtils::SetStr(node->GetOpDesc(), kMbatchSwitchnName, switchn_desc->GetName())) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (StampDynamicType(switchn_desc) != SUCCESS) { | if (StampDynamicType(switchn_desc) != SUCCESS) { | ||||
@@ -773,13 +1025,17 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { | |||||
} | } | ||||
auto switchn = graph_->AddNode(switchn_desc); | auto switchn = graph_->AddNode(switchn_desc); | ||||
if (switchn == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to create switchn %s from desc", switchn_desc->GetName().c_str()); | |||||
return OUT_OF_MEMORY; | |||||
GE_IF_BOOL_EXEC(switchn == nullptr, | |||||
GELOGE(OUT_OF_MEMORY, "Failed to create switchn %s from desc", switchn_desc->GetName().c_str()); | |||||
return OUT_OF_MEMORY); | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
data_nodes_to_switchn_[node.get()] = switchn; | |||||
} else { | |||||
dynamic_out_to_switchn.emplace_back(std::make_pair(node.get(), switchn)); | |||||
} | } | ||||
data_nodes_to_switchn_[data.get()] = switchn; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) { | Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) { | ||||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); | auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
@@ -809,6 +1065,27 @@ Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::LinkGetDynamicDimsToNetOutput(const NodePtr &node) { | |||||
if (getnext_sink_dynamic_dims_ && node->GetType() == NETOUTPUT) { | |||||
size_t input_index = node->GetAllInDataAnchors().size(); | |||||
if (NodeUtils::AppendInputAnchor(node, input_index + 1) != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", node->GetName().c_str(), input_index); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
auto ret = | |||||
ge::GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(input_index)); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s", | |||||
node->GetName().c_str(), shape_data_->GetName().c_str()); | |||||
return INTERNAL_ERROR); | |||||
if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.", node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) { | Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) { | ||||
auto ©ed_nodes = nodes_to_batch_nodes_[node.get()]; | auto ©ed_nodes = nodes_to_batch_nodes_[node.get()]; | ||||
for (size_t i = 0; i < shapes_.size(); ++i) { | for (size_t i = 0; i < shapes_.size(); ++i) { | ||||
@@ -823,20 +1100,76 @@ Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::AddAttrForGetDynamicDims(const NodePtr &node) { | |||||
GELOGD("Add attr for :%s, type is %s:", shape_data_->GetName().c_str(), shape_data_->GetType().c_str()); | |||||
size_t data_count = node->GetAllOutDataAnchors().size() / kDivisionConst; | |||||
if (!AttrUtils::SetInt(shape_data_->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count)) { | |||||
GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
vector<int64_t> shape_info; | |||||
for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { | |||||
shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size()); | |||||
for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) { | |||||
shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j)); | |||||
} | |||||
} | |||||
if (!AttrUtils::SetListInt(shape_data_->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) { | |||||
GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchGraphCopyer::AddLinkForGetDynamicDims(const NodePtr &node) { | |||||
GELOGD("Start relink out anchor from shape node to getdynamicdims, and delete link between shape node and identity."); | |||||
size_t input_index = 0; | |||||
GELOGD("Out count of %s is %zu.", node->GetName().c_str(), node->GetAllOutDataAnchors().size()); | |||||
size_t data_count = node->GetAllOutDataAnchors().size() / kDivisionConst; | |||||
for (size_t out_index = data_count; out_index < node->GetAllOutDataAnchors().size(); ++out_index, ++input_index) { | |||||
GELOGI("Start add %s of %zu out_anchor to %s of %zu in_anchor.", node->GetName().c_str(), out_index, | |||||
shape_data_->GetName().c_str(), input_index); | |||||
auto out_data_anchor = node->GetOutDataAnchor(out_index); | |||||
auto ret = GraphUtils::AddEdge(out_data_anchor, shape_data_->GetInDataAnchor(input_index)); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s", | |||||
node->GetName().c_str(), shape_data_->GetName().c_str()); | |||||
return INTERNAL_ERROR); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MultiBatchGraphCopyer::LinkEdges() { | Status MultiBatchGraphCopyer::LinkEdges() { | ||||
Status ret; | Status ret; | ||||
for (const auto &node : origin_all_nodes_) { | for (const auto &node : origin_all_nodes_) { | ||||
if (data_nodes_to_switchn_.count(node.get()) > 0) { | |||||
ret = LinkDataToSwitchN(node); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
if (data_nodes_to_switchn_.count(node.get()) > 0) { | |||||
auto switchn = data_nodes_to_switchn_[node.get()]; | |||||
GE_IF_BOOL_EXEC(switchn == nullptr, | |||||
GELOGE(PARAM_INVALID, "Switchn should not be nullptr for %s.", node->GetName().c_str()); | |||||
return OUT_OF_MEMORY); | |||||
ret = LinkDataToSwitchN(node, switchn, kDataOutIndex); | |||||
GE_CHK_STATUS_RET(ret, "Link data to switchn failed."); | |||||
} | |||||
} else { | |||||
if (IsGetNextType(node)) { | |||||
GELOGD("Start add attr and link edge for %s.", node->GetName().c_str()); | |||||
GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(node), "Failed to add attr for %s.", node->GetName().c_str()); | |||||
GE_CHK_STATUS_RET(AddLinkForGetDynamicDims(node), "Failed to add link for %s.", node->GetName().c_str()); | |||||
} | |||||
for (size_t i = 0; i < getnext_nodes_to_switchn_.size(); ++i) { | |||||
for (size_t j = 0; j < getnext_nodes_to_switchn_.at(i).size(); ++j) { | |||||
if (getnext_nodes_to_switchn_.at(i).at(j).first == node.get()) { | |||||
auto switchn = getnext_nodes_to_switchn_.at(i).at(j).second; | |||||
GE_CHK_STATUS_RET(LinkDataToSwitchN(node, switchn, i), "Link %s to %s failed.", node->GetName().c_str(), | |||||
switchn->GetName().c_str()); | |||||
} | |||||
} | |||||
} | } | ||||
} | } | ||||
if (nodes_to_merge_nodes_.count(node.get()) > 0) { | if (nodes_to_merge_nodes_.count(node.get()) > 0) { | ||||
ret = LinkToMerge(node); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
GE_CHK_STATUS_RET(LinkToMerge(node), "Link %s to merge failed.", node->GetName().c_str()); | |||||
} | } | ||||
if (nodes_to_batch_nodes_.count(node.get()) > 0) { | if (nodes_to_batch_nodes_.count(node.get()) > 0) { | ||||
ret = LinkToNodeInBranch(node); | ret = LinkToNodeInBranch(node); | ||||
@@ -849,20 +1182,21 @@ Status MultiBatchGraphCopyer::LinkEdges() { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::LinkDataToSwitchN(const NodePtr &data) { | |||||
auto switchn = data_nodes_to_switchn_[data.get()]; | |||||
Status MultiBatchGraphCopyer::LinkDataToSwitchN(const NodePtr &data, const NodePtr &switchn, const int &out_index) { | |||||
auto ret = | auto ret = | ||||
GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNPredIndex)); | GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNPredIndex)); | ||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link shape data %s to switchn %s", | GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link shape data %s to switchn %s", | ||||
shape_data_->GetName().c_str(), switchn->GetName().c_str()); | shape_data_->GetName().c_str(), switchn->GetName().c_str()); | ||||
return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
ret = GraphUtils::AddEdge(data->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNDataIndex)); | |||||
ret = GraphUtils::AddEdge(data->GetOutDataAnchor(out_index), switchn->GetInDataAnchor(kSwitchNDataIndex)); | |||||
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link data %s to switchn %s", | GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link data %s to switchn %s", | ||||
data->GetName().c_str(), switchn->GetName().c_str()); | data->GetName().c_str(), switchn->GetName().c_str()); | ||||
return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) { | Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) { | ||||
auto &merge_nodes = nodes_to_merge_nodes_[node.get()]; | auto &merge_nodes = nodes_to_merge_nodes_[node.get()]; | ||||
for (size_t i = 0; i < merge_nodes.size(); ++i) { | for (size_t i = 0; i < merge_nodes.size(); ++i) { | ||||
@@ -877,10 +1211,27 @@ Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) { | |||||
} | } | ||||
continue; | continue; | ||||
} | } | ||||
if (data_nodes_to_switchn_.count(node.get()) > 0) { | |||||
auto ret = LinkDataToMerge(node, merge_node); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
if (!getnext_sink_dynamic_dims_) { | |||||
if (data_nodes_to_switchn_.count(node.get()) > 0) { | |||||
auto &switchn = data_nodes_to_switchn_[node.get()]; | |||||
auto ret = LinkDataToMerge(node, merge_node, switchn); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
continue; | |||||
} | |||||
} else { | |||||
for (size_t j = 0; j < getnext_nodes_to_switchn_.size(); ++j) { | |||||
for (size_t k = 0; k < getnext_nodes_to_switchn_.at(j).size(); ++k) { | |||||
if (getnext_nodes_to_switchn_.at(j).at(k).first == node.get()) { | |||||
auto &switchn = getnext_nodes_to_switchn_.at(j).at(k).second; | |||||
auto ret = LinkDataToMerge(node, merge_node, switchn); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
continue; | continue; | ||||
} | } | ||||
@@ -890,7 +1241,9 @@ Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::LinkToNodeInBranch(const NodePtr &node) { | Status MultiBatchGraphCopyer::LinkToNodeInBranch(const NodePtr &node) { | ||||
GELOGI("Start LinkToNodeInBranch for %s.", node->GetName().c_str()); | |||||
auto &branch_nodes = nodes_to_batch_nodes_[node.get()]; | auto &branch_nodes = nodes_to_batch_nodes_[node.get()]; | ||||
for (size_t i = 0; i < branch_nodes.size(); ++i) { | for (size_t i = 0; i < branch_nodes.size(); ++i) { | ||||
auto ret = CopyInDataEdges(node, i, branch_nodes[i]); | auto ret = CopyInDataEdges(node, i, branch_nodes[i]); | ||||
@@ -904,6 +1257,7 @@ Status MultiBatchGraphCopyer::LinkToNodeInBranch(const NodePtr &node) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { | Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { | ||||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); | auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
@@ -1025,16 +1379,42 @@ Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { | |||||
} | } | ||||
Status ProcessMultiBatch(ComputeGraphPtr &graph) { | Status ProcessMultiBatch(ComputeGraphPtr &graph) { | ||||
if (!GetLocalOmgContext().need_multi_batch) { | |||||
GELOGI("No need to process_multi for no_train graph."); | |||||
return SUCCESS; | |||||
} | |||||
std::vector<NodePtr> data_nodes; | |||||
std::vector<NodePtr> getnext_nosink_nodes; | |||||
std::vector<NodePtr> getnext_sink_nodes; | |||||
if (CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (DeleteIdentityInsertByAdapter(graph) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "DeleteIdentityInsertByAdapter failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
std::vector<std::vector<int64_t>> shapes; | std::vector<std::vector<int64_t>> shapes; | ||||
if (!InitDynamicParams(shapes)) { | if (!InitDynamicParams(shapes)) { | ||||
GELOGD("There is no multi-batch options, no need to process multi-batch copy"); | GELOGD("There is no multi-batch options, no need to process multi-batch copy"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (CheckNegativeCountOfOptions(shapes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Input_shape and dynamic_dims should set correct params."); | |||||
return PARAM_INVALID; | |||||
} | |||||
GetLocalOmgContext().all_gears_info = shapes; | |||||
DynamicType dynamic_type = DynamicType::kDynamicUnknown; | DynamicType dynamic_type = DynamicType::kDynamicUnknown; | ||||
if (!GetLocalOmgContext().dynamic_batch_size.empty()) { | if (!GetLocalOmgContext().dynamic_batch_size.empty()) { | ||||
dynamic_type = DynamicType::kDynamicBatch; | dynamic_type = DynamicType::kDynamicBatch; | ||||
} else if (!GetLocalOmgContext().dynamic_image_size.empty()) { | } else if (!GetLocalOmgContext().dynamic_image_size.empty()) { | ||||
dynamic_type = DynamicType::kDynamicImageSize;; | |||||
dynamic_type = DynamicType::kDynamicImageSize; | |||||
} else if (!GetLocalOmgContext().dynamic_dims.empty()) { | } else if (!GetLocalOmgContext().dynamic_dims.empty()) { | ||||
dynamic_type = DynamicType::kDynamicDims; | dynamic_type = DynamicType::kDynamicDims; | ||||
} | } | ||||
@@ -1134,9 +1514,11 @@ void GetDynamicShapeByMerge(const ComputeGraphPtr &graph, const NodePtr &node, | |||||
GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str()); | GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str()); | ||||
const auto &netoutput_desc = node->GetOpDesc(); | const auto &netoutput_desc = node->GetOpDesc(); | ||||
const auto &inputnode_to_netoutput = node->GetInAllNodes(); | const auto &inputnode_to_netoutput = node->GetInAllNodes(); | ||||
GELOGI("Train_Dynamic Find the merge node size is %zu.", inputnode_to_netoutput.size()); | |||||
for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) { | for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) { | ||||
bool insert_by_mbatch = false; | bool insert_by_mbatch = false; | ||||
(void)AttrUtils::GetBool(inputnode_to_netoutput.at(i)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, insert_by_mbatch); | (void)AttrUtils::GetBool(inputnode_to_netoutput.at(i)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, insert_by_mbatch); | ||||
GELOGI("Train_Dynamic type is %s", inputnode_to_netoutput.at(i)->GetType().c_str()); | |||||
if (inputnode_to_netoutput.at(i)->GetType() == MERGE && insert_by_mbatch) { | if (inputnode_to_netoutput.at(i)->GetType() == MERGE && insert_by_mbatch) { | ||||
GELOGI("Find the merge node %s with mbatch attr and the index is %zu", | GELOGI("Find the merge node %s with mbatch attr and the index is %zu", | ||||
inputnode_to_netoutput.at(i)->GetName().c_str(), i); | inputnode_to_netoutput.at(i)->GetName().c_str(), i); | ||||
@@ -55,9 +55,7 @@ class MultiBatchGraphCopyer { | |||||
data_name_order_.push_back(item.first); | data_name_order_.push_back(item.first); | ||||
} | } | ||||
} | } | ||||
void SetDataToDynamicInfo(const map<string, vector<vector<int64_t>>> &designate_shape) { | |||||
data_to_dynamic_info_ = designate_shape; | |||||
} | |||||
void SetDynamicType(const DynamicType dynamic_type) { | void SetDynamicType(const DynamicType dynamic_type) { | ||||
dynamic_type_ = dynamic_type; | dynamic_type_ = dynamic_type; | ||||
} | } | ||||
@@ -69,15 +67,25 @@ class MultiBatchGraphCopyer { | |||||
// label status for origin_all_nodes_ | // label status for origin_all_nodes_ | ||||
Status LabelStatus(); | Status LabelStatus(); | ||||
void LabelStatusForData(const NodePtr &data); | |||||
void LabelStatusForGetNextSink(const NodePtr &data); | |||||
// add nodes functions | // add nodes functions | ||||
Status CreateNewNodes(); | Status CreateNewNodes(); | ||||
NodePtr InsertShapeDataNode(); | NodePtr InsertShapeDataNode(); | ||||
Status InsertSwitchNForData(const NodePtr &data); | |||||
NodePtr InsertGetDynamicDimsNode(); | |||||
Status InsertSwitchNAndUpdateMaxShape(const NodePtr &node); | |||||
Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, | |||||
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn); | |||||
Status InsertIdentityAfterSwitchN(); | Status InsertIdentityAfterSwitchN(); | ||||
Status UpdateMaxShapeToData(const NodePtr &data); | |||||
Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); | |||||
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); | |||||
Status InsertMergeForEdgeNode(const NodePtr &node); | Status InsertMergeForEdgeNode(const NodePtr &node); | ||||
Status LinkGetDynamicDimsToNetOutput(const NodePtr &node); | |||||
/// Insert a merge node for src node `node` on output index `index`. The merge node will be used to merge all nodes | /// Insert a merge node for src node `node` on output index `index`. The merge node will be used to merge all nodes | ||||
/// in batch-branch to one output to the node out of the batch-branch. | /// in batch-branch to one output to the node out of the batch-branch. | ||||
@@ -95,12 +103,16 @@ class MultiBatchGraphCopyer { | |||||
// link edges functions | // link edges functions | ||||
Status LinkEdges(); | Status LinkEdges(); | ||||
Status LinkDataToSwitchN(const NodePtr &data); | |||||
Status AddAttrForGetDynamicDims(const NodePtr &node); | |||||
Status AddLinkForGetDynamicDims(const NodePtr &node); | |||||
Status LinkDataToSwitchN(const NodePtr &data, const NodePtr &switchn, const int &out_index); | |||||
Status LinkToMerge(const NodePtr &node); | Status LinkToMerge(const NodePtr &node); | ||||
Status LinkToNodeInBranch(const NodePtr &node); | Status LinkToNodeInBranch(const NodePtr &node); | ||||
Status LinkToNodeOutBranch(const NodePtr &node); | Status LinkToNodeOutBranch(const NodePtr &node); | ||||
Status LinkDataToMerge(const NodePtr &data, const NodePtr &merge); | |||||
Status LinkDataToMerge(const NodePtr &data, const NodePtr &merge, const NodePtr &switchn); | |||||
Status LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge); | Status LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge); | ||||
NodePtr FindSwitchnNodeForDataEdge(const OutDataAnchorPtr &data_out_anchor, const NodePtr &origin_node); | |||||
Status CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr ©ed_node); | Status CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr ©ed_node); | ||||
Status CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr ©ed_node); | Status CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr ©ed_node); | ||||
Status CheckAndParseDynamicData(); | Status CheckAndParseDynamicData(); | ||||
@@ -127,6 +139,11 @@ class MultiBatchGraphCopyer { | |||||
// the data nodes, and the SwitchN nodes inserted after it | // the data nodes, and the SwitchN nodes inserted after it | ||||
std::map<Node *, NodePtr> data_nodes_to_switchn_; | std::map<Node *, NodePtr> data_nodes_to_switchn_; | ||||
// the getnext_sink nodes, and the SwitchN nodes inserted after it | |||||
std::vector<std::vector<std::pair<Node *, NodePtr>>> getnext_nodes_to_switchn_; | |||||
std::vector<std::vector<std::pair<int, int>>> outidx_inidx_mappings_; | |||||
std::vector<std::pair<int, int>> outidx_inidx_mapping_; | |||||
// the nodes on the in/out-batch-branch edge, and the merge nodes inserted after it | // the nodes on the in/out-batch-branch edge, and the merge nodes inserted after it | ||||
std::map<Node *, std::vector<NodePtr>> nodes_to_merge_nodes_; | std::map<Node *, std::vector<NodePtr>> nodes_to_merge_nodes_; | ||||
@@ -142,6 +159,9 @@ class MultiBatchGraphCopyer { | |||||
// dynamic type : dynamic batch,, dynamic image size, dynamic dims. | // dynamic type : dynamic batch,, dynamic image size, dynamic dims. | ||||
DynamicType dynamic_type_ = DynamicType::kDynamicUnknown; | DynamicType dynamic_type_ = DynamicType::kDynamicUnknown; | ||||
std::vector<std::pair<size_t, size_t>> getnext_sink_dynamic_out_mapping_; | |||||
bool getnext_sink_dynamic_dims_ = false; | |||||
}; | }; | ||||
} // namespace multibatch | } // namespace multibatch | ||||
} // namespace ge | } // namespace ge | ||||
@@ -27,6 +27,8 @@ | |||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/compute_graph.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
namespace ge { | namespace ge { | ||||
namespace multibatch { | namespace multibatch { | ||||
@@ -38,6 +40,17 @@ const int kDynamicBatchDynamicDimsNum = 1; | |||||
const int kDynamicImgSizeDynamciDimsNum = 2; | const int kDynamicImgSizeDynamciDimsNum = 2; | ||||
const size_t kMaxNDDimNum = 4; | const size_t kMaxNDDimNum = 4; | ||||
const size_t kMinNDDimNum = 1; | const size_t kMinNDDimNum = 1; | ||||
const size_t kNumOfGetnextNode = 1; | |||||
const int kDivisionConst = 2; | |||||
const char *const kSubstrOfGetNextNosinkName = "IteratorGetNext"; | |||||
const char *const kShapeDataName = "ascend_mbatch_shape_data"; | |||||
inline bool IsGetNextType(const NodePtr &node) { | |||||
std::string original_type; | |||||
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | |||||
GELOGW("Get original type failed."); return false); | |||||
return (original_type == ITERATORV2); | |||||
} | |||||
void ParseDynamicSize(string dynamic_size, vector<vector<int64_t>> &shapes) { | void ParseDynamicSize(string dynamic_size, vector<vector<int64_t>> &shapes) { | ||||
std::vector<std::string> shape_strs = ge::StringUtils::Split(dynamic_size, ';'); | std::vector<std::string> shape_strs = ge::StringUtils::Split(dynamic_size, ';'); | ||||
@@ -59,6 +72,248 @@ void ParseDynamicSize(string dynamic_size, vector<vector<int64_t>> &shapes) { | |||||
} | } | ||||
} | } | ||||
Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes, | |||||
vector<NodePtr> &getnext_nosink_nodes, vector<NodePtr> &getnext_sink_nodes) { | |||||
GELOGD("Start distinguish getnext and data node."); | |||||
for (NodePtr &input_node : graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(input_node); | |||||
OpDescPtr op_desc = input_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->GetType() == DATA && op_desc->GetName() != kShapeDataName) { | |||||
if (op_desc->GetName().find(kSubstrOfGetNextNosinkName) == string::npos) { | |||||
data_nodes.emplace_back(input_node); | |||||
} else { | |||||
getnext_nosink_nodes.emplace_back(input_node); | |||||
} | |||||
} | |||||
if (IsGetNextType(input_node)) { | |||||
GELOGD("Name of getnext sink is %s.", op_desc->GetName().c_str()); | |||||
getnext_sink_nodes.emplace_back(input_node); | |||||
} | |||||
} | |||||
GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), | |||||
getnext_nosink_nodes.size(), getnext_sink_nodes.size()); | |||||
return SUCCESS; | |||||
} | |||||
Status CheckSequenceOfData(ComputeGraphPtr &graph, const vector<NodePtr> &data_nodes) { | |||||
GELOGD("Start check input sequence from data nodes and input shape."); | |||||
if (data_nodes.size() != GetLocalOmgContext().user_input_dims.size()) { | |||||
GELOGE(PARAM_INVALID, "The count of input shape:%zu should be equal to the count of data num:%zu.", | |||||
GetLocalOmgContext().user_input_dims.size(), data_nodes.size()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t i = 0; i < data_nodes.size(); ++i) { | |||||
auto data_node = data_nodes.at(i); | |||||
GE_CHECK_NOTNULL(data_node); | |||||
GE_CHECK_NOTNULL(data_node->GetOpDesc()); | |||||
auto output_shape = data_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims(); | |||||
auto dynamic_dims = GetLocalOmgContext().user_input_dims.at(i).second; | |||||
if (dynamic_dims.size() != output_shape.size()) { | |||||
GELOGE(PARAM_INVALID, "The output shape of %s is %s, the input shape from options of %s is %s.", | |||||
data_node->GetName().c_str(), formats::JoinToString(output_shape).c_str(), | |||||
GetLocalOmgContext().user_input_dims.at(i).first.c_str(), formats::JoinToString(dynamic_dims).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t j = 0; j < dynamic_dims.size(); ++j) { | |||||
if (dynamic_dims.at(j) != kDynmaicDims && dynamic_dims.at(j) != output_shape.at(j)) { | |||||
GELOGE(INTERNAL_ERROR, "Value of input shape %s should be equal to %s.", | |||||
formats::JoinToString(dynamic_dims).c_str(), formats::JoinToString(output_shape).c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status CheckSequenceOfGetnext(ComputeGraphPtr &graph, const vector<NodePtr> &getnext_sink_node) { | |||||
GELOGD("Start check input sequence from getnext sink nodes and input shape."); | |||||
if (getnext_sink_node.size() != kNumOfGetnextNode) { | |||||
GELOGE(PARAM_INVALID, "Not support dynamic dims when a graph with multi getnext nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
auto data_node = getnext_sink_node.at(0); | |||||
GE_CHECK_NOTNULL(data_node); | |||||
auto op_desc = data_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
size_t data_count = data_node ->GetAllOutDataAnchors().size() / kDivisionConst; | |||||
if (data_count != GetLocalOmgContext().user_input_dims.size()) { | |||||
GELOGE(PARAM_INVALID, "Output count of %s is %zu, should be equal to count of input shape: %zu", | |||||
op_desc->GetName().c_str(), data_count, GetLocalOmgContext().user_input_dims.size()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t i = 0; i < data_count; ++i) { | |||||
auto output_shape = data_node->GetOpDesc()->GetOutputDesc(i).GetShape().GetDims(); | |||||
auto dynamic_dims = GetLocalOmgContext().user_input_dims.at(i).second; | |||||
if (dynamic_dims.size() != output_shape.size()) { | |||||
GELOGE(PARAM_INVALID, "the output_shape of %s is %s, the input_shape from options of %s is %s.", | |||||
data_node->GetName().c_str(), formats::JoinToString(output_shape).c_str(), | |||||
GetLocalOmgContext().user_input_dims.at(i).first.c_str(), formats::JoinToString(dynamic_dims).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t j = 0; j < dynamic_dims.size(); ++j) { | |||||
if (dynamic_dims.at(j) != kDynmaicDims && dynamic_dims.at(j) != output_shape.at(j)) { | |||||
GELOGE(INTERNAL_ERROR, "value of input_shape %s should be equal to %s.", | |||||
formats::JoinToString(dynamic_dims).c_str(), formats::JoinToString(output_shape).c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status CheckSequenceOfOptions(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes, | |||||
vector<NodePtr> &getnext_nosink_nodes, vector<NodePtr> &getnext_sink_nodes) { | |||||
if (GetLocalOmgContext().dynamic_node_type.empty()) { | |||||
GELOGI("No need to CheckSequenceOfOptions."); | |||||
return SUCCESS; | |||||
} | |||||
if (DistinguishGetNextAndData(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "DistinguishGetNextAndData failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (GetLocalOmgContext().dynamic_node_type == DATA) { | |||||
GELOGD("Users want data nodes to be dynamic."); | |||||
if(CheckSequenceOfData(graph, data_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to check sequence of data nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} else { | |||||
GELOGD("Users want getnext nodes to be dynamic."); | |||||
if (!getnext_nosink_nodes.empty()) { | |||||
if (CheckSequenceOfData(graph, getnext_nosink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to check sequence of getnext nosink nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} else { | |||||
if (CheckSequenceOfGetnext(graph, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to check sequence of getnext sink nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status UpdateNameOfData(ComputeGraphPtr &graph, const vector<NodePtr> &data_nodes) { | |||||
GELOGD("Update first value of input shape by data nodes."); | |||||
if (data_nodes.size() != GetLocalOmgContext().user_input_dims.size()) { | |||||
GELOGE(PARAM_INVALID, "count of data_nodes: %zu should be equal to input_shape count: %zu.", | |||||
data_nodes.size(), GetLocalOmgContext().user_input_dims.size()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t i = 0; i < data_nodes.size(); ++i) { | |||||
GELOGD("The %zu data name is %s.", i, data_nodes.at(i)->GetOpDesc()->GetName().c_str()); | |||||
GetLocalOmgContext().user_input_dims.at(i).first = data_nodes.at(i)->GetOpDesc()->GetName(); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status UpdateNameOfGetnext(ComputeGraphPtr &graph, const vector<NodePtr> &getnext_sink_nodes) { | |||||
GELOGD("Update first value of input shape by getnext sink nodes."); | |||||
if (getnext_sink_nodes.size() != kNumOfGetnextNode) { | |||||
GELOGE(PARAM_INVALID, "Not support dynamic dims when a graph with multi getnext nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
auto input_node = getnext_sink_nodes.at(0); | |||||
GE_CHECK_NOTNULL(input_node); | |||||
auto op_desc = input_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
// user want getnext dynamic, just getnext or data+getnext_sink | |||||
size_t data_count = input_node->GetAllOutDataAnchors().size() / kDivisionConst; | |||||
if (data_count != GetLocalOmgContext().user_input_dims.size()) { | |||||
GELOGE(PARAM_INVALID, "Output count of %s is %zu, should be equal to count of input shape: %zu", | |||||
op_desc->GetName().c_str(), data_count, GetLocalOmgContext().user_input_dims.size()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t i = 0; i < data_count; ++i) { | |||||
string data_name = op_desc->GetName() + + "_" + std::to_string(i); | |||||
GELOGD("Data just from getnext sink is %s.", data_name.c_str()); | |||||
GetLocalOmgContext().user_input_dims.at(i).first = data_name; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
// need to distinguish online and offline, offline no need to update the name of input_shape | |||||
Status UpdateNameOfInputShape(ComputeGraphPtr &graph, const vector<NodePtr> &data_nodes, | |||||
const vector<NodePtr> &getnext_nosink_nodes, const vector<NodePtr> &getnext_sink_nodes) { | |||||
if (GetLocalOmgContext().dynamic_node_type.empty()) { | |||||
GELOGI("No need to update first value of input shape when offline infer."); | |||||
return SUCCESS; | |||||
} | |||||
if (GetLocalOmgContext().dynamic_node_type == DATA) { | |||||
GELOGD("Users want data nodes to be dynamic."); | |||||
if(UpdateNameOfData(graph, data_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to update first value of input shape of data nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} else { | |||||
GELOGD("Users want getnext nodes to be dynamic."); | |||||
if (!getnext_nosink_nodes.empty()) { | |||||
if(UpdateNameOfData(graph, getnext_nosink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to update first value of input shape of getnext nosink nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} else { | |||||
if (UpdateNameOfGetnext(graph, getnext_sink_nodes) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Failed to update first value of input shape of getnext sink nodes."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status DeleteIdentityInsertByAdapter(ComputeGraphPtr &graph) { | |||||
GELOGD("Start delete identity node inserted by adapter."); | |||||
for (NodePtr &node : graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node); | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (IsGetNextType(node)) { | |||||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | |||||
for (auto &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
GE_IF_BOOL_EXEC(peer_in_anchor == nullptr, continue); | |||||
auto dst_node = peer_in_anchor->GetOwnerNode(); | |||||
GE_IF_BOOL_EXEC(dst_node == nullptr, continue); | |||||
if (dst_node->GetType() == IDENTITY) { | |||||
GELOGI("Need to remove %s.", dst_node->GetName().c_str()); | |||||
if (ge::GraphUtils::RemoveNodeWithoutRelink(graph, dst_node) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Remove Identity node %s failed.", dst_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status CheckNegativeCountOfOptions(const std::vector<std::vector<int64_t>> &shapes) { | |||||
size_t negative_count = 0; | |||||
for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { | |||||
for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) { | |||||
if (GetLocalOmgContext().user_input_dims.at(i).second.at(j) == kDynmaicDims) { | |||||
negative_count++; | |||||
} | |||||
} | |||||
} | |||||
for (size_t i = 0; i < shapes.size(); ++i) { | |||||
if (shapes.at(i).size() != negative_count) { | |||||
GELOGE(PARAM_INVALID, "each gear num of dynamic_dims is %zu should be equal to %zu.", shapes.at(i).size(), | |||||
negative_count); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Init Dynamic Param from Options. | /// @brief Init Dynamic Param from Options. | ||||
@@ -115,8 +370,10 @@ Status ParserDataToDynmaicInfo(const vector<vector<int64_t>> &shapes, | |||||
auto &data_shape = cur_item.second; | auto &data_shape = cur_item.second; | ||||
auto dynamic_dims_num = std::count_if(data_shape.begin(), data_shape.end(), | auto dynamic_dims_num = std::count_if(data_shape.begin(), data_shape.end(), | ||||
[&data_shape](int64_t dim){ return dim < 0; }); | [&data_shape](int64_t dim){ return dim < 0; }); | ||||
GELOGI("Train_Dynamic dynamic_dims_num of %s is %zu", data_name.c_str(), dynamic_dims_num); | |||||
vector<vector<int64_t> > dynamic_info; | vector<vector<int64_t> > dynamic_info; | ||||
for (auto &dynamic_gear_info : shapes) { | for (auto &dynamic_gear_info : shapes) { | ||||
GELOGI("Train_Dynamic dynamic_gear_info is %s", formats::JoinToString(dynamic_gear_info).c_str()); | |||||
vector<int64_t> one_gear; | vector<int64_t> one_gear; | ||||
if (dynamic_gear_info.size() == static_cast<size_t>(dynamic_dims_num)) { | if (dynamic_gear_info.size() == static_cast<size_t>(dynamic_dims_num)) { | ||||
one_gear = dynamic_gear_info; | one_gear = dynamic_gear_info; | ||||
@@ -135,6 +392,7 @@ Status ParserDataToDynmaicInfo(const vector<vector<int64_t>> &shapes, | |||||
data_name.c_str(), formats::JoinToString(data_shape).c_str()); | data_name.c_str(), formats::JoinToString(data_shape).c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GELOGI("Train_Dynamic one_gear is %s.", formats::JoinToString(one_gear).c_str()); | |||||
dynamic_info.push_back(one_gear); | dynamic_info.push_back(one_gear); | ||||
} | } | ||||
cur_data_index += dynamic_dims_num; | cur_data_index += dynamic_dims_num; | ||||
@@ -214,7 +472,7 @@ Status CalcShape(const std::vector<int64_t> &batch_shape, GeShape &data_shape) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
"E19012", {"function", "reason"}, | "E19012", {"function", "reason"}, | ||||
{"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + | {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + | ||||
" does not match the data shape " + data_shape.ToString()}); | |||||
" does not match the data shape " + data_shape.ToString()}); | |||||
GELOGE(PARAM_INVALID, | GELOGE(PARAM_INVALID, | ||||
"Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", | "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", | ||||
batch_shape.size(), data_shape.ToString().c_str()); | batch_shape.size(), data_shape.ToString().c_str()); | ||||
@@ -223,6 +481,7 @@ Status CalcShape(const std::vector<int64_t> &batch_shape, GeShape &data_shape) { | |||||
data_shape.SetDim(i, batch_shape[batch_shape_index++]); | data_shape.SetDim(i, batch_shape[batch_shape_index++]); | ||||
} | } | ||||
} | } | ||||
GELOGI("CalcShape size of batch_shape is %zu, batch_shape_index is %zu.", batch_shape.size(), batch_shape_index); | |||||
if (batch_shape_index != batch_shape.size()) { | if (batch_shape_index != batch_shape.size()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
"E19012", {"function", "reason"}, {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + | "E19012", {"function", "reason"}, {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + | ||||
@@ -28,6 +28,21 @@ namespace ge { | |||||
namespace multibatch { | namespace multibatch { | ||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Update Dynamic Param from Options. | |||||
/// @param [in] ComputeGraphPtr &graph: the train graph | |||||
/// @return SUCCESS: valid / PARAM_INVALID: invalid. | |||||
/// | |||||
Status CheckSequenceOfOptions(ComputeGraphPtr &graph, std::vector<NodePtr> &data_nodes, | |||||
std::vector<NodePtr> &getnext_nosink_nodes, std::vector<NodePtr> &getnext_sink_nodes); | |||||
Status UpdateNameOfInputShape(ComputeGraphPtr &graph, const vector<NodePtr> &data_nodes, | |||||
const vector<NodePtr> &getnext_nosink_nodes, const vector<NodePtr> &getnext_sink_nodes); | |||||
Status DeleteIdentityInsertByAdapter(ComputeGraphPtr &graph); | |||||
Status CheckNegativeCountOfOptions(const std::vector<std::vector<int64_t>> &shapes); | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Init Dynamic Param from Options. | /// @brief Init Dynamic Param from Options. | ||||
/// @param [out] std::vector<std::vector<int64_t>> &shapes: Result for Params. | /// @param [out] std::vector<std::vector<int64_t>> &shapes: Result for Params. | ||||
/// @return true: Configed for Multi batch / false: Not configed for Multi batch. | /// @return true: Configed for Multi batch / false: Not configed for Multi batch. | ||||
@@ -174,6 +174,9 @@ const std::string HCOM_PARALLEL = "ge.hcomParallel"; | |||||
// configure whether to use dynamic batch size | // configure whether to use dynamic batch size | ||||
const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | ||||
const std::string INPUT_SHAPE = "ge.inputShape"; | |||||
const std::string DYNAMIC_NODE_TYPE = "ge.dynamicNodeType"; | |||||
// configure whether to use dynamic image size | // configure whether to use dynamic image size | ||||
const char *const kDynamicImageSize = "ge.dynamicImageSize"; | const char *const kDynamicImageSize = "ge.dynamicImageSize"; | ||||
@@ -525,6 +525,12 @@ REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLGATHER, "HorovodAllgather"); | |||||
REGISTER_OPTYPE_DECLARE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); | REGISTER_OPTYPE_DECLARE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); | ||||
REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); | REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); | ||||
// aicpu op for online_infer dynamic_dims | |||||
REGISTER_OPTYPE_DECLARE(GETDYNAMICDIMS, "GetDynamicDims"); | |||||
// getnext op for online_infer dynamic_dims | |||||
REGISTER_OPTYPE_DECLARE(ITERATORV2, "IteratorV2"); | |||||
enum InputMode { INPUT = 0, CONST }; | enum InputMode { INPUT = 0, CONST }; | ||||
// Definition of the processing status enum of the process module | // Definition of the processing status enum of the process module | ||||
@@ -115,6 +115,11 @@ struct OmgContext { | |||||
std::string dynamic_batch_size; | std::string dynamic_batch_size; | ||||
std::string dynamic_image_size; | std::string dynamic_image_size; | ||||
std::string dynamic_dims; | std::string dynamic_dims; | ||||
std::string dynamic_node_type; | |||||
std::vector<std::vector<int64_t>> user_real_input_dims; | |||||
std::vector<int64_t> cur_dynamic_dims; | |||||
std::vector<std::vector<int64_t>> all_gears_info; | |||||
bool need_multi_batch = false; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||