Browse Source

fix

pull/2052/head
guopeian 3 years ago
parent
commit
f4d4445e16
2 changed files with 43 additions and 13 deletions
  1. +33
    -5
      ge/single_op/task/op_task.cc
  2. +10
    -8
      ge/single_op/task/op_task.h

+ 33
- 5
ge/single_op/task/op_task.cc View File

@@ -1136,11 +1136,39 @@ Status AiCpuTask::SetMemCopyTask(const domi::KernelExDef &kernel_def) {
return SUCCESS;
}

Status AiCpuBaseTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) {
Status AiCpuTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) {
GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream));
if (unknown_type_ == DEPEND_COMPUTE) {
std::vector<DataBuffer> summary_buffers;
for (size_t i = 0; i < num_outputs_; ++i) {
summary_buffers.emplace_back(output_summary_[i], sizeof(aicpu::FWKAdapter::ResultSummary), false);
}
GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, summary_buffers));
} else {
GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, output_buffers));
}

GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream));
if (unknown_type_ == DEPEND_SHAPE_RANGE) {
GE_CHK_RT_RET(rtStreamSynchronize(stream));
GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc));
} else if (unknown_type_ == DEPEND_COMPUTE) {
GE_CHK_RT_RET(rtStreamSynchronize(stream));
GE_CHK_STATUS_RET_NOLOG(UpdateShapeAndDataByResultSummary(output_desc, output_buffers, stream));
}

return SUCCESS;
}

Status AiCpuCCTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) {
GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream));
if (unknown_type_ == DEPEND_COMPUTE) {
std::vector<DataBuffer> summary_buffers;


+ 10
- 8
ge/single_op/task/op_task.h View File

@@ -167,12 +167,6 @@ class AiCpuBaseTask : public OpTask {
UnknowShapeOpType GetUnknownType() const { return unknown_type_; }
Status UpdateArgTable(const SingleOpModelParam &param) override;
const std::string &GetTaskType() const override;
Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) override;
virtual Status LaunchKernel(rtStream_t stream) = 0;
protected:
Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
Status SetInputConst();
@@ -226,7 +220,11 @@ class AiCpuTask : public AiCpuBaseTask {

Status LaunchKernel(rtStream_t stream) override;
void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override;

Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) override;
Status SetMemCopyTask(const domi::KernelExDef &kernel_def);

private:
@@ -265,6 +263,11 @@ class AiCpuCCTask : public AiCpuBaseTask {
AiCpuCCTask &operator=(const AiCpuCCTask &) = delete;
Status SetMemCopyTask(const domi::KernelDef &kernel_def);
Status LaunchKernel(rtStream_t stream) override;
Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) override;
void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override;
const void *GetArgs() const;
void SetKernelArgs(std::unique_ptr<uint8_t[]> args, size_t arg_size);
@@ -274,7 +277,6 @@ class AiCpuCCTask : public AiCpuBaseTask {
size_t GetArgSize() const;
private:
Status InitForSummaryAndCopy();

Status CopyDataToHbm(vector<DataBuffer> &outputs, rtStream_t stream) override;
private:
friend class AiCpuCCTaskBuilder;


Loading…
Cancel
Save