/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "hybrid/node_executor/rts/rts_node_task.h" #include "hybrid/node_executor/rts/rts_task_factory.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "graph/utils/node_utils.h" #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" namespace { constexpr uint8_t kSwitchPredIndex = 0; constexpr uint8_t kSwitchCompIndex = 1; const static std::map> kCompHandle = { {RT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value == comp_value; }}, {RT_NOT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value != comp_value; }}, {RT_GREATER, [](int64_t pred_value, int64_t comp_value) { return pred_value > comp_value; }}, {RT_GREATER_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value >= comp_value; }}, {RT_LESS, [](int64_t pred_value, int64_t comp_value) { return pred_value < comp_value; }}, {RT_LESS_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value <= comp_value; }}, }; } namespace ge { namespace hybrid { REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(LOOPCOND, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(NEXTITERATION, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(REFNEXTITERATION, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(EXIT, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(REFEXIT, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(LABELSET, LabelSetNodeTask); REGISTER_RTS_TASK_CREATOR(LABELGOTO, LabelGotoNodeTask); REGISTER_RTS_TASK_CREATOR(LABELGOTOEX, LabelGotoNodeTask); REGISTER_RTS_TASK_CREATOR(LABELSWITCH, LabelSwitchNodeTask); REGISTER_RTS_TASK_CREATOR(LABELSWITCHBYINDEX, LabelSwitchNodeTask); Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value) { auto tensor_value = task_context.GetInput(index); GE_CHECK_NOTNULL(tensor_value); auto tensor_desc = task_context.MutableInputDesc(index); GE_CHECK_NOTNULL(tensor_desc); auto data_type = tensor_desc->GetDataType(); switch (data_type) { #define CASE_TYPE(DT, VT) \ case (DT): { \ VT data_val{}; \ GE_CHK_STATUS_RET(tensor_value->CopyScalarValueToHost(data_val)); \ value = static_cast(data_val); \ break; \ } // Just accept index data type. CASE_TYPE(DT_INT32, int32_t) CASE_TYPE(DT_INT64, int64_t) #undef CASE_TYPE default: { GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str()); return UNSUPPORTED; } } return SUCCESS; } Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); const auto &node_state = task_context.GetNodeState(); node_state->RunStreamActive(); if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } GELOGI("[%s] Done executing successfully.", task_context.GetNodeName()); return SUCCESS; } Status StreamSwitchNodeTask::Init(const HybridModel &model, const NodePtr &node) { uint32_t value = 0; if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, value)) { GELOGE(INTERNAL_ERROR, "[%s] Get %s failed.", node->GetName().c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str()); return INTERNAL_ERROR; } rtCondition_t cond = static_cast(value); const auto it = kCompHandle.find(cond); if (it == kCompHandle.end()) { GELOGE(INTERNAL_ERROR, "[%s] Get Condition: %u handle failed.", node->GetName().c_str(), value); return INTERNAL_ERROR; } comp_func_ = it->second; GELOGD("[%s] Done initialization successfully, condition is %u.", node->GetName().c_str(), value); return SUCCESS; } Status StreamSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); GE_CHECK_NOTNULL(comp_func_); int64_t pred_value = 0; GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchPredIndex, pred_value)); int64_t comp_value = 0; GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchCompIndex, comp_value)); bool switch_idx = comp_func_(pred_value, comp_value); auto node_state = task_context.GetNodeState(); node_state->SetSwitchIndex(static_cast(switch_idx)); if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } GELOGI("[%s] Done executing successfully, pred value: %ld, comp value: %ld, switch index: %d.", task_context.GetNodeName(), pred_value, comp_value, static_cast(switch_idx)); return SUCCESS; } Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { int index = task_context.GetNodeState()->GetMergeIndex(); GELOGD("[%s] Start to execute, merge index: %d.", task_context.GetNodeName(), index); if (index < 0 || index >= task_context.NumInputs()) { GELOGE(INTERNAL_ERROR, "[%s] Invalid merge param, inputs num: %d, merge index: %d.", task_context.GetNodeName(), task_context.NumInputs(), index); return INTERNAL_ERROR; } const auto in_x = task_context.MutableInput(index); // x GE_CHECK_NOTNULL(in_x); GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(MERGE_DATA_OUTPUT, *in_x)); // y const auto out_y = task_context.MutableOutput(MERGE_INDEX_OUTPUT); // value_index GE_CHECK_NOTNULL(out_y); if (out_y->GetSize() > 0) { GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE_EX, task_context.GetStream())); } if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } task_context.GetNodeState()->SetMergeIndex(-1); // Invalidate for loop. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); return SUCCESS; } Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); const auto in_x = task_context.GetInput(0); // x GE_CHECK_NOTNULL(in_x); GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y const auto &node_state = task_context.GetNodeState(); if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { node_state->RunNextIteration(); } if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); return SUCCESS; } Status LabelSetNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); return UNSUPPORTED; } Status LabelGotoNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); return UNSUPPORTED; } Status LabelSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); return UNSUPPORTED; } } // namespace hybrid } // namespace ge