Browse Source

add support for train_mode tune

tags/v1.3.0
gengchao4@huawei.com 4 years ago
parent
commit
d694f1dc21
4 changed files with 35 additions and 8 deletions
  1. +17
    -0
      ge/graph/build/model_builder.cc
  2. +9
    -7
      ge/graph/manager/graph_manager.cc
  3. +1
    -1
      ge/graph/manager/graph_manager.h
  4. +8
    -0
      ge/graph/passes/global_step_insert_pass.cc

+ 17
- 0
ge/graph/build/model_builder.cc View File

@@ -647,6 +647,14 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) {
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(),
atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());
if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) {
std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra attr" +
FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed";
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str());
return ge::FAILED;
}
}
}
if (tbe_kernel == nullptr) {
@@ -695,6 +703,15 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) {
GE_CHECK_NOTNULL(kernel_buffer.GetData());
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(),
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());
if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) {
std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra attr" +
FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed";
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str());
return ge::FAILED;
}
}
}
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue);


+ 9
- 7
ge/graph/manager/graph_manager.cc View File

@@ -1686,7 +1686,8 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return GE_GRAPH_OPTIONS_INVALID);

// ge.graphType
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
ret =
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag, options_.build_mode == BUILD_MODE_TUNING);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid");
return GE_GRAPH_OPTIONS_INVALID);
@@ -1728,20 +1729,21 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return SUCCESS;
}

Status GraphManager::ParseTrainGraphFlag(bool &options, bool &option) {
Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag) {
std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance();
if (ge_instance_ptr == nullptr) {
GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized");
option = false;
train_flag = false;
} else if (!ge_instance_ptr->isTrainMode()) {
option = false;
train_flag = false;
} else { // ge_instance_ptr->isTrainMode() is true
if (!options) {
// tune mode no need check
if (!run_flag && !tune_flag) {
GELOGE(GE_GRAPH_OPTIONS_INVALID,
"Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", options);
"Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag);
return GE_GRAPH_OPTIONS_INVALID;
}
option = true;
train_flag = true;
}
return SUCCESS;
}


+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -277,7 +277,7 @@ class GraphManager {

static Status ParseParallelNum(const std::string &parallel_num, const std::string &key, int &num);

static Status ParseTrainGraphFlag(bool &options, bool &option);
static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag);

static bool IsPerfLevelInvalid(int32_t perf_level);



+ 8
- 0
ge/graph/passes/global_step_insert_pass.cc View File

@@ -26,6 +26,8 @@
#include "common/ge/ge_util.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/passes/pass_utils.h"
#include "graph/ge_context.h"
#include "graph/tuning_utils.h"

namespace ge {
NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
@@ -72,6 +74,12 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
}

Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) {
std::string build_mode;
if (ge::GetContext().GetOption(ge::BUILD_MODE, build_mode) == GRAPH_SUCCESS && build_mode == BUILD_MODE_TUNING) {
GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(),
compute_graph->GetName().c_str());
return SUCCESS;
}
NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT);
if (output_node == nullptr) {
GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID());


Loading…
Cancel
Save