@@ -13,15 +13,15 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "host_cpu_engine.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "ge_local_engine/engine/host_cpu_engine.h" | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "register/op_kernel_registry.h" | #include "register/op_kernel_registry.h" | ||||
#include "register/host_cpu_context.h" | #include "register/host_cpu_context.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
#include "graph/utils/type_utils.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { | |||||
} | } | ||||
Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | ||||
std::string op_type; | |||||
auto status = GetOriginalType(node, op_type); | |||||
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); | |||||
const std::string op_type = NodeUtils::GetNodeType(node); | |||||
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | ||||
if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | ||||
@@ -288,10 +288,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. | |||||
node->GetOpDesc()->SetType(IDENTITY); | |||||
} | |||||
std::unique_ptr<NodeItem> new_node; | std::unique_ptr<NodeItem> new_node; | ||||
GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | ||||
@@ -14,10 +14,8 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "node_item.h" | |||||
#include <sstream> | |||||
#include "common/debug/log.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "hybrid/model/node_item.h" | |||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
@@ -98,8 +96,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
if (node_type == DATA) { | if (node_type == DATA) { | ||||
GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | ||||
} else if (node_type == kNodeTypeRetVal) { | } else if (node_type == kNodeTypeRetVal) { | ||||
@@ -409,8 +406,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | ||||
if (switch_index < switch_groups_.size()) { | if (switch_index < switch_groups_.size()) { | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace(node_item); | |||||
} else { | } else { | ||||
ctrl_send_.insert(node_item); | ctrl_send_.insert(node_item); | ||||
} | } | ||||
@@ -433,8 +430,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { | |||||
} | } | ||||
// this is StreamMerge node, node_item is StreamActive node. | // this is StreamMerge node, node_item is StreamActive node. | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace(node_item); | |||||
node_item->ctrl_send_.emplace(this); | node_item->ctrl_send_.emplace(this); | ||||
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | ||||
@@ -155,7 +155,7 @@ struct NodeItem { | |||||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::vector<std::set<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
@@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
GE_CHK_RT_RET(rtEventDestroy(evt)); | GE_CHK_RT_RET(rtEventDestroy(evt)); | ||||
} | } | ||||
GELOGI("rdma callback success."); | GELOGI("rdma callback success."); | ||||
return SUCCESS; | |||||
}; | }; | ||||
HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | ||||
@@ -17,13 +17,9 @@ | |||||
#include "hybrid/node_executor/rts/rts_node_executor.h" | #include "hybrid/node_executor/rts/rts_node_executor.h" | ||||
#include "hybrid/node_executor/rts/rts_task_factory.h" | #include "hybrid/node_executor/rts/rts_task_factory.h" | ||||
#include "common/debug/log.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/types.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "runtime/rt.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -133,8 +129,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||||
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("[%s] Load for local task.", node->GetName().c_str()); | GELOGD("[%s] Load for local task.", node->GetName().c_str()); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | ||||
if (rts_task == nullptr) { | if (rts_task == nullptr) { | ||||
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | ||||