GitOrigin-RevId: cd155a1fcf
release-1.10
@@ -118,6 +118,17 @@ struct LITE_API Config { | |||||
}; | }; | ||||
/*! | /*! | ||||
* \brief Extra Configuration for a network | |||||
* | |||||
* \param disable_configure_by_model_info disable the configuration dumped with model, | |||||
* if set true, all configuration in the model will not apply, users should configure | |||||
* the network. | |||||
*/ | |||||
struct LITE_API ExtraConfig { | |||||
bool disable_configure_by_model_info = false; | |||||
}; | |||||
/*! | |||||
* \brief config the network input and output item | * \brief config the network input and output item | ||||
* | * | ||||
*/ | */ | ||||
@@ -275,6 +286,12 @@ public: | |||||
//! get static peak memory info showed by Graph visualization | //! get static peak memory info showed by Graph visualization | ||||
void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; | void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; | ||||
/** @brief the extra configuration | |||||
* | |||||
* @param extra_config the extra configuration to set into the network | |||||
*/ | |||||
void extra_configure(const ExtraConfig& extra_config); | |||||
public: | public: | ||||
friend class NetworkHelper; | friend class NetworkHelper; | ||||
@@ -288,6 +305,7 @@ private: | |||||
private: | private: | ||||
bool m_loaded = false; | bool m_loaded = false; | ||||
Config m_config; | Config m_config; | ||||
ExtraConfig m_extra_config; | |||||
NetworkIO m_network_io; | NetworkIO m_network_io; | ||||
std::unique_ptr<NetworkImplBase> m_impl; | std::unique_ptr<NetworkImplBase> m_impl; | ||||
std::string m_extra_info; | std::string m_extra_info; | ||||
@@ -114,6 +114,17 @@ typedef struct LiteConfig { | |||||
LITE_API LiteConfig* default_config(); | LITE_API LiteConfig* default_config(); | ||||
/*! | /*! | ||||
* \brief Exetra Configuration for a network | |||||
* | |||||
* \param disable_configure_by_model_info disable the configuration dumped with model, | |||||
* if set true, all configuration in the model will not apply, users should configure | |||||
* the network. | |||||
*/ | |||||
typedef struct LiteExtraConfig { | |||||
int disable_configure_by_model_info; | |||||
} LiteExtraConfig; | |||||
/*! | |||||
* \brief config the network input and output item | * \brief config the network input and output item | ||||
* | * | ||||
*/ | */ | ||||
@@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory( | |||||
const void* model_mem, size_t size, const LiteConfig config, | const void* model_mem, size_t size, const LiteConfig config, | ||||
LiteNetworkIO* ios); | LiteNetworkIO* ios); | ||||
/** @brief the extra configuration | |||||
* | |||||
* @param extra_config the extra configuration to set into the network | |||||
*/ | |||||
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
@@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { | |||||
return innner_io; | return innner_io; | ||||
} | } | ||||
lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) { | |||||
lite::ExtraConfig ret; | |||||
ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info; | |||||
return ret; | |||||
} | |||||
int LITE_make_default_network(LiteNetwork* network) { | int LITE_make_default_network(LiteNetwork* network) { | ||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
@@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory( | |||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) { | |||||
LITE_CAPI_BEGIN(); | |||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | |||||
static_cast<lite::Network*>(network)->extra_configure( | |||||
convert_extra_config(extra_config)); | |||||
LITE_CAPI_END(); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -134,6 +134,31 @@ class LiteConfig(Structure): | |||||
return data.__repr__() | return data.__repr__() | ||||
class LiteExtraConfig(Structure): | |||||
""" | |||||
Extra configuration when load and compile the graph | |||||
disable_configure_by_model_info: disable the configuration dumped with | |||||
model, if set true, all configuration in the model will not apply, users | |||||
should configure the network. | |||||
""" | |||||
_fields_ = [ | |||||
("disable_configure_by_model_info", c_int), | |||||
] | |||||
def __init__(self, disable_model_config=False): | |||||
self.disable_configure_by_model_info = disable_model_config | |||||
def __repr__(self): | |||||
data = { | |||||
"disable_configure_by_model_info": bool( | |||||
self.disable_configure_by_model_info | |||||
), | |||||
} | |||||
return data.__repr__() | |||||
class LiteIO(Structure): | class LiteIO(Structure): | ||||
""" | """ | ||||
config the network input and output item | config the network input and output item | ||||
@@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase): | |||||
"LITE_get_model_io_info_by_memory", | "LITE_get_model_io_info_by_memory", | ||||
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | ||||
), | ), | ||||
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), | |||||
] | ] | ||||
@@ -541,6 +567,12 @@ class LiteNetwork(object): | |||||
ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)] | ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)] | ||||
return ret_name | return ret_name | ||||
def extra_configure(self, extra_config): | |||||
""" | |||||
Extra Configuration to the network. | |||||
""" | |||||
self._api.LITE_extra_configure(self._network, extra_config) | |||||
def share_weights_with(self, src_network): | def share_weights_with(self, src_network): | ||||
""" | """ | ||||
share weights with the loaded network | share weights with the loaded network | ||||
@@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet): | |||||
network.load(model_path) | network.load(model_path) | ||||
self.do_forward(network) | self.do_forward(network) | ||||
def test_disable_model_config(self): | |||||
model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite") | |||||
network = LiteNetwork() | |||||
network.extra_configure(LiteExtraConfig(True)) | |||||
network.load(model_path) | |||||
self.do_forward(network) | |||||
def test_pack_cache_to_model(self): | def test_pack_cache_to_model(self): | ||||
model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | ||||
network = LiteNetwork() | network = LiteNetwork() | ||||
@@ -31,7 +31,6 @@ using namespace mgb; | |||||
LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft); | LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft); | ||||
void NetworkImplDft::set_config(const Config& config) { | void NetworkImplDft::set_config(const Config& config) { | ||||
m_user_config = std::make_unique<Config>(); | |||||
*m_user_config = config; | *m_user_config = config; | ||||
m_compnode_locator = to_compnode_locator(m_user_config->device_type); | m_compnode_locator = to_compnode_locator(m_user_config->device_type); | ||||
m_compnode_locator.device = config.device_id; | m_compnode_locator.device = config.device_id; | ||||
@@ -428,8 +427,11 @@ void NetworkImplDft::load_model( | |||||
global_layout_transform(); | global_layout_transform(); | ||||
//! some optimization option maybe invalid in some case, so here just | |||||
//! auto determine whether some options will apply. | |||||
adapt_option_valid(); | adapt_option_valid(); | ||||
//! find how many compnode the model has, this should call before update_io | |||||
cross_compnode_model_detect(); | cross_compnode_model_detect(); | ||||
//! update the IO of the network | //! update the IO of the network | ||||
@@ -496,7 +498,6 @@ void NetworkImplDft::finish() const { | |||||
} | } | ||||
void NetworkImplDft::set_io(const NetworkIO& network_io) { | void NetworkImplDft::set_io(const NetworkIO& network_io) { | ||||
m_network_io = std::make_unique<NetworkIOInner>(); | |||||
for (auto&& in : network_io.inputs) { | for (auto&& in : network_io.inputs) { | ||||
m_network_io->inputs.emplace_back(in); | m_network_io->inputs.emplace_back(in); | ||||
} | } | ||||
@@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase { | |||||
LITE_DYN_TYPE_OBJ_FINAL_DECL; | LITE_DYN_TYPE_OBJ_FINAL_DECL; | ||||
public: | public: | ||||
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); } | |||||
NetworkImplDft() { | |||||
m_load_config.comp_graph = mgb::ComputingGraph::make(); | |||||
m_user_config = std::make_unique<Config>(); | |||||
m_network_io = std::make_unique<NetworkIOInner>(); | |||||
} | |||||
using S = megdnn::param::ExecutionPolicy::Strategy; | using S = megdnn::param::ExecutionPolicy::Strategy; | ||||
using Var = mgb::cg::SymbolVar; | using Var = mgb::cg::SymbolVar; | ||||
//! set the config of the network, include: | //! set the config of the network, include: | ||||
@@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) { | |||||
ModelParser model_parser(model_data, size); | ModelParser model_parser(model_data, size); | ||||
//! parse the model info | //! parse the model info | ||||
if (model_parser.parse_model_info( | if (model_parser.parse_model_info( | ||||
m_config, m_network_io, separate_config_map, m_extra_info)) { | |||||
m_config, m_network_io, separate_config_map, m_extra_info, | |||||
!m_extra_config.disable_configure_by_model_info)) { | |||||
if (m_config.backend == LiteBackend::LITE_DEFAULT && | if (m_config.backend == LiteBackend::LITE_DEFAULT && | ||||
m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { | m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { | ||||
m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>( | m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>( | ||||
"parse_model")); | "parse_model")); | ||||
} | } | ||||
m_impl->set_config(m_config); | |||||
m_impl->set_io(m_network_io); | |||||
if (!m_extra_config.disable_configure_by_model_info) { | |||||
m_impl->set_config(m_config); | |||||
m_impl->set_io(m_network_io); | |||||
} | |||||
} | } | ||||
//! decryption the model | //! decryption the model | ||||
size_t model_length; | size_t model_length; | ||||
@@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const { | |||||
LITE_ERROR_HANDLER_END | LITE_ERROR_HANDLER_END | ||||
} | } | ||||
void Network::extra_configure(const ExtraConfig& extra_config) { | |||||
LITE_ERROR_HANDLER_BEGIN | |||||
if (!extra_config.disable_configure_by_model_info) { | |||||
LITE_ASSERT( | |||||
!m_loaded, | |||||
"disable_configure_by_model_info should be configured before model " | |||||
"loaded."); | |||||
} | |||||
m_extra_config = extra_config; | |||||
LITE_ERROR_HANDLER_END | |||||
} | |||||
/*********************** MGE special network function ***************/ | /*********************** MGE special network function ***************/ | ||||
void Runtime::set_cpu_threads_number( | void Runtime::set_cpu_threads_number( | ||||
@@ -43,7 +43,7 @@ void ModelParser::parse_header() { | |||||
bool ModelParser::parse_model_info( | bool ModelParser::parse_model_info( | ||||
Config& network_config, NetworkIO& network_io, | Config& network_config, NetworkIO& network_io, | ||||
std::unordered_map<std::string, LiteAny>& isolated_config_map, | std::unordered_map<std::string, LiteAny>& isolated_config_map, | ||||
std::string& extra_info) const { | |||||
std::string& extra_info, bool configure_valid) const { | |||||
//! no model info, no parse, direct return | //! no model info, no parse, direct return | ||||
if (m_is_bare_model || !m_info) { | if (m_is_bare_model || !m_info) { | ||||
return false; | return false; | ||||
@@ -78,7 +78,7 @@ bool ModelParser::parse_model_info( | |||||
} | } | ||||
} | } | ||||
//! parse ModelInfo::algo_policy | //! parse ModelInfo::algo_policy | ||||
if (m_info->algo_policy()) { | |||||
if (m_info->algo_policy() && configure_valid) { | |||||
size_t cache_length = m_info->algo_policy()->size(); | size_t cache_length = m_info->algo_policy()->size(); | ||||
const uint8_t* cache = m_info->algo_policy()->Data(); | const uint8_t* cache = m_info->algo_policy()->Data(); | ||||
if (m_info_cache_parse_func_name == "LITE_parse_cache") { | if (m_info_cache_parse_func_name == "LITE_parse_cache") { | ||||
@@ -93,6 +93,10 @@ bool ModelParser::parse_model_info( | |||||
} else { | } else { | ||||
LITE_THROW("opencl binary cache is not given"); | LITE_THROW("opencl binary cache is not given"); | ||||
} | } | ||||
} else { | |||||
LITE_THROW(ssprintf( | |||||
"model cache parse function of %s is not defined.", | |||||
m_info_cache_parse_func_name.c_str())); | |||||
} | } | ||||
} | } | ||||
return true; | return true; | ||||
@@ -25,7 +25,7 @@ public: | |||||
bool parse_model_info( | bool parse_model_info( | ||||
Config& network_config, NetworkIO& network_io, | Config& network_config, NetworkIO& network_io, | ||||
std::unordered_map<std::string, LiteAny>& isolated_config_map, | std::unordered_map<std::string, LiteAny>& isolated_config_map, | ||||
std::string& extra_info) const; | |||||
std::string& extra_info, bool configure_valid) const; | |||||
//! parse the model and decrypt the model | //! parse the model and decrypt the model | ||||
std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const; | std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const; | ||||
@@ -7,6 +7,8 @@ | |||||
#include "lite/global.h" | #include "lite/global.h" | ||||
#include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
#include "megbrain/utils/infile_persistent_cache.h" | |||||
#include "megbrain/utils/persistent_cache.h" | |||||
#include "test_common.h" | #include "test_common.h" | ||||
#include <string.h> | #include <string.h> | ||||
@@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) { | |||||
compare_lite_tensor<float>(output_tensor, result_mgb); | compare_lite_tensor<float>(output_tensor, result_mgb); | ||||
} | } | ||||
TEST(TestNetWorkOptions, DisableModelInfo) { | |||||
//! clear the cache set by other test | |||||
mgb::PersistentCache::inst().set_impl( | |||||
std::make_shared<mgb::InMemoryPersistentCache>()); | |||||
Config config; | |||||
auto tensor = get_input_data("./input_data.npy"); | |||||
std::string model_path = "./test_pack_cache_to_model.lite"; | |||||
std::string model_path2 = "./test_pack_cache_to_model.lite"; | |||||
std::string input_name = "data"; | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
network->extra_configure({true}); | |||||
Runtime::set_cpu_inplace_mode(network); | |||||
network->load_model(model_path); | |||||
//! the fast-run cache will not configure, so it is not support dump | |||||
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), false); | |||||
ASSERT_EQ(Runtime::is_cpu_inplace_mode(network), true); | |||||
std::shared_ptr<Network> network2 = std::make_shared<Network>(config); | |||||
network2->load_model(model_path2); | |||||
//! the fast-run cache is configured by the model information | |||||
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), true); | |||||
} | |||||
TEST(TestNetWorkOptions, FastRunIgnorBatch) { | TEST(TestNetWorkOptions, FastRunIgnorBatch) { | ||||
Config config; | Config config; | ||||
auto tensor = get_input_data("./input_data.npy"); | auto tensor = get_input_data("./input_data.npy"); | ||||