/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph_builder_utils.h" #include "inc/external/graph/operator.h" #include "inc/external/graph/operator_factory.h" #include "graph/utils/graph_utils.h" namespace ge { namespace st { NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, DataType data_type, std::vector shape) { auto tensor_desc = std::make_shared(); tensor_desc->SetShape(GeShape(std::move(shape))); tensor_desc->SetFormat(format); tensor_desc->SetDataType(data_type); auto op_desc = std::make_shared(name, type); for (int i = 0; i < in_cnt; ++i) { op_desc->AddInputDesc(tensor_desc->Clone()); } for (int i = 0; i < out_cnt; ++i) { op_desc->AddOutputDesc(tensor_desc->Clone()); } return graph_->AddNode(op_desc); } void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); } void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); } } // namespace st } // namespace ge