diff --git a/ge/graph/manager/util/hcom_util.cc b/ge/graph/manager/util/hcom_util.cc index d865b40e..ae4c5838 100644 --- a/ge/graph/manager/util/hcom_util.cc +++ b/ge/graph/manager/util/hcom_util.cc @@ -263,7 +263,7 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST) { + if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); int64_t root_id = 0; Status dmrt = GetHcclRootId(op_desc, root_id); @@ -281,7 +281,7 @@ Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, bool HcomOmeUtil::IsHCOMOp(const string &op_type) { return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) || - (op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER); + (op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE); } bool HcomOmeUtil::IsHorovodOp(const string &op_type) {