Browse Source

Fix review advice.

tags/v1.5.1
zhaozhixuan 3 years ago
parent
commit
21886e608e
2 changed files with 20 additions and 14 deletions
  1. +14
    -12
      ge/single_op/task/op_task.cc
  2. +6
    -2
      tests/ut/ge/single_op/single_op_task_unittest.cc

+ 14
- 12
ge/single_op/task/op_task.cc View File

@@ -293,25 +293,26 @@ Status TbeOpTask::UpdateNodeByShape(const vector<GeTensorDesc> &input_desc, cons
} }


Status TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { Status TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) {
node_ = node;
tiling_buffer_ = tiling_buffer;
max_tiling_size_ = max_tiling_size;
if (tiling_buffer != nullptr) { if (tiling_buffer != nullptr) {
uintptr_t *arg_base = reinterpret_cast<uintptr_t *>(args_.get());
size_t arg_num = arg_size_ / sizeof(void *);
uintptr_t *arg_base = nullptr;
size_t arg_num = 0;
GetIoAddr(arg_base, arg_num);
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc()); GE_CHECK_NOTNULL(node->GetOpDesc());
uint32_t inputs_num = node->GetOpDesc()->GetInputsSize(); uint32_t inputs_num = node->GetOpDesc()->GetInputsSize();
uint32_t outputs_num = node->GetOpDesc()->GetOutputsSize(); uint32_t outputs_num = node->GetOpDesc()->GetOutputsSize();
uint32_t workspace_nums = node->GetOpDesc()->GetWorkspace().size(); uint32_t workspace_nums = node->GetOpDesc()->GetWorkspace().size();
uint32_t tiling_index = inputs_num + outputs_num + workspace_nums; uint32_t tiling_index = inputs_num + outputs_num + workspace_nums;
if (arg_num == 0 || arg_num <= tiling_index) {
if (arg_num == 0 || arg_num < tiling_index) {
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.",
tiling_index, arg_num); tiling_index, arg_num);
return ACL_ERROR_GE_INTERNAL_ERROR; return ACL_ERROR_GE_INTERNAL_ERROR;
} }
arg_base[tiling_index] = reinterpret_cast<uintptr_t>(tiling_buffer); arg_base[tiling_index] = reinterpret_cast<uintptr_t>(tiling_buffer);
} }
node_ = node;
tiling_buffer_ = tiling_buffer;
max_tiling_size_ = max_tiling_size;
return SUCCESS; return SUCCESS;
} }


@@ -481,20 +482,21 @@ void TbeOpTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) {
} }


Status AtomicAddrCleanOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { Status AtomicAddrCleanOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) {
node_ = node;
tiling_buffer_ = tiling_buffer;
max_tiling_size_ = max_tiling_size;
if (tiling_buffer != nullptr) { if (tiling_buffer != nullptr) {
uintptr_t *arg_base = reinterpret_cast<uintptr_t *>(args_.get());
size_t arg_num = arg_size_ / sizeof(void *);
uintptr_t *arg_base = nullptr;
size_t arg_num = 0;
GetIoAddr(arg_base, arg_num);
uint32_t tiling_index = atomic_output_indices_.size(); uint32_t tiling_index = atomic_output_indices_.size();
if (arg_num == 0 || arg_num <= tiling_index) {
if (arg_num == 0 || arg_num < tiling_index) {
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.",
tiling_index, arg_num); tiling_index, arg_num);
return ACL_ERROR_GE_INTERNAL_ERROR; return ACL_ERROR_GE_INTERNAL_ERROR;
} }
arg_base[tiling_index] = reinterpret_cast<uintptr_t>(tiling_buffer); arg_base[tiling_index] = reinterpret_cast<uintptr_t>(tiling_buffer);
} }
node_ = node;
tiling_buffer_ = tiling_buffer;
max_tiling_size_ = max_tiling_size;
return SUCCESS; return SUCCESS;
} }




+ 6
- 2
tests/ut/ge/single_op/single_op_task_unittest.cc View File

@@ -245,12 +245,16 @@ TEST_F(UtestSingleOpTask, test_dynamic_support) {
AtomicAddrCleanOpTask atomic_task; AtomicAddrCleanOpTask atomic_task;
TbeOpTask tbe_task; TbeOpTask tbe_task;


tbe_task.arg_size_ = sizeof(void *) * 1;
tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]);
atomic_task.arg_size_ = sizeof(void *) * 1;
atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]);
ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR);
ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR);


tbe_task.arg_size_ = sizeof(void *);
tbe_task.arg_size_ = sizeof(void *) * 2;
tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]);
atomic_task.arg_size_ = sizeof(void *);
atomic_task.arg_size_ = sizeof(void *) * 2;
atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]);
ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS);
ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS);


Loading…
Cancel
Save