|
|
@@ -19,7 +19,7 @@ namespace lar { |
|
|
|
template <> |
|
|
|
void FastRunOption::config_model_internel<ModelLite>( |
|
|
|
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { |
|
|
|
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { |
|
|
|
if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) { |
|
|
|
//! set the algo policy before model load |
|
|
|
using Strategy = ModelLite::Strategy; |
|
|
|
uint32_t strategy = 0; |
|
|
@@ -44,23 +44,17 @@ void FastRunOption::config_model_internel<ModelLite>( |
|
|
|
strategy; |
|
|
|
} |
|
|
|
auto lite_strategy = static_cast<Strategy>(strategy); |
|
|
|
model->set_lite_strategy(lite_strategy); |
|
|
|
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { |
|
|
|
auto&& lite_network = model->get_lite_network(); |
|
|
|
auto&& lite_strategy = model->get_lite_strategy(); |
|
|
|
//! set algo policy for model |
|
|
|
auto&& lite_network = model->get_lite_network(); |
|
|
|
lite::Runtime::set_network_algo_policy( |
|
|
|
lite_network, lite_strategy, share_batch_size, batch_binary_equal); |
|
|
|
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { |
|
|
|
if (!m_fast_run_cache.empty()) { |
|
|
|
if (!access(m_fast_run_cache.c_str(), F_OK)) { |
|
|
|
lite::set_persistent_cache(m_fast_run_cache); |
|
|
|
} else { |
|
|
|
lite::set_persistent_cache(m_fast_run_cache, true); |
|
|
|
} |
|
|
|
//! TODO:this is from mdl model settings but not matched settings in |
|
|
|
//! lite model |
|
|
|
// if (!enable_full_run && !enable_fast_run) |
|
|
|
// mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); |
|
|
|
} |
|
|
|
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { |
|
|
|
#if MGB_ENABLE_FASTRUN |
|
|
@@ -255,4 +249,4 @@ DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fast |
|
|
|
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); |
|
|
|
|
|
|
|
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); |
|
|
|
REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid); |
|
|
|
REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid); |