Browse Source

fix(pylite): fix lite global layout transform and fast run conflict error

GitOrigin-RevId: 910c8da19f
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
5ebc9d50b7
5 changed files with 76 additions and 23 deletions
  1. +0
    -1
      lite/load_and_run/src/models/model_lite.h
  2. +26
    -0
      lite/pylite/test/test_network.py
  3. +28
    -0
      lite/pylite/test/test_network_device.py
  4. +7
    -7
      lite/src/mge/network_impl.cpp
  5. +15
    -15
      src/gopt/impl/global_layout_transform/opr_format_modifier.cpp

+ 0
- 1
lite/load_and_run/src/models/model_lite.h View File

@@ -40,7 +40,6 @@ public:
void wait() override;

//! enable global layout transform

void set_layout_transform(bool state) { enable_layout_transform = state; }

//! get the network of lite model


+ 26
- 0
lite/pylite/test/test_network.py View File

@@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet):
fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")

def test_fast_run_and_global_layout_transform(self):

config_ = LiteConfig()
network = LiteNetwork(config_)
fast_run_cache = "./algo_cache"
global_layout_transform_model = "./model_afer_layoutTrans.mgb"
network.set_network_algo_policy(
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED
)
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
network.dump_layout_transform_model(global_layout_transform_model)
LiteGlobal.dump_persistent_cache(fast_run_cache)
fi = open(fast_run_cache, "r")
fi.close()
fi = open(global_layout_transform_model, "r")
fi.close()

LiteGlobal.set_persistent_cache(path=fast_run_cache)
self.do_forward(network)

os.remove(fast_run_cache)
os.remove(global_layout_transform_model)

+ 28
- 0
lite/pylite/test/test_network_device.py View File

@@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda):
fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")

@require_cuda()
def test_fast_run_and_global_layout_transform(self):

config_ = LiteConfig()
config_.device_type = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config_)
fast_run_cache = "./algo_cache"
global_layout_transform_model = "./model_afer_layoutTrans.mgb"
network.set_network_algo_policy(
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED
)
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
network.dump_layout_transform_model(global_layout_transform_model)
LiteGlobal.dump_persistent_cache(fast_run_cache)
fi = open(fast_run_cache, "r")
fi.close()
fi = open(global_layout_transform_model, "r")
fi.close()

LiteGlobal.set_persistent_cache(path=fast_run_cache)
self.do_forward(network)

os.remove(fast_run_cache)
os.remove(global_layout_transform_model)

+ 7
- 7
lite/src/mge/network_impl.cpp View File

@@ -422,6 +422,8 @@ void NetworkImplDft::load_model(

m_load_result = m_loader->load(m_load_config, true);

modify_exection_policy();

global_layout_transform();

adapt_option_valid();
@@ -436,7 +438,6 @@ void NetworkImplDft::load_model(
}

void NetworkImplDft::compile_graph() {
modify_exection_policy();
replace_dev_input_pass();
make_output_spec();
m_execute_func = m_load_result.graph_compile(m_output_spec);
@@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy(
if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
dst_strategy = dst_strategy | S::OPTIMIZED;
}
m_execution_policy = dst_strategy;
if (static_cast<uint32_t>(dst_strategy) != 0)
m_execution_policy = dst_strategy;

auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
@@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy(
}

void NetworkImplDft::modify_exection_policy() {
mgb::SymbolVarArray vars;
for (auto i : m_output_spec) {
vars.push_back(i.first);
}
if (static_cast<uint32_t>(m_execution_policy) != 0)
auto& vars = m_load_result.output_var_list;
if (static_cast<uint32_t>(m_execution_policy) != 0) {
mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
}
}

//! set opr algorithm selection strategy in the network


+ 15
- 15
src/gopt/impl/global_layout_transform/opr_format_modifier.cpp View File

@@ -289,21 +289,21 @@ namespace intl {
template <typename Opr>
struct OprFormatModifier;

#define INST(_Opr) \
template <> \
struct OprFormatModifier<_Opr> { \
using OprFormat = typename _Opr::Param::Format; \
static VarNode* make( \
OprFormat opr_format, const VarNodeArray& i, \
const cg::OperatorNodeBase* opr_) { \
MIDOUT_B(_Opr) \
auto&& opr = opr_->cast_final_safe<_Opr>(); \
auto param = opr.param(); \
param.format = opr_format; \
return OprWithPolicyMaker<_Opr>::make( \
i, param, opr.execution_policy(), opr.config()); \
MIDOUT_E \
} \
#define INST(_Opr) \
template <> \
struct OprFormatModifier<_Opr> { \
using OprFormat = typename _Opr::Param::Format; \
static VarNode* make( \
OprFormat opr_format, const VarNodeArray& i, \
const cg::OperatorNodeBase* opr_) { \
MIDOUT_B(_Opr) \
auto&& opr = opr_->cast_final_safe<_Opr>(); \
auto param = opr.param(); \
param.format = opr_format; \
return OprWithPolicyMaker<_Opr>::make( \
i, param, opr.execution_policy_transient(), opr.config()); \
MIDOUT_E \
} \
};
INST(Convolution);
INST(ConvBiasForward);


Loading…
Cancel
Save