From d63a8dad985aa5ce12a653d51a267d2ba65c7009 Mon Sep 17 00:00:00 2001 From: huanruizhi Date: Sat, 12 Mar 2022 10:02:53 +0800 Subject: [PATCH] onnx if bugfix --- parser/onnx/subgraph_adapter/if_subgraph_adapter.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc index 37df217..6248e0a 100644 --- a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "if_subgraph_adapter.h" #include "subgraph_adapter_factory.h" #include "common/util.h" @@ -95,8 +96,8 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set &all_inputs) const { - std::set graph_inputs; - std::set graph_outputs; + std::unordered_set graph_inputs; + std::unordered_set graph_outputs; for (int i = 0; i < onnx_graph.node_size(); i++) { ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); for (int j = 0; j < node_proto->input_size(); j++) { @@ -106,10 +107,12 @@ domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx graph_outputs.emplace(node_proto->output(j)); } } - + std::unordered_set graph_initializer_tensors; + for (int32_t i = 0; i < onnx_graph.initializer_size(); i++) { + graph_initializer_tensors.emplace(onnx_graph.initializer(i).name()); + } for (const auto &input : graph_inputs) { - std::set::const_iterator out_iter = graph_outputs.find(input); - if (out_iter == graph_outputs.end()) { + if (graph_outputs.count(input) == 0 && graph_initializer_tensors.count(input) == 0) { // Record input node need to be constructed all_inputs.emplace(input); }