Browse Source

bank path

tags/v1.1.0
baker 4 years ago
parent
commit
981be3bef0
2 changed files with 22 additions and 1 deletions
  1. +10
    -0
      ge/offline/main.cc
  2. +12
    -1
      inc/external/ge/ge_api_types.h

+ 10
- 0
ge/offline/main.cc View File

@@ -201,6 +201,10 @@ DEFINE_string(op_compiler_cache_dir, "", "Optional; the path to cache operator c

DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode");

DEFINE_string(mdl_bank_path, "", "Optional; model bank path");

DEFINE_string(op_bank_path, "", "Optional; op bank path");

class GFlagUtils {
public:
/**
@@ -1013,6 +1017,8 @@ static void SetEnvForSingleOp(std::map<string, string> &options) {
options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir);
options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir);
options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode);
options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path);
options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path);
}

domi::Status GenerateSingleOp(const std::string& json_file_path) {
@@ -1166,6 +1172,10 @@ domi::Status GenerateOmModel() {
}

options.insert(std::pair<string, string>(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level)));

options.insert(std::pair<string, string>(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path));

options.insert(std::pair<string, string>(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path));
// set enable scope fusion passes
SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes);
// print atc option map


+ 12
- 1
inc/external/ge/ge_api_types.h View File

@@ -245,6 +245,12 @@ const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16";
// 0: close debug; 1: open TBE compiler; 2: open ccec compiler
const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel";

// Configure model bank path
const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path";

// Configure op bank path
const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path";

// Graph run mode
enum GraphRunMode { PREDICTION = 0, TRAIN };

@@ -315,6 +321,9 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c
static const char *const DEBUG_DIR = ge::DEBUG_DIR;
static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR;
static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE;
static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str();
static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str();

// for interface: aclgrphBuildModel
const std::set<std::string> ir_builder_suppported_options = {
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP,
@@ -336,7 +345,9 @@ const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT,
OUT_NODES,
COMPRESS_WEIGHT_CONF,
ENABLE_SCOPE_FUSION_PASSES,
LOG_LEVEL};
LOG_LEVEL,
MDL_BANK_PATH_FLAG,
OP_BANK_PATH_FLAG};

// for interface: aclgrphBuildInitialize
const std::set<std::string> global_options = {CORE_TYPE,


Loading…
Cancel
Save