diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index e36c4333..4f36eece 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -274,8 +274,9 @@ class HcclTaskInfo : public TaskInfo { HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, - int64_t op_type, int64_t data_type, const std::string &group, - std::function hcom_bind_model, std::function hcom_unbind_model, + int64_t op_type, int64_t data_type, int64_t src_rank, int64_t dest_rank, int64_t sr_tag, + const std::string &group, std::function hcom_bind_model, + std::function hcom_unbind_model, std::function, void *)> hcom_distribute_task, bool dump_flag) : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), hccl_type_(hccl_type), @@ -290,6 +291,9 @@ class HcclTaskInfo : public TaskInfo { root_id_(root_id), op_type_(op_type), data_type_(data_type), + src_rank_(src_rank), + dest_rank_(dest_rank), + sr_tag_(sr_tag), group_(group), hcom_bind_model_(hcom_bind_model), hcom_unbind_model_(hcom_unbind_model), @@ -308,6 +312,9 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id() const { return root_id_; } int64_t op_type() const { return op_type_; } int64_t data_type() const { return data_type_; } + int64_t src_rank() const { return src_rank_; } + int64_t dest_rank() const { return dest_rank_; } + int64_t sr_tag() const { return sr_tag_; } const std::string &group() const { return group_; } std::function hcom_bind_model() const { return hcom_bind_model_; } std::function hcom_unbind_model() const { return hcom_unbind_model_; } @@ -328,6 +335,9 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id_; int64_t op_type_; int64_t data_type_; + int64_t src_rank_; + int64_t dest_rank_; + int64_t sr_tag_; std::string group_; std::function hcom_bind_model_; std::function hcom_unbind_model_;