Browse Source

fix sc

tags/v1.3.0
wjm 4 years ago
parent
commit
2fbf651550
16 changed files with 76 additions and 117 deletions
  1. +4
    -6
      ge/common/profiling/profiling_manager.cc
  2. +2
    -6
      ge/generator/ge_generator.cc
  3. +2
    -5
      ge/graph/build/memory/graph_mem_assigner.cc
  4. +2
    -1
      ge/graph/build/stream_allocator.cc
  5. +9
    -36
      ge/graph/load/model_manager/model_manager.cc
  6. +2
    -3
      ge/graph/manager/graph_manager.cc
  7. +5
    -5
      ge/graph/optimize/mem_rw_conflict_optimize.cc
  8. +4
    -10
      ge/graph/partition/graph_partition.cc
  9. +1
    -4
      ge/graph/passes/base_pass.cc
  10. +2
    -6
      ge/graph/passes/flow_ctrl_pass.cc
  11. +2
    -1
      ge/graph/passes/subgraph_const_migration_pass.cc
  12. +5
    -2
      ge/graph/passes/transop_breadth_fusion_pass.cc
  13. +18
    -11
      ge/host_kernels/slice_kernel.cc
  14. +1
    -0
      ge/host_kernels/slice_kernel.h
  15. +2
    -2
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  16. +15
    -19
      ge/ir_build/ge_ir_build.cc

+ 4
- 6
ge/common/profiling/profiling_manager.cc View File

@@ -961,9 +961,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp
std::vector<DataType> input_data_type;
for (size_t i = 0; i < op->GetAllInputsSize(); ++i) {
GeTensorDescPtr input_tensor_desc = op->MutableInputDesc(i);
if (input_tensor_desc == nullptr) {
continue;
}
GE_IF_BOOL_EXEC(input_tensor_desc == nullptr, continue);

input_format.emplace_back(input_tensor_desc->GetFormat());
input_shape.emplace_back(input_tensor_desc->GetShape().GetDims());
input_data_type.emplace_back(input_tensor_desc->GetDataType());
@@ -973,9 +972,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInp
std::vector<DataType> output_data_type;
for (size_t j = 0; j < op->GetOutputsSize(); ++j) {
GeTensorDescPtr output_tensor_desc = op->MutableOutputDesc(j);
if (output_tensor_desc == nullptr) {
continue;
}
GE_IF_BOOL_EXEC(output_tensor_desc == nullptr, continue);

output_format.emplace_back(output_tensor_desc->GetFormat());
output_shape.emplace_back(output_tensor_desc->GetShape().GetDims());
output_data_type.emplace_back(output_tensor_desc->GetDataType());


+ 2
- 6
ge/generator/ge_generator.cc View File

@@ -854,7 +854,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
op_desc->GetName().c_str());
return PARAM_INVALID;
}
OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_;
OmgContext &omg_context = impl_->omg_context_;
omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);

if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
@@ -869,11 +869,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) {
fuzz_compile_flag = true;
}
if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag)) {
REPORT_CALL_ERROR("E19999", "set ATTR_NAME_FUZZ_BUILD failed for %s.", op_desc->GetName().c_str());
GELOGE(FAILED, "[Set][ATTR_NAME_FUZZ_BUILD] Failed to set attr for %s.", op_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag);
impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag;

// 1. Create ComputeGraph.


+ 2
- 5
ge/graph/build/memory/graph_mem_assigner.cc View File

@@ -579,11 +579,8 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) {
if (continuous_output) {
GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "output"),
"[Get][MemType]fail for node:%s", node->GetName().c_str());
ret = AssignContinuousOutputMemory(node, memory_type, continuous_type);
if (ret != ge::SUCCESS) {
GELOGE(ret, "[Assign][Memory:Continuous:Ouput]fail for node:%s", node->GetName().c_str());
return ret;
}
GE_CHK_STATUS_RET(AssignContinuousOutputMemory(node, memory_type, continuous_type),
"[Assign][Memory:Continuous:Output]fail for node:%s", node->GetName().c_str());
}
}
// Assign continuous input memory in `reverse topo order` which stored before


+ 2
- 1
ge/graph/build/stream_allocator.cc View File

@@ -1212,7 +1212,8 @@ Status StreamAllocator::SetActiveStreamsForLoop() {
for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) {
GE_CHECK_NOTNULL(node->GetOpDesc());
bool is_loop_active = false;
if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) {
(void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active);
if (is_loop_active) {
vector<string> activated_label_list;

NodePtr pre_switch_node = FindSwitchNodeBeforeLoopActiveNode(node);


+ 9
- 36
ge/graph/load/model_manager/model_manager.cc View File

@@ -1668,42 +1668,23 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op
};
GE_MAKE_GUARD(release, callback);
// malloc sysOpInfoList in SysOpCheckInfo
status = rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM);
if (status != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status);
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status);
return RT_ERROR_TO_GE_STATUS(status);
}
GE_CHK_RT_RET(rtMalloc(&d_req_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM));
allocated_mem.push_back(d_req_op_list);

// malloc sysOpInfoList in SysOpCheckResp
status = rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM);
if (status != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status);
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(SysOpInfo), status);
return RT_ERROR_TO_GE_STATUS(status);
}
GE_CHK_RT_RET(rtMalloc(&d_res_op_list, op_nums * sizeof(SysOpInfo), RT_MEMORY_HBM));
allocated_mem.push_back(d_res_op_list);

// malloc returnCodeList in SysOpCheckResp
status = rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM);
if (status != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status);
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%zu, ret = 0x%X", op_nums * sizeof(ReturnCode), status);
return RT_ERROR_TO_GE_STATUS(status);
}
GE_CHK_RT_RET(rtMalloc(&d_ret_code_list, op_nums * sizeof(ReturnCode), RT_MEMORY_HBM));
allocated_mem.push_back(d_ret_code_list);

for (const auto &op_type : aicpu_optype_list) {
SysOpInfo op_info;
// malloc op_type name in SysOpInfo
void *d_op_type_name = nullptr;
status = rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM);
if (status != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status);
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.length(), status);
return RT_ERROR_TO_GE_STATUS(status);
}
GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM));

allocated_mem.push_back(d_op_type_name);
GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.length(), op_type.c_str(), op_type.length(), RT_MEMCPY_HOST_TO_DEVICE));
op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name));
@@ -1716,12 +1697,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op
SysOpInfo op_info;
// malloc op_type name in SysOpInfo
void *d_op_type_name = nullptr;
status = rtMalloc(&d_op_type_name, op_type.size(), RT_MEMORY_HBM);
if (status != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%lu, ret = 0x%X", op_type.length(), status);
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%lu, ret = 0x%X", op_type.size(), status);
return RT_ERROR_TO_GE_STATUS(status);
}
GE_CHK_RT_RET(rtMalloc(&d_op_type_name, op_type.length(), RT_MEMORY_HBM));

allocated_mem.push_back(d_op_type_name);
GE_CHK_RT(rtMemcpy(d_op_type_name, op_type.size(), op_type.c_str(), op_type.size(), RT_MEMCPY_HOST_TO_DEVICE));
op_info.opType = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_op_type_name));
@@ -1745,12 +1722,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op
op_check_info_res.sysOpInfoList = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_res_op_list));

uint32_t args_size = sizeof(SysOpCheckInfo) + sizeof(SysOpCheckResp);
status = rtMalloc(&args, args_size, RT_MEMORY_HBM);
if (status != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, size:%u, ret = 0x%X", args_size, status);
GELOGE(RT_FAILED, "[Call][RtMalloc] fail, size:%u, ret = 0x%X", args_size, status);
return RT_ERROR_TO_GE_STATUS(status);
}
GE_CHK_RT_RET(rtMalloc(&args, args_size, RT_MEMORY_HBM));

allocated_mem.push_back(args);
GE_CHK_RT(rtMemcpy(args, sizeof(SysOpCheckInfo), reinterpret_cast<void *>(&op_check_info_req), sizeof(SysOpCheckInfo),
RT_MEMCPY_HOST_TO_DEVICE));


+ 2
- 3
ge/graph/manager/graph_manager.cc View File

@@ -3532,9 +3532,8 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra
return ret;
}
GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph");
if ((options_.build_mode == BUILD_MODE_TUNING) &&
(options_.build_step == BUILD_STEP_BEFORE_UB_MATCH || options_.build_step == BUILD_STEP_AFTER_BUILDER ||
options_.build_step == BUILD_STEP_AFTER_BUILDER_SUB)) {
std::set<string> build_steps = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB};
if ((options_.build_mode == BUILD_MODE_TUNING) && (build_steps.count(options_.build_step) > 0)) {
GE_TIMESTAMP_START(ConvertGraphToFile);
std::string tuning_path;
(void) GetContext().GetOption(TUNING_PATH, tuning_path);


+ 5
- 5
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -743,12 +743,12 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) {
continue;
}
// ignore data / netoutput of subgraph
if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
continue;
}
if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
continue;
if (AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
if (node->GetType() == DATA || node->GetType() == NETOUTPUT) {
continue;
}
}

bool identity_reserved = false;
AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved);
if (identity_reserved) {


+ 4
- 10
ge/graph/partition/graph_partition.cc View File

@@ -366,11 +366,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr
// link input -> end
string end_name = kEndType + std::to_string(graph_info_.num_of_pld_end_);
auto end_op_desc = MakeShared<OpDesc>(end_graph->GetName() + "_" + end_name, END);
if (end_op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed.");
GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed, pld_op_desc is nullptr.");
return FAILED;
}
GE_CHECK_NOTNULL(end_op_desc);

GE_IF_BOOL_EXEC(!AttrUtils::SetInt(end_op_desc, "peerIndex", graph_info_.num_of_pld_end_),
GELOGW("SetInt peerIndex failed");)
GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()),
@@ -429,11 +426,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr
int64_t node_id = src_node_opdesc->GetId();
const string pld_name = kPlaceHolderType + std::to_string(graph_info_.num_of_pld_end_);
auto pld_op_desc = MakeShared<OpDesc>(pld_graph->GetName() + "_" + pld_name, PLACEHOLDER);
if (pld_op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "New Memory for OpDesc failed.");
GELOGE(GRAPH_PARAM_INVALID, "[New][Memory] for OpDesc failed.");
return FAILED;
}
GE_CHECK_NOTNULL(pld_op_desc);
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "peerIndex", graph_info_.num_of_pld_end_),
GELOGW("SetInt peerIndex failed");)
GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, "_peerNodeName", new_end_node->GetName()),


+ 1
- 4
ge/graph/passes/base_pass.cc View File

@@ -333,11 +333,8 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
during_pass_node_set.nodes_last.clear();
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes);

if (re_pass_times == kMaxRePassTimes) {
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes);
}
GE_IF_BOOL_EXEC(re_pass_times == kMaxRePassTimes, GELOGW("re_pass_times should not come to %d", kMaxRePassTimes));
GELOGD("All passes runs end");

return SUCCESS;
}
Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) {


+ 2
- 6
ge/graph/passes/flow_ctrl_pass.cc View File

@@ -41,9 +41,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
bool graph_change = false;
// 1. Add FP/BP flow ctrl (big cycle)
for (auto &node : compute_graph->GetDirectNode()) {
if (node == nullptr) {
continue;
}
GE_IF_BOOL_EXEC(node == nullptr, continue);
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
uint32_t true_stream_id = 0;
bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id);
@@ -65,9 +63,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
// 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle)
// NOTE: Small cycle share the variables with big cycle.
for (auto &node : compute_graph->GetDirectNode()) {
if (node == nullptr) {
continue;
}
GE_IF_BOOL_EXEC(node == nullptr, continue);
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
bool need_cycle_flag = false;
bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag);


+ 2
- 1
ge/graph/passes/subgraph_const_migration_pass.cc View File

@@ -164,9 +164,10 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra

data_nodes[parent_index] = node;
GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str());
} else if ((node->GetType() == CONSTANT) && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) {
} else if (node->GetType() == CONSTANT) {
set<string> peer_name_list;
const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex);
GE_IF_BOOL_EXEC(out_anchor == nullptr, continue);
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto &peer_node = in_anchor->GetOwnerNode();
// Trim subgraph node name prefix.


+ 5
- 2
ge/graph/passes/transop_breadth_fusion_pass.cc View File

@@ -64,16 +64,19 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No
GE_IF_BOOL_EXEC(node == nullptr || node->GetOpDesc() == nullptr,
REPORT_INNER_ERROR("E19999", "Param node or its op_desc is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] Param node or its op_desc is nullptr"); return "");

std::set<std::string> trans_shapes = { RESHAPE, EXPANDDIMS, SQUEEZE };
std::set<std::string> trans_shape_and_format = { TRANSPOSE, TRANSPOSED, EXPANDDIMS };
if (node->GetType() == CAST) {
trans_data_type = true;
} else if (node->GetType() == TRANSPOSE || node->GetType() == TRANSPOSED || node->GetType() == EXPANDDIMS) {
} else if (trans_shape_and_format.count(node->GetType()) > 0) {
trans_format = true;
trans_shape = true;
} else if (node->GetType() == TRANSDATA) {
trans_data_type = true;
trans_format = true;
trans_shape = true;
} else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) {
} else if (trans_shapes.count(node->GetType()) > 0) {
trans_shape = true;
} else if (node->GetType() == REFORMAT) {
trans_format = true;


+ 18
- 11
ge/host_kernels/slice_kernel.cc View File

@@ -71,15 +71,13 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
GELOGW("The number of input for slice must be %zu.", kSliceInputSize);
return NOT_CHANGED;
}

ConstGeTensorPtr x_ = input[kSliceInputIndexX];
ConstGeTensorPtr begin = input[kSliceInputIndexBegin];
ConstGeTensorPtr size = input[kSliceInputIndexSize];
if (x_ == nullptr || begin == nullptr || size == nullptr) {
GELOGW("input tensor is nullptr.");
return NOT_CHANGED;
Status ret = CheckInput(x_, begin, size);
if (ret != SUCCESS) {
return ret;
}

// data type in input_x
auto data_type = x_->GetTensorDesc().GetDataType();
// check supported
@@ -92,11 +90,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
if (!is_success) {
return NOT_CHANGED;
}
// check data type of begin and size
if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) {
GELOGW("Data type of begin and size for slice are not DT_INT32.");
return NOT_CHANGED;
}


void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(x_->GetData().data()));
int32_t *begin_data = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(begin->GetData().GetData()));
@@ -145,7 +139,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
return NOT_CHANGED;
}

Status ret = CheckOutputDims(output_dims, attr);
ret = CheckOutputDims(output_dims, attr);
if (ret != SUCCESS) {
return ret;
}
@@ -161,6 +155,19 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
return SUCCESS;
}

Status SliceKernel::CheckInput(const ConstGeTensorPtr x_, const ConstGeTensorPtr begin, const ConstGeTensorPtr size) {
if (x_ == nullptr || begin == nullptr || size == nullptr) {
GELOGW("input tensor is nullptr.");
return NOT_CHANGED;
}
// check data type of begin and size
if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) {
GELOGW("Data type of begin and size for slice are not DT_INT32.");
return NOT_CHANGED;
}
return SUCCESS;
}

Status SliceKernel::CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr) {
// check dim not all less than 0
for (auto dim : output_dims) {


+ 1
- 0
ge/host_kernels/slice_kernel.h View File

@@ -28,6 +28,7 @@ class SliceKernel : public Kernel {
vector<GeTensorPtr> &v_output) override;

Status CheckOutputDims(const std::vector<int64_t> &output_dims, const OpDescPtr attr);
Status CheckInput(const ConstGeTensorPtr x_, const ConstGeTensorPtr begin, const ConstGeTensorPtr size);
};
} // namespace ge



+ 2
- 2
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -95,8 +95,8 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
}
op_info.dataType = iter->second;
HcclReduceOp op_type = HCCL_REDUCE_SUM;
if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER ||
op_desc->GetType() == HVDCALLBACKALLREDUCE || op_desc->GetType() == HCOMREDUCE) {
std::set<std::string> hccl_types = { HCOMALLREDUCE, HCOMREDUCESCATTER, HVDCALLBACKALLREDUCE, HCOMREDUCE };
if (hccl_types.count(op_desc->GetType()) > 0) {
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type),
"[Get][HcclOperationType] failed for %s type:%s", op_desc->GetName().c_str(),
op_desc->GetType().c_str());


+ 15
- 19
ge/ir_build/ge_ir_build.cc View File

@@ -283,6 +283,7 @@ class Impl {
void SetRtSocVersion();
void UpdateThreadContext();
void LoadOpsProto();
std::string GetParam(const std::string &param);
public:
ge::GeGenerator generator_;
std::map<std::string, std::string> options_;
@@ -512,6 +513,10 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options
return GRAPH_SUCCESS;
}

std::string Impl::GetParam(const std::string &param) {
return options_.find(param) == options_.end() ? "" : options_[param];
}

graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::string> &options) {
// 1. check options
graphStatus ret = CheckOptions(options);
@@ -533,20 +538,14 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri
GE_CHK_BOOL_RET_STATUS_NOLOG(ge::CheckLogParamValidAndSetLogLevel(log) == 0, GRAPH_PARAM_INVALID);
options_[ge::ir_option::LOG_LEVEL] = log;

string input_shape = options_.find("input_shape") == options_.end() ? "" : options_["input_shape"];
string input_format = options_.find("input_format") == options_.end() ? "" : options_["input_format"];
string net_format = options_.find("net_format") == options_.end() ? "" : options_["net_format"];
string dynamic_batch_size = options_.find(ge::ir_option::DYNAMIC_BATCH_SIZE) == options_.end()
? ""
: options_[ge::ir_option::DYNAMIC_BATCH_SIZE];
string dynamic_image_size = options_.find(ge::ir_option::DYNAMIC_IMAGE_SIZE) == options_.end()
? ""
: options_[ge::ir_option::DYNAMIC_IMAGE_SIZE];
string dynamic_dims =
options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS];
string input_shape_range =
options_.find(ge::INPUT_SHAPE_RANGE) == options_.end() ? "" : options_[ge::INPUT_SHAPE_RANGE];
string input_shape = GetParam("input_shape");
string input_format = GetParam("input_format");
string net_format = GetParam("net_format");

string dynamic_batch_size = GetParam(ge::ir_option::DYNAMIC_BATCH_SIZE);
string dynamic_image_size = GetParam(ge::ir_option::DYNAMIC_IMAGE_SIZE);
string dynamic_dims = GetParam(ge::ir_option::DYNAMIC_DIMS);
string input_shape_range = GetParam(ge::INPUT_SHAPE_RANGE);
auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape,
input_shape_range, input_format, is_dynamic_input_);
if (status != ge::SUCCESS) {
@@ -559,15 +558,12 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri
omg_context_.dynamic_image_size = dynamic_image_size;
omg_context_.dynamic_dims = dynamic_dims;
// check output_type
std::string output_type = options_.find(ge::ir_option::OUTPUT_TYPE) == options_.end()
? ""
: options_[ge::ir_option::OUTPUT_TYPE];
std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE);
GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS,
return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!");
// check insert_op_conf
std::string insert_op_conf = options_.find(ge::ir_option::INSERT_OP_FILE) == options_.end()
? ""
: options_[ge::ir_option::INSERT_OP_FILE];

std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE);
GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS,
return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!");



Loading…
Cancel
Save