Browse Source

reuse workspace memory of hccl op

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v1.2.0
zhoufeng 4 years ago
parent
commit
d6308151e0
2 changed files with 6 additions and 15 deletions
  1. +2
    -14
      ge/ge_runtime/task/hccl_task.cc
  2. +4
    -1
      inc/framework/ge_runtime/task_info.h

+ 2
- 14
ge/ge_runtime/task/hccl_task.cc View File

@@ -52,15 +52,7 @@ HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<Hccl
}
}

HcclTask::~HcclTask() {
if (workspace_mem_ != nullptr) {
rtError_t rt_ret = rtFree(workspace_mem_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret);
}
workspace_mem_ = nullptr;
}
}
HcclTask::~HcclTask() {}

bool HcclTask::Distribute() {
// Ops kernel info store
@@ -79,11 +71,7 @@ bool HcclTask::Distribute() {
SetSecondaryStream();

if (task_info_->workspace_size() > 0) {
rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
workspace_mem_ = task_info_->workspace_addr();
}

GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl.");


+ 4
- 1
inc/framework/ge_runtime/task_info.h View File

@@ -271,13 +271,14 @@ class FusionEndTaskInfo : public TaskInfo {
class HcclTaskInfo : public TaskInfo {
public:
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num,
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
hccl_type_(hccl_type),
input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr),
workspace_addr_(workspace_addr),
workspace_size_(workspace_size),
hccl_stream_num_(hccl_stream_num),
private_def_(private_def),
@@ -292,6 +293,7 @@ class HcclTaskInfo : public TaskInfo {
const std::string &hccl_type() const { return hccl_type_; }
void *input_data_addr() const { return input_data_addr_; }
void *output_data_addr() const { return output_data_addr_; }
void *workspace_addr() const { return workspace_addr_; }
int64_t workspace_size() const { return workspace_size_; }
int64_t hccl_stream_num() const { return hccl_stream_num_; }
const std::vector<uint8_t> &private_def() const { return private_def_; }
@@ -306,6 +308,7 @@ class HcclTaskInfo : public TaskInfo {
std::string hccl_type_;
void *input_data_addr_;
void *output_data_addr_;
void *workspace_addr_;
int64_t workspace_size_;
int64_t hccl_stream_num_;
std::vector<uint8_t> private_def_;


Loading…
Cancel
Save