Browse Source

onnx if bugfix

pull/475/head
huanruizhi 3 years ago
parent
commit
d63a8dad98
1 changed files with 8 additions and 5 deletions
  1. +8
    -5
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc

+ 8
- 5
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <unordered_set>
#include "if_subgraph_adapter.h"
#include "subgraph_adapter_factory.h"
#include "common/util.h"
@@ -95,8 +96,8 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(

domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph,
std::set<std::string> &all_inputs) const {
std::set<std::string> graph_inputs;
std::set<std::string> graph_outputs;
std::unordered_set<std::string> graph_inputs;
std::unordered_set<std::string> graph_outputs;
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
for (int j = 0; j < node_proto->input_size(); j++) {
@@ -106,10 +107,12 @@ domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx
graph_outputs.emplace(node_proto->output(j));
}
}

std::unordered_set<std::string> graph_initializer_tensors;
for (int32_t i = 0; i < onnx_graph.initializer_size(); i++) {
graph_initializer_tensors.emplace(onnx_graph.initializer(i).name());
}
for (const auto &input : graph_inputs) {
std::set<std::string>::const_iterator out_iter = graph_outputs.find(input);
if (out_iter == graph_outputs.end()) {
if (graph_outputs.count(input) == 0 && graph_initializer_tensors.count(input) == 0) {
// Record input node need to be constructed
all_inputs.emplace(input);
}


Loading…
Cancel
Save