Browse Source

Pre Merge pull request !2041 from 梁昊/lh2

pull/2041/MERGE
梁昊 Gitee 3 years ago
parent
commit
41edf18efc
2 changed files with 36 additions and 4 deletions
  1. +11
    -3
      ge/graph/build/task_generator.cc
  2. +25
    -1
      tests/ut/ge/graph/build/task_generator_unittest.cc

+ 11
- 3
ge/graph/build/task_generator.cc View File

@@ -769,7 +769,8 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP
GELOGW("not find fp_op_desc.");
return SUCCESS;
}
GELOGI("Find fp_op_desc is %s, id is %ld", fp_op_desc->GetName().c_str(), fp_op_desc->GetId());
GEEVENT("Auto find graph[%s]'s fp node[%s], type[%s], index[%u], stream id[%ld]", graph->GetName().c_str(),
fp_op_desc->GetName().c_str(), fp_op_desc->GetType().c_str(), fp_op_desc->GetId(), fp_op_desc->GetStreamId());
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -866,8 +867,8 @@ Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const N
break;
}
}
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(),
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId());
GEEVENT("Auto find graph[%s]'s bp node[%s], type[%s], index[%u], stream id[%ld]", graph->GetName().c_str(),
bp_op_desc->GetName().c_str(), bp_op_desc->GetType().c_str(), bp_op_desc->GetId(), bp_op_desc->GetStreamId());
return SUCCESS;
}

@@ -954,6 +955,8 @@ Status TaskGenerator::GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint
GELOGW("First forward profiling op_index not set and FindFpOpIndex failed.");
return FAILED;
}
} else {
GEEVENT("Find fp node set by user, graph[%s], node[%s].", graph->GetName().c_str(), fp_point_str.c_str());
}

if (bp_point_str.empty()) {
@@ -962,6 +965,8 @@ Status TaskGenerator::GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint
GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed.");
return FAILED;
}
} else {
GEEVENT("Find bp node set by user, graph[%s], node[%s].", graph->GetName().c_str(), bp_point_str.c_str());
}

return SUCCESS;
@@ -1023,6 +1028,9 @@ Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, Profi
if (profiling_point.bp_index == 0 && train_graph) {
GELOGW("Last backward op name can't be found in graph for training trace.");
}
for (const auto end_idx : profiling_point.end_index) {
GEEVENT("Find end index: %u, graph: %s.", end_idx, graph->GetName().c_str());
}
return SUCCESS;
}



+ 25
- 1
tests/ut/ge/graph/build/task_generator_unittest.cc View File

@@ -32,6 +32,7 @@
#include "init/gelib.h"
#include "ge/opskernel_manager/ops_kernel_builder_manager.h"
#include "graph/build/task_generator.h"
#include "graph/ge_local_context.h"
#include "graph/manager/graph_mem_manager.h"
#include "graph/manager/graph_var_manager.h"
#undef protected
@@ -202,6 +203,11 @@ TEST_F(UtestTaskGeneratorTest, AutoFindBpOpIndex) {
output_desc->SetType("HcomAllReduce");
output_desc->SetName("hcom");
EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS);

setenv("PROFILING_MODE", "true", true);
EXPECT_EQ(task_generator.FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes), SUCCESS);
EXPECT_EQ(profiling_point.end_index.size(), 1);
EXPECT_EQ(*profiling_point.end_index.begin(), 3);
}

TEST_F(UtestTaskGeneratorTest, GenerateTask) {
@@ -230,4 +236,22 @@ TEST_F(UtestTaskGeneratorTest, GenerateTask) {
EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS);
EXPECT_EQ(task_def_list.size(), 1);
EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast<uintptr_t>(ops_kernel_info_store_ptr.get()));
}
}

TEST_F(UtestTaskGeneratorTest, SetFpBpByOptions) {
map<std::string, string> options_map = {
{ OPTION_EXEC_PROFILING_OPTIONS,
R"({"fp_point":"fp_node","bp_point":"bp_node"})"}};
ge::GEThreadLocalContext &context = GetThreadLocalContext();
context.SetGraphOption(options_map);

auto graph = BuildGraphBpProfiling();
TaskGenerator task_generator(nullptr, 0);
ProfilingPoint profiling_point;
vector<uint32_t> all_reduce_nodes;
std::string fp_str;
std::string bp_str;
EXPECT_EQ(task_generator.GetFpBpIndex(graph, profiling_point, all_reduce_nodes, fp_str, bp_str), SUCCESS);
EXPECT_EQ(fp_str, "fp_node");
EXPECT_EQ(bp_str, "bp_node");
}

Loading…
Cancel
Save