Browse Source

fix(lite/load_and_run): fix some bugs of load and run

GitOrigin-RevId: ffd578b97b
tags/v1.7.0.m1
Megvii Engine Team 3 years ago
parent
commit
e65e3f0579
8 changed files with 23 additions and 20 deletions
  1. +1
    -1
      lite/load_and_run/src/helpers/outdumper.cpp
  2. +1
    -1
      lite/load_and_run/src/models/model_lite.h
  3. +5
    -0
      lite/load_and_run/src/models/model_mdl.h
  4. +1
    -1
      lite/load_and_run/src/options/device_options.cpp
  5. +4
    -4
      lite/load_and_run/src/options/fastrun_options.cpp
  6. +8
    -9
      lite/load_and_run/src/options/io_options.cpp
  7. +1
    -1
      lite/load_and_run/src/options/optimize_options.cpp
  8. +2
    -3
      lite/load_and_run/src/options/plugin_options.cpp

+ 1
- 1
lite/load_and_run/src/helpers/outdumper.cpp View File

@@ -39,7 +39,7 @@ void OutputDumper::write_to_file() {
info.owner_inputs_info.c_str())); info.owner_inputs_info.c_str()));
mgb::debug::write_to_file( mgb::debug::write_to_file(
mgb::ssprintf( mgb::ssprintf(
"%s/run%zu-var %zd", dump_file.c_str(), m_run_id, info.id)
"%s/run%zu-var%zd", dump_file.c_str(), m_run_id, info.id)
.c_str(), .c_str(),
value); value);
} }


+ 1
- 1
lite/load_and_run/src/models/model_lite.h View File

@@ -40,7 +40,7 @@ public:
void wait() override; void wait() override;


//! get the network of lite model //! get the network of lite model
std::shared_ptr<lite::Network> get_lite_network() { return m_network; }
std::shared_ptr<lite::Network>& get_lite_network() { return m_network; }


//! get the config of lite model //! get the config of lite model
lite::Config& get_config() { return config; } lite::Config& get_config() { return config; }


+ 5
- 0
lite/load_and_run/src/models/model_mdl.h View File

@@ -67,13 +67,16 @@ public:


//! get data parser //! get data parser
DataParser& get_input_parser() { return parser; } DataParser& get_input_parser() { return parser; }

uint32_t get_testcase_num() { return testcase_num; } uint32_t get_testcase_num() { return testcase_num; }

std::vector<std::pair<std::string, mgb::HostTensorND*>>& get_test_input() { std::vector<std::pair<std::string, mgb::HostTensorND*>>& get_test_input() {
return test_input_tensors; return test_input_tensors;
} }


//! get output specified configuration //! get output specified configuration
mgb::ComputingGraph::OutputSpec& get_output_spec() { return m_output_spec; } mgb::ComputingGraph::OutputSpec& get_output_spec() { return m_output_spec; }

std::unique_ptr<mgb::cg::AsyncExecutable>& get_async_func() { return m_asyc_exec; } std::unique_ptr<mgb::cg::AsyncExecutable>& get_async_func() { return m_asyc_exec; }


void set_output_callback(std::vector<mgb::ComputingGraph::Callback>& cb) { void set_output_callback(std::vector<mgb::ComputingGraph::Callback>& cb) {
@@ -84,6 +87,7 @@ public:
m_callbacks[i] = cb[i]; m_callbacks[i] = cb[i];
} }
} }

#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
std::unique_ptr<mgb::GraphProfiler>& get_profiler() { return m_profiler; } std::unique_ptr<mgb::GraphProfiler>& get_profiler() { return m_profiler; }
void set_profiler() { void set_profiler() {
@@ -91,6 +95,7 @@ public:
std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get()); std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
} }
#endif #endif

void set_num_range_checker(float range) { void set_num_range_checker(float range) {
m_num_range_checker = std::make_unique<mgb::NumRangeChecker>( m_num_range_checker = std::make_unique<mgb::NumRangeChecker>(
m_load_config.comp_graph.get(), range); m_load_config.comp_graph.get(), range);


+ 1
- 1
lite/load_and_run/src/options/device_options.cpp View File

@@ -37,7 +37,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>(
} }
#endif #endif
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto network = model->get_lite_network();
auto&& network = model->get_lite_network();
if (enable_cpu_default) { if (enable_cpu_default) {
LITE_WARN("using cpu default device\n"); LITE_WARN("using cpu default device\n");
lite::Runtime::set_cpu_inplace_mode(network); lite::Runtime::set_cpu_inplace_mode(network);


+ 4
- 4
lite/load_and_run/src/options/fastrun_options.cpp View File

@@ -55,8 +55,8 @@ void FastRunOption::config_model_internel<ModelLite>(
auto lite_strategy = static_cast<Strategy>(strategy); auto lite_strategy = static_cast<Strategy>(strategy);
model->set_lite_strategy(lite_strategy); model->set_lite_strategy(lite_strategy);
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto lite_network = model->get_lite_network();
auto lite_strategy = model->get_lite_strategy();
auto&& lite_network = model->get_lite_network();
auto&& lite_strategy = model->get_lite_strategy();
//! set algo policy for model //! set algo policy for model
lite::Runtime::set_network_algo_policy( lite::Runtime::set_network_algo_policy(
lite_network, lite_strategy, share_batch_size, batch_binary_equal); lite_network, lite_strategy, share_batch_size, batch_binary_equal);
@@ -121,8 +121,8 @@ void FastRunOption::config_model_internel<ModelMdl>(
.fast_run_config.shared_batch_size = share_batch_size; .fast_run_config.shared_batch_size = share_batch_size;
} }
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto vars = model->get_mdl_load_result().output_var_list;
auto strategy = model->get_mdl_strategy();
auto& vars = model->get_mdl_load_result().output_var_list;
auto&& strategy = model->get_mdl_strategy();
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
// set algo cache path // set algo cache path
if (!m_fast_run_cache.empty()) { if (!m_fast_run_cache.empty()) {


+ 8
- 9
lite/load_and_run/src/options/io_options.cpp View File

@@ -20,8 +20,8 @@ template <>
void InputOption::config_model_internel<ModelLite>( void InputOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto parser = model->get_input_parser();
auto io = model->get_networkIO();
auto&& parser = model->get_input_parser();
auto&& io = model->get_networkIO();
for (size_t idx = 0; idx < data_path.size(); ++idx) { for (size_t idx = 0; idx < data_path.size(); ++idx) {
parser.feed(data_path[idx].c_str()); parser.feed(data_path[idx].c_str());
} }
@@ -32,9 +32,8 @@ void InputOption::config_model_internel<ModelLite>(
io.inputs.push_back({i.first, is_host}); io.inputs.push_back({i.first, is_host});
} }
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto config = model->get_config();
auto parser = model->get_input_parser();
auto network = model->get_lite_network();
auto&& parser = model->get_input_parser();
auto&& network = model->get_lite_network();


//! datd type map from mgb data type to lite data type //! datd type map from mgb data type to lite data type
std::map<megdnn::DTypeEnum, LiteDataType> type_map = { std::map<megdnn::DTypeEnum, LiteDataType> type_map = {
@@ -75,8 +74,8 @@ void InputOption::config_model_internel<ModelMdl>(
parser.feed(data_path[idx].c_str()); parser.feed(data_path[idx].c_str());
} }
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto parser = model->get_input_parser();
auto network = model->get_mdl_load_result();
auto&& parser = model->get_input_parser();
auto&& network = model->get_mdl_load_result();
auto tensormap = network.tensor_map; auto tensormap = network.tensor_map;
for (auto& i : parser.inputs) { for (auto& i : parser.inputs) {
mgb_assert( mgb_assert(
@@ -156,7 +155,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
} }
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (enable_bin_out_dump) { if (enable_bin_out_dump) {
auto load_result = model->get_mdl_load_result();
auto&& load_result = model->get_mdl_load_result();
out_dumper->set(load_result.output_var_list); out_dumper->set(load_result.output_var_list);


std::vector<mgb::ComputingGraph::Callback> cb; std::vector<mgb::ComputingGraph::Callback> cb;
@@ -166,7 +165,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
model->set_output_callback(cb); model->set_output_callback(cb);
} }
if (enable_copy_to_host) { if (enable_copy_to_host) {
auto load_result = model->get_mdl_load_result();
auto&& load_result = model->get_mdl_load_result();


std::vector<mgb::ComputingGraph::Callback> cb; std::vector<mgb::ComputingGraph::Callback> cb;
for (size_t i = 0; i < load_result.output_var_list.size(); i++) { for (size_t i = 0; i < load_result.output_var_list.size(); i++) {


+ 1
- 1
lite/load_and_run/src/options/optimize_options.cpp View File

@@ -365,7 +365,7 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>(
} }
if (workspace_limit < SIZE_MAX) { if (workspace_limit < SIZE_MAX) {
mgb_log_warn("set workspace limit to %ld", workspace_limit); mgb_log_warn("set workspace limit to %ld", workspace_limit);
auto output_spec = model->get_output_spec();
auto&& output_spec = model->get_output_spec();
mgb::SymbolVarArray vars; mgb::SymbolVarArray vars;
for (auto i : output_spec) { for (auto i : output_spec) {
vars.push_back(i.first); vars.push_back(i.first);


+ 2
- 3
lite/load_and_run/src/options/plugin_options.cpp View File

@@ -46,7 +46,7 @@ template <>
void PluginOption::config_model_internel<ModelMdl>( void PluginOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto config = model->get_mdl_config();
auto&& config = model->get_mdl_config();
if (range > 0) { if (range > 0) {
mgb_log_warn("enable number range check"); mgb_log_warn("enable number range check");
model->set_num_range_checker(float(range)); model->set_num_range_checker(float(range));
@@ -151,7 +151,7 @@ template <>
void DebugOption::format_and_print( void DebugOption::format_and_print(
const std::string& tablename, std::shared_ptr<ModelLite> model) { const std::string& tablename, std::shared_ptr<ModelLite> model) {
auto table = mgb::TextTable(tablename); auto table = mgb::TextTable(tablename);
auto network = model->get_lite_network();
auto&& network = model->get_lite_network();
table.padding(1); table.padding(1);
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor();


@@ -259,7 +259,6 @@ template <>
void DebugOption::config_model_internel<ModelMdl>( void DebugOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto config = model->get_mdl_config();
if (enable_verbose) { if (enable_verbose) {
mgb_log_warn("enable verbose"); mgb_log_warn("enable verbose");
mgb::set_log_level(mgb::LogLevel::DEBUG); mgb::set_log_level(mgb::LogLevel::DEBUG);


Loading…
Cancel
Save