diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index b7b638de..7e58cc04 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -1136,11 +1136,39 @@ Status AiCpuTask::SetMemCopyTask(const domi::KernelExDef &kernel_def) { return SUCCESS; } -Status AiCpuBaseTask::LaunchKernel(const std::vector &input_desc, - const std::vector &input_buffers, - std::vector &output_desc, - std::vector &output_buffers, - rtStream_t stream) { +Status AiCpuTask::LaunchKernel(const std::vector &input_desc, + const std::vector &input_buffers, + std::vector &output_desc, + std::vector &output_buffers, + rtStream_t stream) { + GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); + if (unknown_type_ == DEPEND_COMPUTE) { + std::vector summary_buffers; + for (size_t i = 0; i < num_outputs_; ++i) { + summary_buffers.emplace_back(output_summary_[i], sizeof(aicpu::FWKAdapter::ResultSummary), false); + } + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, summary_buffers)); + } else { + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, output_buffers)); + } + + GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); + if (unknown_type_ == DEPEND_SHAPE_RANGE) { + GE_CHK_RT_RET(rtStreamSynchronize(stream)); + GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); + } else if (unknown_type_ == DEPEND_COMPUTE) { + GE_CHK_RT_RET(rtStreamSynchronize(stream)); + GE_CHK_STATUS_RET_NOLOG(UpdateShapeAndDataByResultSummary(output_desc, output_buffers, stream)); + } + + return SUCCESS; +} + +Status AiCpuCCTask::LaunchKernel(const std::vector &input_desc, + const std::vector &input_buffers, + std::vector &output_desc, + std::vector &output_buffers, + rtStream_t stream) { GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); if (unknown_type_ == DEPEND_COMPUTE) { std::vector summary_buffers; diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index ac7e489c..836b2046 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -167,12 +167,6 @@ class AiCpuBaseTask : public OpTask { UnknowShapeOpType GetUnknownType() const { return unknown_type_; } Status UpdateArgTable(const SingleOpModelParam ¶m) override; const std::string &GetTaskType() const override; - Status LaunchKernel(const std::vector &input_desc, - const std::vector &input_buffers, - std::vector &output_desc, - std::vector &output_buffers, - rtStream_t stream) override; - virtual Status LaunchKernel(rtStream_t stream) = 0; protected: Status UpdateIoAddr(const std::vector &inputs, const std::vector &outputs); Status SetInputConst(); @@ -226,7 +220,11 @@ class AiCpuTask : public AiCpuBaseTask { Status LaunchKernel(rtStream_t stream) override; void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; - + Status LaunchKernel(const std::vector &input_desc, + const std::vector &input_buffers, + std::vector &output_desc, + std::vector &output_buffers, + rtStream_t stream) override; Status SetMemCopyTask(const domi::KernelExDef &kernel_def); private: @@ -265,6 +263,11 @@ class AiCpuCCTask : public AiCpuBaseTask { AiCpuCCTask &operator=(const AiCpuCCTask &) = delete; Status SetMemCopyTask(const domi::KernelDef &kernel_def); Status LaunchKernel(rtStream_t stream) override; + Status LaunchKernel(const std::vector &input_desc, + const std::vector &input_buffers, + std::vector &output_desc, + std::vector &output_buffers, + rtStream_t stream) override; void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; const void *GetArgs() const; void SetKernelArgs(std::unique_ptr args, size_t arg_size); @@ -274,7 +277,6 @@ class AiCpuCCTask : public AiCpuBaseTask { size_t GetArgSize() const; private: Status InitForSummaryAndCopy(); - Status CopyDataToHbm(vector &outputs, rtStream_t stream) override; private: friend class AiCpuCCTaskBuilder;