@@ -961,9 +961,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp | |||
std::vector<DataType> input_data_type; | |||
for (size_t i = 0; i < op->GetAllInputsSize(); ++i) { | |||
GeTensorDescPtr input_tensor_desc = op->MutableInputDesc(i); | |||
if (input_tensor_desc == nullptr) { | |||
continue; | |||
} | |||
GE_IF_BOOL_EXEC(input_tensor_desc == nullptr, continue); | |||
input_format.emplace_back(input_tensor_desc->GetFormat()); | |||
input_shape.emplace_back(input_tensor_desc->GetShape().GetDims()); | |||
input_data_type.emplace_back(input_tensor_desc->GetDataType()); | |||
@@ -973,9 +972,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp | |||
std::vector<DataType> output_data_type; | |||
for (size_t j = 0; j < op->GetOutputsSize(); ++j) { | |||
GeTensorDescPtr output_tensor_desc = op->MutableOutputDesc(j); | |||
if (output_tensor_desc == nullptr) { | |||
continue; | |||
} | |||
GE_IF_BOOL_EXEC(output_tensor_desc == nullptr, continue); | |||
output_format.emplace_back(output_tensor_desc->GetFormat()); | |||
output_shape.emplace_back(output_tensor_desc->GetShape().GetDims()); | |||
output_data_type.emplace_back(output_tensor_desc->GetDataType()); | |||
@@ -854,7 +854,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||
op_desc->GetName().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_; | |||
OmgContext &omg_context = impl_->omg_context_; | |||
omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc); | |||
if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { | |||
@@ -869,11 +869,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||
if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) { | |||
fuzz_compile_flag = true; | |||
} | |||
if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag)) { | |||
REPORT_CALL_ERROR("E19999", "set ATTR_NAME_FUZZ_BUILD failed for %s.", op_desc->GetName().c_str()); | |||
GELOGE(FAILED, "[Set][ATTR_NAME_FUZZ_BUILD] Failed to set attr for %s.", op_desc->GetName().c_str()); | |||
return FAILED; | |||
} | |||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag); | |||
impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag; | |||
// 1. Create ComputeGraph. | |||
@@ -579,11 +579,8 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | |||
if (continuous_output) { | |||
GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "output"), | |||
"[Get][MemType]fail for node:%s", node->GetName().c_str()); | |||
ret = AssignContinuousOutputMemory(node, memory_type, continuous_type); | |||
if (ret != ge::SUCCESS) { | |||
GELOGE(ret, "[Assign][Memory:Continuous:Ouput]fail for node:%s", node->GetName().c_str()); | |||
return ret; | |||
} | |||
GE_CHK_STATUS_RET(AssignContinuousOutputMemory(node, memory_type, continuous_type), | |||
"[Assign][Memory:Continuous:Output]fail for node:%s", node->GetName().c_str()); | |||
} | |||
} | |||
// Assign continuous input memory in `reverse topo order` which stored before | |||
@@ -1212,7 +1212,8 @@ Status StreamAllocator::SetActiveStreamsForLoop() { | |||
for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
bool is_loop_active = false; | |||
if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) { | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active); | |||
if (is_loop_active) { | |||
vector<string> activated_label_list; | |||
NodePtr pre_switch_node = FindSwitchNodeBeforeLoopActiveNode(node); | |||
@@ -1668,42 +1668,23 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||
}; | |||
GE_MAKE_GUARD(release, callback); | |||
// malloc sysOpInfoList in SysOpCheckInfo | |||
status = rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | |||
if (status != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
return RT_ERROR_TO_GE_STATUS(status); | |||
} | |||
GE_CHK_RT_RET(rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM)); | |||
allocated_mem.push_back(d_req_op_list); | |||
// malloc sysOpInfoList in SysOpCheckResp | |||
status = rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM); | |||
if (status != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status); | |||
return RT_ERROR_TO_GE_STATUS(status); | |||
} | |||
GE_CHK_RT_RET(rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM)); | |||
allocated_mem.push_back(d_res_op_list); | |||
// malloc returnCodeList in SysOpCheckResp | |||
status = rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM); | |||
if (status != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status); | |||
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status); | |||
return RT_ERROR_TO_GE_STATUS(status); | |||
} | |||
GE_CHK_RT_RET(rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM)); | |||
allocated_mem.push_back(d_ret_code_list); | |||
for (const auto &op_type : aicpu_optype_list) { | |||
SysOpInfo op_info; | |||
// malloc op_type name in SysOpInfo | |||
void *d_op_type_name = nullptr; | |||
status = rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM); | |||
if (status != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||
return RT_ERROR_TO_GE_STATUS(status); | |||
} | |||
GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM)); | |||
allocated_mem.push_back(d_op_type_name); | |||
GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.length(), op_type.c_str(), op_type.length(), RT_MEMCPY_HOST_TO_DEVICE)); | |||
op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | |||
@@ -1716,12 +1697,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||
SysOpInfo op_info; | |||
// malloc op_type name in SysOpInfo | |||
void *d_op_type_name = nullptr; | |||
status = rtMalloc(&d_op_type_name, op_type.size(), RT_MEMORY_HBM); | |||
if (status != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status); | |||
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.size(), status); | |||
return RT_ERROR_TO_GE_STATUS(status); | |||
} | |||
GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM)); | |||
allocated_mem.push_back(d_op_type_name); | |||
GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.size(), op_type.c_str(), op_type.size(), RT_MEMCPY_HOST_TO_DEVICE)); | |||
op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name)); | |||
@@ -1745,12 +1722,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op | |||
op_check_info_res.sysOpInfoList = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_res_op_list)); | |||
uint32_t args_size = sizeof(SysOpCheckInfo) + sizeof(SysOpCheckResp); | |||
status = rtMalloc(&args, args_size, RT_MEMORY_HBM); | |||
if (status != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%u, ret = 0x%X", args_size, status); | |||
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%u, ret = 0x%X", args_size, status); | |||
return RT_ERROR_TO_GE_STATUS(status); | |||
} | |||
GE_CHK_RT_RET(rtMalloc(&args, args_size, RT_MEMORY_HBM)); | |||
allocated_mem.push_back(args); | |||
GE_CHK_RT(rtMemcpy(args, sizeof(SysOpCheckInfo), reinterpret_cast<void *>(&op_check_info_req), sizeof(SysOpCheckInfo), | |||
RT_MEMCPY_HOST_TO_DEVICE)); | |||
@@ -3532,9 +3532,8 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra | |||
return ret; | |||
} | |||
GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); | |||
if ((options_.build_mode == BUILD_MODE_TUNING) && | |||
(options_.build_step == BUILD_STEP_BEFORE_UB_MATCH || options_.build_step == BUILD_STEP_AFTER_BUILDER || | |||
options_.build_step == BUILD_STEP_AFTER_BUILDER_SUB)) { | |||
std::set<string> build_steps = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB}; | |||
if ((options_.build_mode == BUILD_MODE_TUNING) && (build_steps.count(options_.build_step) > 0)) { | |||
GE_TIMESTAMP_START(ConvertGraphToFile); | |||
std::string tuning_path; | |||
(void) GetContext().GetOption(TUNING_PATH, tuning_path); | |||
@@ -743,12 +743,12 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | |||
continue; | |||
} | |||
// ignore data / netoutput of subgraph | |||
if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||
continue; | |||
} | |||
if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||
continue; | |||
if (AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | |||
if (node->GetType() == DATA || node->GetType() == NETOUTPUT) { | |||
continue; | |||
} | |||
} | |||
bool identity_reserved = false; | |||
AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved); | |||
if (identity_reserved) { | |||
@@ -366,11 +366,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr | |||
// link input -> end | |||
string end_name = kEndType + std::to_string(graph_info_.num_of_pld_end_); | |||
auto end_op_desc = MakeShared<OpDesc>(end_graph->GetName() + "_" + end_name, END); | |||
if (end_op_desc == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed."); | |||
GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed, pld_op_desc is nullptr."); | |||
return FAILED; | |||
} | |||
GE_CHECK_NOTNULL(end_op_desc); | |||
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(end_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | |||
GELOGW("SetInt peerIndex failed");) | |||
GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), | |||
@@ -429,11 +426,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr | |||
int64_t node_id = src_node_opdesc->GetId(); | |||
const string pld_name = kPlaceHolderType + std::to_string(graph_info_.num_of_pld_end_); | |||
auto pld_op_desc = MakeShared<OpDesc>(pld_graph->GetName() + "_" + pld_name, PLACEHOLDER); | |||
if (pld_op_desc == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed."); | |||
GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed."); | |||
return FAILED; | |||
} | |||
GE_CHECK_NOTNULL(pld_op_desc); | |||
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "peerIndex", graph_info_.num_of_pld_end_), | |||
GELOGW("SetInt peerIndex failed");) | |||
GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, "_peerNodeName", new_end_node->GetName()), | |||
@@ -333,11 +333,8 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||
during_pass_node_set.nodes_last.clear(); | |||
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); | |||
if (re_pass_times == kMaxRePassTimes) { | |||
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); | |||
} | |||
GE_IF_BOOL_EXEC(re_pass_times == kMaxRePassTimes, GELOGW("re_pass_times should not come to %d", kMaxRePassTimes)); | |||
GELOGD("All passes runs end"); | |||
return SUCCESS; | |||
} | |||
Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | |||
@@ -41,9 +41,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||
bool graph_change = false; | |||
// 1. Add FP/BP flow ctrl (big cycle) | |||
for (auto &node : compute_graph->GetDirectNode()) { | |||
if (node == nullptr) { | |||
continue; | |||
} | |||
GE_IF_BOOL_EXEC(node == nullptr, continue); | |||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
uint32_t true_stream_id = 0; | |||
bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id); | |||
@@ -65,9 +63,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||
// 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle) | |||
// NOTE: Small cycle share the variables with big cycle. | |||
for (auto &node : compute_graph->GetDirectNode()) { | |||
if (node == nullptr) { | |||
continue; | |||
} | |||
GE_IF_BOOL_EXEC(node == nullptr, continue); | |||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
bool need_cycle_flag = false; | |||
bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag); | |||
@@ -164,9 +164,10 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra | |||
data_nodes[parent_index] = node; | |||
GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str()); | |||
} else if ((node->GetType() == CONSTANT) && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) { | |||
} else if (node->GetType() == CONSTANT) { | |||
set<string> peer_name_list; | |||
const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); | |||
GE_IF_BOOL_EXEC(out_anchor == nullptr, continue); | |||
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
const auto &peer_node = in_anchor->GetOwnerNode(); | |||
// Trim subgraph node name prefix. | |||
@@ -64,16 +64,19 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No | |||
GE_IF_BOOL_EXEC(node == nullptr || node->GetOpDesc() == nullptr, | |||
REPORT_INNER_ERROR("E19999", "Param node or its op_desc is nullptr, check invalid"); | |||
GELOGE(FAILED, "[Check][Param] Param node or its op_desc is nullptr"); return ""); | |||
std::set<std::string> trans_shapes = { RESHAPE, EXPANDDIMS, SQUEEZE }; | |||
std::set<std::string> trans_shape_and_format = { TRANSPOSE, TRANSPOSED, EXPANDDIMS }; | |||
if (node->GetType() == CAST) { | |||
trans_data_type = true; | |||
} else if (node->GetType() == TRANSPOSE || node->GetType() == TRANSPOSED || node->GetType() == EXPANDDIMS) { | |||
} else if (trans_shape_and_format.count(node->GetType()) > 0) { | |||
trans_format = true; | |||
trans_shape = true; | |||
} else if (node->GetType() == TRANSDATA) { | |||
trans_data_type = true; | |||
trans_format = true; | |||
trans_shape = true; | |||
} else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) { | |||
} else if (trans_shapes.count(node->GetType()) > 0) { | |||
trans_shape = true; | |||
} else if (node->GetType() == REFORMAT) { | |||
trans_format = true; | |||
@@ -71,15 +71,13 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
GELOGW("The number of input for slice must be %zu.", kSliceInputSize); | |||
return NOT_CHANGED; | |||
} | |||
ConstGeTensorPtr x_ = input[kSliceInputIndexX]; | |||
ConstGeTensorPtr begin = input[kSliceInputIndexBegin]; | |||
ConstGeTensorPtr size = input[kSliceInputIndexSize]; | |||
if (x_ == nullptr || begin == nullptr || size == nullptr) { | |||
GELOGW("input tensor is nullptr."); | |||
return NOT_CHANGED; | |||
Status ret = CheckInput(x_, begin, size); | |||
if (ret != SUCCESS) { | |||
return ret; | |||
} | |||
// data type in input_x | |||
auto data_type = x_->GetTensorDesc().GetDataType(); | |||
// check supported | |||
@@ -92,11 +90,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
if (!is_success) { | |||
return NOT_CHANGED; | |||
} | |||
// check data type of begin and size | |||
if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | |||
GELOGW("Data type of begin and size for slice are not DT_INT32."); | |||
return NOT_CHANGED; | |||
} | |||
void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(x_->GetData().data())); | |||
int32_t *begin_data = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(begin->GetData().GetData())); | |||
@@ -145,7 +139,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
return NOT_CHANGED; | |||
} | |||
Status ret = CheckOutputDims(output_dims, attr); | |||
ret = CheckOutputDims(output_dims, attr); | |||
if (ret != SUCCESS) { | |||
return ret; | |||
} | |||
@@ -161,6 +155,19 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||
return SUCCESS; | |||
} | |||
Status SliceKernel::CheckInput(const ConstGeTensorPtr x_, const ConstGeTensorPtr begin, const ConstGeTensorPtr size) { | |||
if (x_ == nullptr || begin == nullptr || size == nullptr) { | |||
GELOGW("input tensor is nullptr."); | |||
return NOT_CHANGED; | |||
} | |||
// check data type of begin and size | |||
if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | |||
GELOGW("Data type of begin and size for slice are not DT_INT32."); | |||
return NOT_CHANGED; | |||
} | |||
return SUCCESS; | |||
} | |||
Status SliceKernel::CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr) { | |||
// check dim not all less than 0 | |||
for (auto dim : output_dims) { | |||
@@ -28,6 +28,7 @@ class SliceKernel : public Kernel { | |||
vector<GeTensorPtr> &v_output) override; | |||
Status CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr); | |||
Status CheckInput(const ConstGeTensorPtr x_, const ConstGeTensorPtr begin, const ConstGeTensorPtr size); | |||
}; | |||
} // namespace ge | |||
@@ -95,8 +95,8 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||
} | |||
op_info.dataType = iter->second; | |||
HcclReduceOp op_type = HCCL_REDUCE_SUM; | |||
if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER || | |||
op_desc->GetType() == HVDCALLBACKALLREDUCE || op_desc->GetType() == HCOMREDUCE) { | |||
std::set<std::string> hccl_types = { HCOMALLREDUCE, HCOMREDUCESCATTER, HVDCALLBACKALLREDUCE, HCOMREDUCE }; | |||
if (hccl_types.count(op_desc->GetType()) > 0) { | |||
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | |||
"[Get][HcclOperationType] failed for %s type:%s", op_desc->GetName().c_str(), | |||
op_desc->GetType().c_str()); | |||
@@ -283,6 +283,7 @@ class Impl { | |||
void SetRtSocVersion(); | |||
void UpdateThreadContext(); | |||
void LoadOpsProto(); | |||
std::string GetParam(const std::string ¶m); | |||
public: | |||
ge::GeGenerator generator_; | |||
std::map<std::string, std::string> options_; | |||
@@ -512,6 +513,10 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options | |||
return GRAPH_SUCCESS; | |||
} | |||
std::string Impl::GetParam(const std::string ¶m) { | |||
return options_.find(param) == options_.end() ? "" : options_[param]; | |||
} | |||
graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::string> &options) { | |||
// 1. check options | |||
graphStatus ret = CheckOptions(options); | |||
@@ -533,20 +538,14 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(ge::CheckLogParamValidAndSetLogLevel(log) == 0, GRAPH_PARAM_INVALID); | |||
options_[ge::ir_option::LOG_LEVEL] = log; | |||
string input_shape = options_.find("input_shape") == options_.end() ? "" : options_["input_shape"]; | |||
string input_format = options_.find("input_format") == options_.end() ? "" : options_["input_format"]; | |||
string net_format = options_.find("net_format") == options_.end() ? "" : options_["net_format"]; | |||
string dynamic_batch_size = options_.find(ge::ir_option::DYNAMIC_BATCH_SIZE) == options_.end() | |||
? "" | |||
: options_[ge::ir_option::DYNAMIC_BATCH_SIZE]; | |||
string dynamic_image_size = options_.find(ge::ir_option::DYNAMIC_IMAGE_SIZE) == options_.end() | |||
? "" | |||
: options_[ge::ir_option::DYNAMIC_IMAGE_SIZE]; | |||
string dynamic_dims = | |||
options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS]; | |||
string input_shape_range = | |||
options_.find(ge::INPUT_SHAPE_RANGE) == options_.end() ? "" : options_[ge::INPUT_SHAPE_RANGE]; | |||
string input_shape = GetParam("input_shape"); | |||
string input_format = GetParam("input_format"); | |||
string net_format = GetParam("net_format"); | |||
string dynamic_batch_size = GetParam(ge::ir_option::DYNAMIC_BATCH_SIZE); | |||
string dynamic_image_size = GetParam(ge::ir_option::DYNAMIC_IMAGE_SIZE); | |||
string dynamic_dims = GetParam(ge::ir_option::DYNAMIC_DIMS); | |||
string input_shape_range = GetParam(ge::INPUT_SHAPE_RANGE); | |||
auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, | |||
input_shape_range, input_format, is_dynamic_input_); | |||
if (status != ge::SUCCESS) { | |||
@@ -559,15 +558,12 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||
omg_context_.dynamic_image_size = dynamic_image_size; | |||
omg_context_.dynamic_dims = dynamic_dims; | |||
// check output_type | |||
std::string output_type = options_.find(ge::ir_option::OUTPUT_TYPE) == options_.end() | |||
? "" | |||
: options_[ge::ir_option::OUTPUT_TYPE]; | |||
std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE); | |||
GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, | |||
return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!"); | |||
// check insert_op_conf | |||
std::string insert_op_conf = options_.find(ge::ir_option::INSERT_OP_FILE) == options_.end() | |||
? "" | |||
: options_[ge::ir_option::INSERT_OP_FILE]; | |||
std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE); | |||
GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS, | |||
return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!"); | |||