|
|
@@ -20,6 +20,11 @@ |
|
|
|
#define protected public |
|
|
|
#include "generator/ge_generator.h" |
|
|
|
#include "graph/utils/tensor_utils.h" |
|
|
|
#include "graph/attr_value.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
#include "../graph/passes/graph_builder_utils.h" |
|
|
|
#include "../graph/manager/graph_manager.h |
|
|
|
|
|
|
|
using namespace std; |
|
|
|
|
|
|
@@ -31,6 +36,16 @@ class UtestGeGenerator : public testing::Test { |
|
|
|
void TearDown() {} |
|
|
|
}; |
|
|
|
|
|
|
|
namespace { |
|
|
|
ComputeGraphPtr MakeGraph() { |
|
|
|
ge::ut::GraphBuilder builder("graph"); |
|
|
|
auto data = builder.AddNode("data", "Data", 1, 1); |
|
|
|
auto addn1 = builder.AddNode("addn1", "AddN", 1, 1); |
|
|
|
builder.AddDataEdge(data, 0, addn1, 0); |
|
|
|
return builder.GetGraph(); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
/* |
|
|
|
TEST_F(UtestGeGenerator, test_build_single_op_offline) { |
|
|
|
GeTensorDesc tensor_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); |
|
|
@@ -71,4 +86,28 @@ TEST_F(UtestGeGenerator, test_build_single_op_online) { |
|
|
|
ModelBufferData model_buffer; |
|
|
|
EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_AIVECTOR, model_buffer), FAILED); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UtestGeGenerator, test_graph_manager) { |
|
|
|
GraphManager graph_manager; |
|
|
|
GraphPartitioner graph_partitioner; |
|
|
|
|
|
|
|
auto root_graph = MakeGraph(); |
|
|
|
auto sub_graph = MakeGraph(); |
|
|
|
root_graph->AddSubGraph(sub_graph); |
|
|
|
|
|
|
|
auto sgi = MakeShared<SubGraphInfo>(); |
|
|
|
// set engine name |
|
|
|
sgi->SetEngineName("AIcoreEngine"); |
|
|
|
sgi->SetSubGraph(sub_graph); |
|
|
|
|
|
|
|
auto sgi_gelocal = MakeShared<SubGraphInfo>(); |
|
|
|
// set engine name |
|
|
|
sgi_gelocal->SetEngineName("GELOCAL"); |
|
|
|
sgi_gelocal->SetSubGraph(sub_graph); |
|
|
|
|
|
|
|
graph_partitioner.graph_2_input_subgraph_[root_graph] = sgi_gelocal; |
|
|
|
graph_partitioner.graph_2_subgraph_list_.insert({root_graph, {sgi, sgi_gelocal}}); |
|
|
|
graph_partitioner.graph_2_subgraph_list_.insert({sub_graph, {sgi, sgi_gelocal}}); |
|
|
|
EXPECT_EQ(graph_manager.ConvertGraphToFile(root_graph, graph_partitioner, "./"), GRAPH_SUCCESS); |
|
|
|
} |
|
|
|
} // namespace ge |