diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index fd6639a5..ee752022 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -293,25 +293,26 @@ Status TbeOpTask::UpdateNodeByShape(const vector &input_desc, cons } Status TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { + node_ = node; + tiling_buffer_ = tiling_buffer; + max_tiling_size_ = max_tiling_size; if (tiling_buffer != nullptr) { - uintptr_t *arg_base = reinterpret_cast(args_.get()); - size_t arg_num = arg_size_ / sizeof(void *); + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); uint32_t inputs_num = node->GetOpDesc()->GetInputsSize(); uint32_t outputs_num = node->GetOpDesc()->GetOutputsSize(); uint32_t workspace_nums = node->GetOpDesc()->GetWorkspace().size(); uint32_t tiling_index = inputs_num + outputs_num + workspace_nums; - if (arg_num == 0 || arg_num <= tiling_index) { + if (arg_num == 0 || arg_num < tiling_index) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", tiling_index, arg_num); return ACL_ERROR_GE_INTERNAL_ERROR; } arg_base[tiling_index] = reinterpret_cast(tiling_buffer); } - node_ = node; - tiling_buffer_ = tiling_buffer; - max_tiling_size_ = max_tiling_size; return SUCCESS; } @@ -481,20 +482,21 @@ void TbeOpTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { } Status AtomicAddrCleanOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { + node_ = node; + tiling_buffer_ = tiling_buffer; + max_tiling_size_ = max_tiling_size; if (tiling_buffer != nullptr) { - uintptr_t *arg_base = reinterpret_cast(args_.get()); - size_t arg_num = arg_size_ / sizeof(void *); + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); uint32_t tiling_index = atomic_output_indices_.size(); - if (arg_num == 0 || arg_num <= tiling_index) { + if (arg_num == 0 || arg_num < tiling_index) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", tiling_index, arg_num); return ACL_ERROR_GE_INTERNAL_ERROR; } arg_base[tiling_index] = reinterpret_cast(tiling_buffer); } - node_ = node; - tiling_buffer_ = tiling_buffer; - max_tiling_size_ = max_tiling_size; return SUCCESS; } diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index 5960fbbc..9a0381cd 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -245,12 +245,16 @@ TEST_F(UtestSingleOpTask, test_dynamic_support) { AtomicAddrCleanOpTask atomic_task; TbeOpTask tbe_task; + tbe_task.arg_size_ = sizeof(void *) * 1; + tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); + atomic_task.arg_size_ = sizeof(void *) * 1; + atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); - tbe_task.arg_size_ = sizeof(void *); + tbe_task.arg_size_ = sizeof(void *) * 2; tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); - atomic_task.arg_size_ = sizeof(void *); + atomic_task.arg_size_ = sizeof(void *) * 2; atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS);