|
@@ -14,6 +14,7 @@ |
|
|
* limitations under the License. |
|
|
* limitations under the License. |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
|
|
|
#include <unordered_set> |
|
|
#include "if_subgraph_adapter.h" |
|
|
#include "if_subgraph_adapter.h" |
|
|
#include "subgraph_adapter_factory.h" |
|
|
#include "subgraph_adapter_factory.h" |
|
|
#include "common/util.h" |
|
|
#include "common/util.h" |
|
@@ -95,8 +96,8 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( |
|
|
|
|
|
|
|
|
domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, |
|
|
domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, |
|
|
std::set<std::string> &all_inputs) const { |
|
|
std::set<std::string> &all_inputs) const { |
|
|
std::set<std::string> graph_inputs; |
|
|
|
|
|
std::set<std::string> graph_outputs; |
|
|
|
|
|
|
|
|
std::unordered_set<std::string> graph_inputs; |
|
|
|
|
|
std::unordered_set<std::string> graph_outputs; |
|
|
for (int i = 0; i < onnx_graph.node_size(); i++) { |
|
|
for (int i = 0; i < onnx_graph.node_size(); i++) { |
|
|
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); |
|
|
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); |
|
|
for (int j = 0; j < node_proto->input_size(); j++) { |
|
|
for (int j = 0; j < node_proto->input_size(); j++) { |
|
@@ -106,10 +107,12 @@ domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx |
|
|
graph_outputs.emplace(node_proto->output(j)); |
|
|
graph_outputs.emplace(node_proto->output(j)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> graph_initializer_tensors; |
|
|
|
|
|
for (int32_t i = 0; i < onnx_graph.initializer_size(); i++) { |
|
|
|
|
|
graph_initializer_tensors.emplace(onnx_graph.initializer(i).name()); |
|
|
|
|
|
} |
|
|
for (const auto &input : graph_inputs) { |
|
|
for (const auto &input : graph_inputs) { |
|
|
std::set<std::string>::const_iterator out_iter = graph_outputs.find(input); |
|
|
|
|
|
if (out_iter == graph_outputs.end()) { |
|
|
|
|
|
|
|
|
if (graph_outputs.count(input) == 0 && graph_initializer_tensors.count(input) == 0) { |
|
|
// Record input node need to be constructed |
|
|
// Record input node need to be constructed |
|
|
all_inputs.emplace(input); |
|
|
all_inputs.emplace(input); |
|
|
} |
|
|
} |
|
|