From 862bf660e085160739f33c0ecec760dd9907cfca Mon Sep 17 00:00:00 2001 From: lianghuikang <505519763@qq.com> Date: Thu, 20 May 2021 15:56:14 +0800 Subject: [PATCH] add modify_mixlist parameter. --- ge/client/ge_api.cc | 6 +++++ ge/ir_build/ge_ir_build.cc | 33 +++++++++++++++++++++++---- ge/ir_build/option_utils.cc | 25 ++++++++++++++++++++ ge/ir_build/option_utils.h | 4 ++++ ge/offline/main.cc | 27 ++++++++++++++++++++-- ge/session/inner_session.cc | 6 +++++ inc/external/ge/ge_api_types.h | 10 ++++++-- tests/ut/ge/graph_ir/ge_ir_build_unittest.cc | 22 ++++++++++++++++++ tests/ut/ge/session/ge_api_unittest.cc | 8 +++++++ tests/ut/ge/session/inner_session_unittest.cc | 9 ++++++++ 10 files changed, 142 insertions(+), 8 deletions(-) diff --git a/ge/client/ge_api.cc b/ge/client/ge_api.cc index 9cbd2d06..9c898ae3 100644 --- a/ge/client/ge_api.cc +++ b/ge/client/ge_api.cc @@ -34,6 +34,7 @@ #include "common/ge/tbe_plugin_manager.h" #include "common/util/error_manager/error_manager.h" #include "toolchain/plog.h" +#include "ir_build/option_utils.h" using domi::OpRegistry; using std::map; @@ -79,6 +80,11 @@ Status CheckOptionsValid(const std::map &options) { } } + // check modify_mixlist is valid + if (ge::CheckModifyMixlistParamValid(options) != ge::SUCCESS) { + return FAILED; + } + return SUCCESS; } diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index bd6a2d3a..c7e9522b 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -133,6 +133,15 @@ static graphStatus CheckGlobalOptions(std::map &global ? "force_fp16" : global_options[ge::ir_option::PRECISION_MODE]; global_options[ge::ir_option::PRECISION_MODE] = precision_mode; + // check modify_mixlist + std::string modify_mixlist = global_options.find(ge::ir_option::MODIFY_MIXLIST) == + global_options.end() + ? "" + : global_options[ge::ir_option::MODIFY_MIXLIST]; + if (ge::CheckModifyMixlistParamValid(precision_mode, modify_mixlist) != ge::SUCCESS) { + return ge::GRAPH_PARAM_INVALID; + } + global_options[ge::ir_option::MODIFY_MIXLIST] = modify_mixlist; return GRAPH_SUCCESS; } @@ -254,6 +263,8 @@ class Impl { omg_context_.user_attr_index_valid = false; }; ~Impl() { (void)generator_.Finalize(); }; + graphStatus GetSupportedOptions(const std::map &in, + std::map &out); graphStatus CheckOptions(const std::map &options); graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector &inputs); graphStatus UpdateDataOpAttr(const Graph &graph); @@ -440,19 +451,29 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { return GRAPH_SUCCESS; } -graphStatus Impl::CheckOptions(const std::map &options) { - for (auto &ele : options) { +graphStatus Impl::GetSupportedOptions(const std::map &in, + std::map &out) { + for (auto &ele : in) { auto it = ge::ir_option::ir_builder_suppported_options.find(ele.first); if (it == ge::ir_option::ir_builder_suppported_options.end()) { auto it_lx_fusion = ir_builder_supported_options_for_lx_fusion.find(ele.first); if (it_lx_fusion == ir_builder_supported_options_for_lx_fusion.end()) { GELOGE(GRAPH_PARAM_INVALID, "[Check][Options] unsupported option(%s), Please check!", - ele.first.c_str()); + ele.first.c_str()); return GRAPH_PARAM_INVALID; } } - options_.insert(ele); + out.insert(ele); } + return GRAPH_SUCCESS; +} + +graphStatus Impl::CheckOptions(const std::map &options) { + auto ret = GetSupportedOptions(options, options_); + if (ret != GRAPH_SUCCESS) { + return ret; + } + // Check options build_mode and build_step. std::string build_mode; auto it = options_.find(BUILD_MODE); @@ -480,6 +501,10 @@ graphStatus Impl::CheckOptions(const std::map &options if (it != options_.end() && (CheckDisableReuseMemoryParamValid(it->second) != GRAPH_SUCCESS)) { return GRAPH_PARAM_INVALID; } + // Check option modify_mixlist + if (ge::CheckModifyMixlistParamValid(options_) != GRAPH_SUCCESS) { + return GRAPH_PARAM_INVALID; + } // Check Input Format if (options_.find(kInputFormat) != options_.end()) { return CheckInputFormat(options_[kInputFormat]); diff --git a/ge/ir_build/option_utils.cc b/ge/ir_build/option_utils.cc index c23da519..ace6b420 100755 --- a/ge/ir_build/option_utils.cc +++ b/ge/ir_build/option_utils.cc @@ -785,6 +785,31 @@ Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std:: return ge::SUCCESS; } +Status CheckModifyMixlistParamValid(const std::map &options) { + std::string precision_mode; + auto it = options.find(ge::PRECISION_MODE); + if (it != options.end()) { + precision_mode = it->second; + } + it = options.find(ge::MODIFY_MIXLIST); + if (it != options.end() && CheckModifyMixlistParamValid(precision_mode, it->second) != ge::SUCCESS) { + return ge::PARAM_INVALID; + } + return ge::SUCCESS; +} + +Status CheckModifyMixlistParamValid(const std::string &precision_mode, const std::string &modify_mixlist) { + if (!modify_mixlist.empty() && precision_mode != "allow_mix_precision") { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ge::MODIFY_MIXLIST, modify_mixlist, kModifyMixlistError})); + GELOGE(ge::PARAM_INVALID, "[Check][ModifyMixlist] Failed, %s", kModifyMixlistError); + return ge::PARAM_INVALID; + } + GELOGI("Option set successfully, option_key=%s, option_value=%s", ge::MODIFY_MIXLIST.c_str(), modify_mixlist.c_str()); + + return ge::SUCCESS; +} + void PrintOptionMap(std::map &options, std::string tips) { for (auto iter = options.begin(); iter != options.end(); iter++) { std::string key = iter->first; diff --git a/ge/ir_build/option_utils.h b/ge/ir_build/option_utils.h index 44504e35..0b25bdf0 100644 --- a/ge/ir_build/option_utils.h +++ b/ge/ir_build/option_utils.h @@ -29,6 +29,8 @@ #include "graph/preprocess/multi_batch_options.h" namespace ge { +const char *const kModifyMixlistError = "modify_mixlist is assigned, please ensure that " + "precision_mode is assigned to 'allow_mix_precision'"; static std::set caffe_support_input_format = {"NCHW", "ND"}; static std::set tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; static std::set onnx_support_input_format = {"NCHW", "ND", "NCDHW"}; @@ -77,6 +79,8 @@ Status CheckInsertOpConfParamValid(const std::string insert_op_conf); Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory); Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); +Status CheckModifyMixlistParamValid(const std::map &options); +Status CheckModifyMixlistParamValid(const std::string &precision_mode, const std::string &modify_mixlist); Status CheckInputFormat(const string &input_format); Status CheckKeepTypeParamValid(const std::string &keep_dtype); void PrintOptionMap(std::map &options, std::string tips); diff --git a/ge/offline/main.cc b/ge/offline/main.cc index 8eb83010..12e39680 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -116,7 +116,7 @@ DEFINE_string(out_nodes, "", DEFINE_string(precision_mode, "force_fp16", "Optional; precision mode." - "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); + "Support force_fp16, force_fp32, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); DEFINE_string(keep_dtype, "", "Optional; config file to specify the precision used by the operator during compilation."); @@ -218,6 +218,8 @@ DEFINE_string(display_model_info, "0", "Optional; display model info"); DEFINE_string(device_id, "0", "Optional; user device id"); +DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path"); + class GFlagUtils { public: /** @@ -304,7 +306,7 @@ class GFlagUtils { "\"l1_optimize\", \"off_optimize\"\n" " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n" "\n[Operator Tuning]\n" - " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, " + " --precision_mode precision mode, support force_fp16(default), force_fp32, allow_mix_precision, " "allow_fp32_to_fp16, must_keep_origin_dtype.\n" " --keep_dtype Retains the precision of certain operators in inference " "scenarios by using a configuration file.\n" @@ -321,6 +323,7 @@ class GFlagUtils { " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " "(.json), and enable the CCE compiler -O0-g.\n" " 3: Disable debug, and keep generating kernel file (.o and .json)\n" + " --modify_mixlist Set the path of operator mixed precision configuration file.\n" "\n[Debug]\n" " --save_original_model Control whether to output original model. E.g.: true: output original model\n" " --log Generate log with level. Support debug, info, warning, error, null\n" @@ -369,6 +372,14 @@ class GFlagUtils { ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, ret = ge::FAILED, "[Check][ImplMode]check optypelist_for_implmode and op_select_implmode failed!"); + + if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"modify_mixlist", FLAGS_modify_mixlist.c_str(), + ge::kModifyMixlistError}); + ret = ge::FAILED; + } + // No output file information passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", @@ -1081,6 +1092,7 @@ static void SetEnvForSingleOp(std::map &options) { options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); options.emplace(ge::TUNE_DEVICE_IDS, FLAGS_device_id); + options.emplace(ge::MODIFY_MIXLIST, FLAGS_modify_mixlist); } domi::Status GenerateSingleOp(const std::string& json_file_path) { @@ -1093,9 +1105,18 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, return ge::FAILED, "[Check][ImplmodeParam] fail for input optypelist_for_implmode and op_select_implmode."); + if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"modify_mixlist", FLAGS_modify_mixlist.c_str(), + ge::kModifyMixlistError}); + return ge::FAILED; + } + std::map options; // need to be changed when ge.ini plan is done SetEnvForSingleOp(options); + // print single op option map + ge::PrintOptionMap(options, "single op option"); auto ret = ge::GELib::Initialize(options); if (ret != ge::SUCCESS) { @@ -1234,6 +1255,8 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info)); + options.insert(std::pair(string(ge::MODIFY_MIXLIST), FLAGS_modify_mixlist)); + // set enable scope fusion passes SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); // print atc option map diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index 39c87107..8248eecf 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -35,6 +35,7 @@ #include "graph/manager/graph_mem_manager.h" #include "graph/utils/tensor_adapter.h" #include "runtime/mem.h" +#include "ir_build/option_utils.h" namespace ge { namespace { @@ -81,6 +82,11 @@ Status InnerSession::Initialize() { return ret; } + // Check option modify_mixlist + if (ge::CheckModifyMixlistParamValid(all_options) != ge::SUCCESS) { + return FAILED; + } + UpdateThreadContext(std::map{}); // session device id set here diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 388f0fe0..fbd6c020 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -112,6 +112,7 @@ const char *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; const char *const PERFORMANCE_MODE = "ge.performance_mode"; +const char *const MODIFY_MIXLIST = "ge.exec.modify_mixlist"; } // namespace configure_option // Configure stream num by Session constructor options param, // its value should be int32_t type, default value is "1" @@ -323,6 +324,8 @@ const char *const INPUT_SHAPE_RANGE = "input_shape_range"; // high: need to recompile, high execute performance mode const std::string PERFORMANCE_MODE = "ge.performance_mode"; +const std::string MODIFY_MIXLIST = "ge.exec.modify_mixlist"; + // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; @@ -401,6 +404,7 @@ static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); static const char *const OP_BANK_UPDATE = ge::OP_BANK_UPDATE_FLAG.c_str(); static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); static const char *const PERFORMANCE_MODE = ge::PERFORMANCE_MODE.c_str(); +static const char *const MODIFY_MIXLIST = ge::MODIFY_MIXLIST.c_str(); // for interface: aclgrphBuildModel #ifdef __GNUC__ @@ -427,7 +431,8 @@ const std::set ir_builder_suppported_options = {INPUT_FORMAT, MDL_BANK_PATH, OP_BANK_PATH, OP_BANK_UPDATE, - PERFORMANCE_MODE}; + PERFORMANCE_MODE, + MODIFY_MIXLIST}; // for interface: aclgrphParse const std::set ir_parser_suppported_options = { @@ -453,7 +458,8 @@ const std::set global_options = {CORE_TYPE, OP_DEBUG_LEVEL, DEBUG_DIR, OP_COMPILER_CACHE_DIR, - OP_COMPILER_CACHE_MODE}; + OP_COMPILER_CACHE_MODE, + MODIFY_MIXLIST}; #endif } // namespace ir_option } // namespace ge diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc index fb4a5a8d..e14178d8 100644 --- a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -250,6 +250,17 @@ TEST(UtestIrCommon, check_dynamic_input_param_failed) { EXPECT_EQ(ret, ge::PARAM_INVALID); } +TEST(UtestIrCommon, check_modify_mixlist_param) { + std::string precision_mode = "allow_mix_precision"; + std::string modify_mixlist = "/mixlist.json"; + Status ret = CheckModifyMixlistParamValid(precision_mode, modify_mixlist); + EXPECT_EQ(ret, ge::SUCCESS); + + precision_mode = ""; + ret = CheckModifyMixlistParamValid(precision_mode, modify_mixlist); + EXPECT_EQ(ret, ge::PARAM_INVALID); +} + TEST(UtestIrCommon, check_compress_weight) { std::string enable_compress_weight = "true"; std::string compress_weight_conf="./"; @@ -349,4 +360,15 @@ TEST(UtestIrBuild, check_data_attr_index_succ_no_input_range) { ModelBufferData model; graphStatus ret = aclgrphBuildModel(graph, build_options, model); EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); +} + +TEST(UtestIrBuild, check_modify_mixlist_param) { + Graph graph = BuildIrGraph1(); + const std::map build_options = { + {"ge.exec.modify_mixlist", "/modify.json"} + }; + ModelBufferData model; + + auto ret = aclgrphBuildModel(graph, build_options, model); + EXPECT_EQ(ret, GRAPH_PARAM_INVALID); } \ No newline at end of file diff --git a/tests/ut/ge/session/ge_api_unittest.cc b/tests/ut/ge/session/ge_api_unittest.cc index 371efdfa..2cabc4a3 100644 --- a/tests/ut/ge/session/ge_api_unittest.cc +++ b/tests/ut/ge/session/ge_api_unittest.cc @@ -63,4 +63,12 @@ TEST_F(UtestGeApi, build_graph_success) { auto ret = session.BuildGraph(1, inputs); ASSERT_NE(ret, SUCCESS); } + +TEST_F(UtestGeApi, ge_initialize) { + std::map options = { + {ge::MODIFY_MIXLIST, "/mixlist.json"} + }; + auto ret = GEInitialize(options); + ASSERT_NE(ret, SUCCESS); +} } // namespace ge diff --git a/tests/ut/ge/session/inner_session_unittest.cc b/tests/ut/ge/session/inner_session_unittest.cc index 19f75d9f..ecad56d6 100644 --- a/tests/ut/ge/session/inner_session_unittest.cc +++ b/tests/ut/ge/session/inner_session_unittest.cc @@ -44,4 +44,13 @@ TEST_F(Utest_Inner_session, build_graph_success) { EXPECT_NE(ret, ge::SUCCESS); } +TEST_F(Utest_Inner_session, initialize) { + std::map options = { + {ge::MODIFY_MIXLIST, "/modify.json"} + }; + uint64_t session_id = 1; + InnerSession inner_session(session_id, options); + auto ret = inner_session.Initialize(); + EXPECT_NE(ret, ge::SUCCESS); +} } // namespace ge