diff --git a/ge/offline/main.cc b/ge/offline/main.cc index c0aa583c..d72040d7 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -201,6 +201,10 @@ DEFINE_string(op_compiler_cache_dir, "", "Optional; the path to cache operator c DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode"); +DEFINE_string(mdl_bank_path, "", "Optional; model bank path"); + +DEFINE_string(op_bank_path, "", "Optional; op bank path"); + class GFlagUtils { public: /** @@ -1017,6 +1021,8 @@ static void SetEnvForSingleOp(std::map &options) { options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir); options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir); options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode); + options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); + options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); } domi::Status GenerateSingleOp(const std::string& json_file_path) { @@ -1170,6 +1176,10 @@ domi::Status GenerateOmModel() { } options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); + + options.insert(std::pair(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path)); + + options.insert(std::pair(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path)); // set enable scope fusion passes SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); // print atc option map diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 4cee1b6f..b66c5b78 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -245,6 +245,12 @@ const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; // 0: close debug; 1: open TBE compiler; 2: open ccec compiler const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; +// Configure model bank path +const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; + +// Configure op bank path +const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; + // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; @@ -315,6 +321,9 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c static const char *const DEBUG_DIR = ge::DEBUG_DIR; static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; +static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str(); +static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str(); + // for interface: aclgrphBuildModel const std::set ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, @@ -347,7 +356,9 @@ const std::set ir_parser_suppported_options = {INPUT_FORMAT, OUT_NODES, COMPRESS_WEIGHT_CONF, ENABLE_SCOPE_FUSION_PASSES, - LOG_LEVEL}; + LOG_LEVEL, + MDL_BANK_PATH_FLAG, + OP_BANK_PATH_FLAG}; // for interface: aclgrphBuildInitialize const std::set global_options = {CORE_TYPE,