|
|
@@ -263,7 +263,7 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro |
|
|
|
Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, |
|
|
|
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST) { |
|
|
|
if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HCOMREDUCE || op_desc->GetType() == HVDCALLBACKBROADCAST) { |
|
|
|
GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
int64_t root_id = 0; |
|
|
|
Status dmrt = GetHcclRootId(op_desc, root_id); |
|
|
@@ -286,7 +286,7 @@ bool HcomOmeUtil::IsHCOMOp(const string &op_type) { |
|
|
|
|
|
|
|
bool HcomOmeUtil::IsHorovodOp(const string &op_type) { |
|
|
|
return (op_type == HVDCALLBACKALLREDUCE) || (op_type == HVDCALLBACKALLGATHER) || (op_type == HVDCALLBACKBROADCAST) || |
|
|
|
(op_type == HVDWAIT); |
|
|
|
(op_type == HVDWAIT) || (op_type == HCOMREDUCE); |
|
|
|
} |
|
|
|
|
|
|
|
Status HcomOmeUtil::CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc, |
|
|
|