GitOrigin-RevId: 910c8da19f
tags/v1.9.0
@@ -40,7 +40,6 @@ public: | |||||
void wait() override; | void wait() override; | ||||
//! enable global layout transform | //! enable global layout transform | ||||
void set_layout_transform(bool state) { enable_layout_transform = state; } | void set_layout_transform(bool state) { enable_layout_transform = state; } | ||||
//! get the network of lite model | //! get the network of lite model | ||||
@@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet): | |||||
fi = open("./model_afer_layoutTrans.mgb", "r") | fi = open("./model_afer_layoutTrans.mgb", "r") | ||||
fi.close() | fi.close() | ||||
os.remove("./model_afer_layoutTrans.mgb") | os.remove("./model_afer_layoutTrans.mgb") | ||||
def test_fast_run_and_global_layout_transform(self): | |||||
config_ = LiteConfig() | |||||
network = LiteNetwork(config_) | |||||
fast_run_cache = "./algo_cache" | |||||
global_layout_transform_model = "./model_afer_layoutTrans.mgb" | |||||
network.set_network_algo_policy( | |||||
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE | |||||
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED | |||||
) | |||||
network.enable_global_layout_transform() | |||||
network.load(self.model_path) | |||||
self.do_forward(network) | |||||
network.dump_layout_transform_model(global_layout_transform_model) | |||||
LiteGlobal.dump_persistent_cache(fast_run_cache) | |||||
fi = open(fast_run_cache, "r") | |||||
fi.close() | |||||
fi = open(global_layout_transform_model, "r") | |||||
fi.close() | |||||
LiteGlobal.set_persistent_cache(path=fast_run_cache) | |||||
self.do_forward(network) | |||||
os.remove(fast_run_cache) | |||||
os.remove(global_layout_transform_model) |
@@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda): | |||||
fi = open("./model_afer_layoutTrans.mgb", "r") | fi = open("./model_afer_layoutTrans.mgb", "r") | ||||
fi.close() | fi.close() | ||||
os.remove("./model_afer_layoutTrans.mgb") | os.remove("./model_afer_layoutTrans.mgb") | ||||
@require_cuda() | |||||
def test_fast_run_and_global_layout_transform(self): | |||||
config_ = LiteConfig() | |||||
config_.device_type = LiteDeviceType.LITE_CUDA | |||||
network = LiteNetwork(config_) | |||||
fast_run_cache = "./algo_cache" | |||||
global_layout_transform_model = "./model_afer_layoutTrans.mgb" | |||||
network.set_network_algo_policy( | |||||
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE | |||||
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED | |||||
) | |||||
network.enable_global_layout_transform() | |||||
network.load(self.model_path) | |||||
self.do_forward(network) | |||||
network.dump_layout_transform_model(global_layout_transform_model) | |||||
LiteGlobal.dump_persistent_cache(fast_run_cache) | |||||
fi = open(fast_run_cache, "r") | |||||
fi.close() | |||||
fi = open(global_layout_transform_model, "r") | |||||
fi.close() | |||||
LiteGlobal.set_persistent_cache(path=fast_run_cache) | |||||
self.do_forward(network) | |||||
os.remove(fast_run_cache) | |||||
os.remove(global_layout_transform_model) |
@@ -422,6 +422,8 @@ void NetworkImplDft::load_model( | |||||
m_load_result = m_loader->load(m_load_config, true); | m_load_result = m_loader->load(m_load_config, true); | ||||
modify_exection_policy(); | |||||
global_layout_transform(); | global_layout_transform(); | ||||
adapt_option_valid(); | adapt_option_valid(); | ||||
@@ -436,7 +438,6 @@ void NetworkImplDft::load_model( | |||||
} | } | ||||
void NetworkImplDft::compile_graph() { | void NetworkImplDft::compile_graph() { | ||||
modify_exection_policy(); | |||||
replace_dev_input_pass(); | replace_dev_input_pass(); | ||||
make_output_spec(); | make_output_spec(); | ||||
m_execute_func = m_load_result.graph_compile(m_output_spec); | m_execute_func = m_load_result.graph_compile(m_output_spec); | ||||
@@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy( | |||||
if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) { | if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) { | ||||
dst_strategy = dst_strategy | S::OPTIMIZED; | dst_strategy = dst_strategy | S::OPTIMIZED; | ||||
} | } | ||||
m_execution_policy = dst_strategy; | |||||
if (static_cast<uint32_t>(dst_strategy) != 0) | |||||
m_execution_policy = dst_strategy; | |||||
auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config; | auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config; | ||||
fast_run_config.binary_equal_between_batch = binary_equal_between_batch; | fast_run_config.binary_equal_between_batch = binary_equal_between_batch; | ||||
@@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy( | |||||
} | } | ||||
void NetworkImplDft::modify_exection_policy() { | void NetworkImplDft::modify_exection_policy() { | ||||
mgb::SymbolVarArray vars; | |||||
for (auto i : m_output_spec) { | |||||
vars.push_back(i.first); | |||||
} | |||||
if (static_cast<uint32_t>(m_execution_policy) != 0) | |||||
auto& vars = m_load_result.output_var_list; | |||||
if (static_cast<uint32_t>(m_execution_policy) != 0) { | |||||
mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy); | mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy); | ||||
} | |||||
} | } | ||||
//! set opr algorithm selection strategy in the network | //! set opr algorithm selection strategy in the network | ||||
@@ -289,21 +289,21 @@ namespace intl { | |||||
template <typename Opr> | template <typename Opr> | ||||
struct OprFormatModifier; | struct OprFormatModifier; | ||||
#define INST(_Opr) \ | |||||
template <> \ | |||||
struct OprFormatModifier<_Opr> { \ | |||||
using OprFormat = typename _Opr::Param::Format; \ | |||||
static VarNode* make( \ | |||||
OprFormat opr_format, const VarNodeArray& i, \ | |||||
const cg::OperatorNodeBase* opr_) { \ | |||||
MIDOUT_B(_Opr) \ | |||||
auto&& opr = opr_->cast_final_safe<_Opr>(); \ | |||||
auto param = opr.param(); \ | |||||
param.format = opr_format; \ | |||||
return OprWithPolicyMaker<_Opr>::make( \ | |||||
i, param, opr.execution_policy(), opr.config()); \ | |||||
MIDOUT_E \ | |||||
} \ | |||||
#define INST(_Opr) \ | |||||
template <> \ | |||||
struct OprFormatModifier<_Opr> { \ | |||||
using OprFormat = typename _Opr::Param::Format; \ | |||||
static VarNode* make( \ | |||||
OprFormat opr_format, const VarNodeArray& i, \ | |||||
const cg::OperatorNodeBase* opr_) { \ | |||||
MIDOUT_B(_Opr) \ | |||||
auto&& opr = opr_->cast_final_safe<_Opr>(); \ | |||||
auto param = opr.param(); \ | |||||
param.format = opr_format; \ | |||||
return OprWithPolicyMaker<_Opr>::make( \ | |||||
i, param, opr.execution_policy_transient(), opr.config()); \ | |||||
MIDOUT_E \ | |||||
} \ | |||||
}; | }; | ||||
INST(Convolution); | INST(Convolution); | ||||
INST(ConvBiasForward); | INST(ConvBiasForward); | ||||