GitOrigin-RevId: fcbf945de5
HuaHua404-patch-4
@@ -114,6 +114,9 @@ struct LITE_API Options { | |||
* model is not pack json information data inside | |||
* | |||
* @param options configuration of Options | |||
* | |||
* @param auto_optimize_inference lite will detect the device information add | |||
* set the options heuristically | |||
*/ | |||
struct LITE_API Config { | |||
bool has_compression = false; | |||
@@ -122,6 +125,7 @@ struct LITE_API Config { | |||
LiteBackend backend = LiteBackend::LITE_DEFAULT; | |||
std::string bare_model_cryption_name = {}; | |||
Options options = {}; | |||
bool auto_optimize_inference = false; | |||
}; | |||
/*! | |||
@@ -100,6 +100,9 @@ extern LITE_API const LiteOptions default_option; | |||
* | |||
*\param has_compression flag whether the model is compressed, the compress | |||
*method will read form the model | |||
*\param auto_optimize_inference lite will detect the device information add | |||
* set the options heuristically | |||
*/ | |||
typedef struct LiteConfig { | |||
int has_compression; | |||
@@ -108,6 +111,7 @@ typedef struct LiteConfig { | |||
LiteBackend backend; | |||
const char* bare_model_cryption_name; | |||
LiteOptions options; | |||
int auto_optimize_inference; | |||
} LiteConfig; | |||
//! get default config | |||
@@ -42,7 +42,8 @@ LiteConfig default_config_t = { | |||
.device_type = LiteDeviceType::LITE_CPU, | |||
.backend = LiteBackend::LITE_DEFAULT, | |||
.bare_model_cryption_name = nullptr, | |||
.options = default_option}; | |||
.options = default_option, | |||
.auto_optimize_inference = false}; | |||
LiteConfig* default_config() { | |||
return &default_config_t; | |||
} | |||
@@ -133,6 +134,8 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) { | |||
lite_config.options.enable_nchw32 = c_config.options.enable_nchw32; | |||
lite_config.options.enable_nchw64 = c_config.options.enable_nchw64; | |||
lite_config.auto_optimize_inference = c_config.auto_optimize_inference; | |||
return lite_config; | |||
} | |||
@@ -171,15 +171,18 @@ class LiteConfig(Structure): | |||
options: configuration of Options | |||
auto_optimize_inference: lite will detect the device information add set the options heuristically | |||
Examples: | |||
.. code-block:: | |||
from megenginelite import * | |||
config = LiteConfig() | |||
config.has_compression = false | |||
config.has_compression = False | |||
config.device_type = LiteDeviceType.LITE_CPU | |||
config.backend = LiteBackend.LITE_DEFAULT | |||
config.bare_model_cryption_name = "AES_default".encode("utf-8") | |||
config.auto_optimize_inference = False | |||
""" | |||
_fields_ = [ | |||
@@ -189,6 +192,7 @@ class LiteConfig(Structure): | |||
("backend", c_int), | |||
("_bare_model_cryption_name", c_char_p), | |||
("options", LiteOptions), | |||
("auto_optimize_inference", c_int), | |||
] | |||
def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None): | |||
@@ -202,6 +206,7 @@ class LiteConfig(Structure): | |||
self.use_loader_dynamic_param = 0 | |||
self.has_compression = 0 | |||
self.backend = LiteBackend.LITE_DEFAULT | |||
self.auto_optimize_inference = 0 | |||
@property | |||
def bare_model_cryption_name(self): | |||
@@ -223,6 +228,7 @@ class LiteConfig(Structure): | |||
"backend": LiteBackend(self.backend), | |||
"bare_model_cryption_name": self.bare_model_cryption_name, | |||
"options": self.options, | |||
"auto_optimize_inference": self.auto_optimize_inference, | |||
} | |||
return data.__repr__() | |||
@@ -21,6 +21,10 @@ | |||
#include "megcore_opencl.h" | |||
#endif | |||
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||
#include "cpuinfo.h" | |||
#endif | |||
#include <fstream> | |||
#include <memory> | |||
#include <set> | |||
@@ -42,14 +46,7 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { | |||
LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded."); | |||
m_load_result = src_impl.m_loader->load(m_load_config, true); | |||
//! flag weather the mode is cross compnode model | |||
cross_compnode_model_detect(); | |||
//! update the IO of the network | |||
update_io(); | |||
//! replace the IO when there is device input or output | |||
compile_graph(); | |||
configure_after_loaded(); | |||
} | |||
void NetworkImplDft::application_config() { | |||
@@ -364,7 +361,7 @@ void NetworkImplDft::adapt_option_valid() { | |||
} | |||
} | |||
void NetworkImplDft::global_layout_transform() { | |||
void NetworkImplDft::layout_transform_optimization() { | |||
if (m_set_layout_transform) { | |||
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; | |||
auto output_var_array = mgb::gopt::layout_transform( | |||
@@ -382,6 +379,103 @@ void NetworkImplDft::global_layout_transform() { | |||
for (auto&& item : m_load_result.output_var_map) { | |||
item.second = out_var_map[item.second]; | |||
} | |||
} else if (m_user_config->auto_optimize_inference) { | |||
//! set model weight preprocess | |||
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true; | |||
LITE_LOG( | |||
"weight_preprocess is enabled, this maybe use more memory when " | |||
"infernece."); | |||
//! get the current format and data type of the model | |||
bool is_model_nchw = true; | |||
//! is any convolution is int8 | |||
bool is_model_int8 = false; | |||
//! is all convolution is float32 | |||
bool is_model_float32 = true; | |||
float conv_cnt = 0; | |||
float dimshuffle_cnt = 0; | |||
auto detect_int8_model = [&](const VarNode* input) { | |||
if (input->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 || | |||
input->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm) { | |||
is_model_int8 = true; | |||
is_model_float32 = false; | |||
} else if (input->dtype().enumv() == megdnn::DTypeEnum::Float32) { | |||
is_model_float32 = (is_model_float32 && true); | |||
} else { | |||
is_model_float32 = false; | |||
} | |||
}; | |||
cg::DepOprIter dep([&](cg::OperatorNodeBase* opr) { | |||
if (auto conv = opr->try_cast_final<opr::ConvolutionForward>()) { | |||
if (conv->param().format != megdnn::param::ConvBias::Format::NCHW) { | |||
is_model_nchw = false; | |||
} | |||
conv_cnt++; | |||
detect_int8_model(conv->input(0)); | |||
} else if (auto conv_bias = opr->try_cast_final<opr::ConvBias>()) { | |||
if (conv_bias->param().format != | |||
megdnn::param::ConvBias::Format::NCHW) { | |||
is_model_nchw = false; | |||
} | |||
conv_cnt++; | |||
detect_int8_model(conv->input(0)); | |||
} else if (auto dimshuffle = opr->try_cast_final<opr::Dimshuffle>()) { | |||
LITE_MARK_USED_VAR(dimshuffle); | |||
dimshuffle_cnt++; | |||
} | |||
}); | |||
for (auto&& i : m_load_result.output_var_list) | |||
dep.add(i); | |||
float radio_dimshuffle_conv = 0; | |||
if (conv_cnt > 0) { | |||
radio_dimshuffle_conv = dimshuffle_cnt / conv_cnt; | |||
} | |||
//! format optimize can only applied on nchw model, | |||
//! shufflenet like model will hurt the performance when using nchw88 or nchw44 | |||
//! format, here just heuristically decide the gate radio of | |||
//! dimshuffle and convolution | |||
if (!is_model_nchw || radio_dimshuffle_conv > 0.15f) { | |||
return; | |||
} | |||
//! determine the layout by the device information | |||
//! TODO: shufflenet like model use nchw88 or nchw44 will hurt the | |||
//! performance | |||
if (m_user_config->device_type == LITE_CPU) { | |||
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||
cpuinfo_initialize(); | |||
//! if all convolution and matmul data type is float32 | |||
if (is_model_float32) { | |||
//! if device is x86 | |||
//! if x86 support avx, use format nchw88 | |||
if (cpuinfo_has_x86_avx()) { | |||
m_load_config.comp_graph->options().graph_opt.enable_nchw88(); | |||
LITE_LOG("Configure model inference with nchw88 format."); | |||
} else if (cpuinfo_has_x86_sse2() && !cpuinfo_has_x86_sse3()) { | |||
//! if x86 only support sse2, use format nchw44 | |||
m_load_config.comp_graph->options().graph_opt.enable_nchw44(); | |||
LITE_LOG("Configure model inference with nchw44 format."); | |||
} else if (cpuinfo_has_arm_neon()) { | |||
//! if device is arm, use format nchw44 | |||
m_load_config.comp_graph->options().graph_opt.enable_nchw44(); | |||
LITE_LOG("Configure model inference with nchw44 format."); | |||
} | |||
} else if (is_model_int8) { | |||
//! if date type of convolution is int8 | |||
//! if device is arm and support dot, use nchw44-dot format | |||
if (cpuinfo_has_arm_neon() && cpuinfo_has_arm_neon_dot()) { | |||
m_load_config.comp_graph->options().graph_opt.enable_nchw44_dot(); | |||
LITE_LOG("Configure model inference with nchw44-dot format."); | |||
} else if (cpuinfo_has_arm_neon()) { | |||
//! if device is arm and do not support dot, use nchw44 format | |||
m_load_config.comp_graph->options().graph_opt.enable_nchw44(); | |||
LITE_LOG("Configure model inference with nchw44 format."); | |||
} | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
@@ -422,10 +516,13 @@ void NetworkImplDft::load_model( | |||
} | |||
m_load_result = m_loader->load(m_load_config, true); | |||
configure_after_loaded(); | |||
} | |||
void NetworkImplDft::configure_after_loaded() { | |||
modify_exection_policy(); | |||
global_layout_transform(); | |||
layout_transform_optimization(); | |||
//! some optimization option maybe invalid in some case, so here just | |||
//! auto determine whether some options will apply. | |||
@@ -178,8 +178,10 @@ private: | |||
//! call_back to the outputspec | |||
void make_output_spec(); | |||
//! do the global layout transform for the given platform target | |||
void global_layout_transform(); | |||
//! do layout transform for the given platform target, maybe the global | |||
//! layout optimization or heuristically choose the best layout according to | |||
//! the device information | |||
void layout_transform_optimization(); | |||
//! modify the execution policy | |||
void modify_exection_policy(); | |||
@@ -223,6 +225,9 @@ private: | |||
//! adapt option valid, it should call after update_io | |||
void adapt_option_valid(); | |||
//! configure and optimize network after loaded | |||
void configure_after_loaded(); | |||
private: | |||
bool m_async = false; | |||
bool m_is_cpu_inplace_mode = false; | |||
@@ -48,6 +48,35 @@ TEST(TestNetWorkOptions, no_var_sanity_check_and_record) { | |||
compare_lite_tensor<float>(output_tensor, result_mgb); | |||
} | |||
TEST(TestNetWorkOptions, auto_optimize_inference_layout) { | |||
Config config; | |||
auto tensor = get_input_data("./input_data.npy"); | |||
std::string model_path = "./shufflenet.mge"; | |||
std::string input_name = "data"; | |||
auto result_mgb = mgb_lar(model_path, config, input_name, tensor); | |||
config.auto_optimize_inference = true; | |||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||
network->load_model(model_path); | |||
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name); | |||
auto src_ptr = tensor->get_memory_ptr(); | |||
auto src_layout = tensor->get_layout(); | |||
input_tensor->reset(src_ptr, src_layout); | |||
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0); | |||
auto result_tensor = std::make_shared<Tensor>( | |||
LiteDeviceType::LITE_CPU, Layout{{1, 1000}, 2, LiteDataType::LITE_FLOAT}); | |||
void* out_data = result_tensor->get_memory_ptr(); | |||
output_tensor->reset(out_data, result_tensor->get_layout()); | |||
network->forward(); | |||
network->wait(); | |||
compare_lite_tensor<float>(output_tensor, result_mgb); | |||
} | |||
TEST(TestNetWorkOptions, const_shape) { | |||
Config config; | |||
auto tensor = get_input_data("./input_data.npy"); | |||