GitOrigin-RevId: 910c8da19f
tags/v1.9.0
@@ -40,7 +40,6 @@ public: | |||
void wait() override; | |||
//! enable global layout transform | |||
void set_layout_transform(bool state) { enable_layout_transform = state; } | |||
//! get the network of lite model | |||
@@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet): | |||
fi = open("./model_afer_layoutTrans.mgb", "r") | |||
fi.close() | |||
os.remove("./model_afer_layoutTrans.mgb") | |||
def test_fast_run_and_global_layout_transform(self): | |||
config_ = LiteConfig() | |||
network = LiteNetwork(config_) | |||
fast_run_cache = "./algo_cache" | |||
global_layout_transform_model = "./model_afer_layoutTrans.mgb" | |||
network.set_network_algo_policy( | |||
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE | |||
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED | |||
) | |||
network.enable_global_layout_transform() | |||
network.load(self.model_path) | |||
self.do_forward(network) | |||
network.dump_layout_transform_model(global_layout_transform_model) | |||
LiteGlobal.dump_persistent_cache(fast_run_cache) | |||
fi = open(fast_run_cache, "r") | |||
fi.close() | |||
fi = open(global_layout_transform_model, "r") | |||
fi.close() | |||
LiteGlobal.set_persistent_cache(path=fast_run_cache) | |||
self.do_forward(network) | |||
os.remove(fast_run_cache) | |||
os.remove(global_layout_transform_model) |
@@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda): | |||
fi = open("./model_afer_layoutTrans.mgb", "r") | |||
fi.close() | |||
os.remove("./model_afer_layoutTrans.mgb") | |||
@require_cuda() | |||
def test_fast_run_and_global_layout_transform(self): | |||
config_ = LiteConfig() | |||
config_.device_type = LiteDeviceType.LITE_CUDA | |||
network = LiteNetwork(config_) | |||
fast_run_cache = "./algo_cache" | |||
global_layout_transform_model = "./model_afer_layoutTrans.mgb" | |||
network.set_network_algo_policy( | |||
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE | |||
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED | |||
) | |||
network.enable_global_layout_transform() | |||
network.load(self.model_path) | |||
self.do_forward(network) | |||
network.dump_layout_transform_model(global_layout_transform_model) | |||
LiteGlobal.dump_persistent_cache(fast_run_cache) | |||
fi = open(fast_run_cache, "r") | |||
fi.close() | |||
fi = open(global_layout_transform_model, "r") | |||
fi.close() | |||
LiteGlobal.set_persistent_cache(path=fast_run_cache) | |||
self.do_forward(network) | |||
os.remove(fast_run_cache) | |||
os.remove(global_layout_transform_model) |
@@ -422,6 +422,8 @@ void NetworkImplDft::load_model( | |||
m_load_result = m_loader->load(m_load_config, true); | |||
modify_exection_policy(); | |||
global_layout_transform(); | |||
adapt_option_valid(); | |||
@@ -436,7 +438,6 @@ void NetworkImplDft::load_model( | |||
} | |||
void NetworkImplDft::compile_graph() { | |||
modify_exection_policy(); | |||
replace_dev_input_pass(); | |||
make_output_spec(); | |||
m_execute_func = m_load_result.graph_compile(m_output_spec); | |||
@@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy( | |||
if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) { | |||
dst_strategy = dst_strategy | S::OPTIMIZED; | |||
} | |||
m_execution_policy = dst_strategy; | |||
if (static_cast<uint32_t>(dst_strategy) != 0) | |||
m_execution_policy = dst_strategy; | |||
auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config; | |||
fast_run_config.binary_equal_between_batch = binary_equal_between_batch; | |||
@@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy( | |||
} | |||
void NetworkImplDft::modify_exection_policy() { | |||
mgb::SymbolVarArray vars; | |||
for (auto i : m_output_spec) { | |||
vars.push_back(i.first); | |||
} | |||
if (static_cast<uint32_t>(m_execution_policy) != 0) | |||
auto& vars = m_load_result.output_var_list; | |||
if (static_cast<uint32_t>(m_execution_policy) != 0) { | |||
mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy); | |||
} | |||
} | |||
//! set opr algorithm selection strategy in the network | |||
@@ -289,21 +289,21 @@ namespace intl { | |||
template <typename Opr> | |||
struct OprFormatModifier; | |||
#define INST(_Opr) \ | |||
template <> \ | |||
struct OprFormatModifier<_Opr> { \ | |||
using OprFormat = typename _Opr::Param::Format; \ | |||
static VarNode* make( \ | |||
OprFormat opr_format, const VarNodeArray& i, \ | |||
const cg::OperatorNodeBase* opr_) { \ | |||
MIDOUT_B(_Opr) \ | |||
auto&& opr = opr_->cast_final_safe<_Opr>(); \ | |||
auto param = opr.param(); \ | |||
param.format = opr_format; \ | |||
return OprWithPolicyMaker<_Opr>::make( \ | |||
i, param, opr.execution_policy(), opr.config()); \ | |||
MIDOUT_E \ | |||
} \ | |||
#define INST(_Opr) \ | |||
template <> \ | |||
struct OprFormatModifier<_Opr> { \ | |||
using OprFormat = typename _Opr::Param::Format; \ | |||
static VarNode* make( \ | |||
OprFormat opr_format, const VarNodeArray& i, \ | |||
const cg::OperatorNodeBase* opr_) { \ | |||
MIDOUT_B(_Opr) \ | |||
auto&& opr = opr_->cast_final_safe<_Opr>(); \ | |||
auto param = opr.param(); \ | |||
param.format = opr_format; \ | |||
return OprWithPolicyMaker<_Opr>::make( \ | |||
i, param, opr.execution_policy_transient(), opr.config()); \ | |||
MIDOUT_E \ | |||
} \ | |||
}; | |||
INST(Convolution); | |||
INST(ConvBiasForward); | |||