Browse Source

feat(lite): add disable configure by model info interface

GitOrigin-RevId: cd155a1fcf
release-1.10
Megvii Engine Team 3 years ago
parent
commit
d610c98756
11 changed files with 146 additions and 9 deletions
  1. +18
    -0
      lite/include/lite/network.h
  2. +17
    -0
      lite/lite-c/include/lite-c/network_c.h
  3. +14
    -0
      lite/lite-c/src/network.cpp
  4. +32
    -0
      lite/pylite/megenginelite/network.py
  5. +7
    -0
      lite/pylite/test/test_network.py
  6. +3
    -2
      lite/src/mge/network_impl.cpp
  7. +5
    -1
      lite/src/mge/network_impl.h
  8. +18
    -3
      lite/src/network.cpp
  9. +6
    -2
      lite/src/parse_model/model_parser.cpp
  10. +1
    -1
      lite/src/parse_model/model_parser.h
  11. +25
    -0
      lite/test/test_network_options.cpp

+ 18
- 0
lite/include/lite/network.h View File

@@ -118,6 +118,17 @@ struct LITE_API Config {
}; };


/*! /*!
* \brief Extra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
struct LITE_API ExtraConfig {
bool disable_configure_by_model_info = false;
};

/*!
* \brief config the network input and output item * \brief config the network input and output item
* *
*/ */
@@ -275,6 +286,12 @@ public:
//! get static peak memory info showed by Graph visualization //! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const;


/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
void extra_configure(const ExtraConfig& extra_config);

public: public:
friend class NetworkHelper; friend class NetworkHelper;


@@ -288,6 +305,7 @@ private:
private: private:
bool m_loaded = false; bool m_loaded = false;
Config m_config; Config m_config;
ExtraConfig m_extra_config;
NetworkIO m_network_io; NetworkIO m_network_io;
std::unique_ptr<NetworkImplBase> m_impl; std::unique_ptr<NetworkImplBase> m_impl;
std::string m_extra_info; std::string m_extra_info;


+ 17
- 0
lite/lite-c/include/lite-c/network_c.h View File

@@ -114,6 +114,17 @@ typedef struct LiteConfig {
LITE_API LiteConfig* default_config(); LITE_API LiteConfig* default_config();


/*! /*!
* \brief Exetra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
typedef struct LiteExtraConfig {
int disable_configure_by_model_info;
} LiteExtraConfig;

/*!
* \brief config the network input and output item * \brief config the network input and output item
* *
*/ */
@@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config, const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios); LiteNetworkIO* ios);


/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config);

#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif


+ 14
- 0
lite/lite-c/src/network.cpp View File

@@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
return innner_io; return innner_io;
} }


lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) {
lite::ExtraConfig ret;
ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info;
return ret;
}

int LITE_make_default_network(LiteNetwork* network) { int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
@@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory(
LITE_CAPI_END(); LITE_CAPI_END();
} }


LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
static_cast<lite::Network*>(network)->extra_configure(
convert_extra_config(extra_config));
LITE_CAPI_END();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 32
- 0
lite/pylite/megenginelite/network.py View File

@@ -134,6 +134,31 @@ class LiteConfig(Structure):
return data.__repr__() return data.__repr__()




class LiteExtraConfig(Structure):
"""
Extra configuration when load and compile the graph

disable_configure_by_model_info: disable the configuration dumped with
model, if set true, all configuration in the model will not apply, users
should configure the network.
"""

_fields_ = [
("disable_configure_by_model_info", c_int),
]

def __init__(self, disable_model_config=False):
self.disable_configure_by_model_info = disable_model_config

def __repr__(self):
data = {
"disable_configure_by_model_info": bool(
self.disable_configure_by_model_info
),
}
return data.__repr__()


class LiteIO(Structure): class LiteIO(Structure):
""" """
config the network input and output item config the network input and output item
@@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase):
"LITE_get_model_io_info_by_memory", "LITE_get_model_io_info_by_memory",
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
), ),
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
] ]




@@ -541,6 +567,12 @@ class LiteNetwork(object):
ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)] ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
return ret_name return ret_name


def extra_configure(self, extra_config):
"""
Extra Configuration to the network.
"""
self._api.LITE_extra_configure(self._network, extra_config)

def share_weights_with(self, src_network): def share_weights_with(self, src_network):
""" """
share weights with the loaded network share weights with the loaded network


+ 7
- 0
lite/pylite/test/test_network.py View File

@@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet):
network.load(model_path) network.load(model_path)
self.do_forward(network) self.do_forward(network)


def test_disable_model_config(self):
model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
network = LiteNetwork()
network.extra_configure(LiteExtraConfig(True))
network.load(model_path)
self.do_forward(network)

def test_pack_cache_to_model(self): def test_pack_cache_to_model(self):
model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite")
network = LiteNetwork() network = LiteNetwork()


+ 3
- 2
lite/src/mge/network_impl.cpp View File

@@ -31,7 +31,6 @@ using namespace mgb;
LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft); LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);


void NetworkImplDft::set_config(const Config& config) { void NetworkImplDft::set_config(const Config& config) {
m_user_config = std::make_unique<Config>();
*m_user_config = config; *m_user_config = config;
m_compnode_locator = to_compnode_locator(m_user_config->device_type); m_compnode_locator = to_compnode_locator(m_user_config->device_type);
m_compnode_locator.device = config.device_id; m_compnode_locator.device = config.device_id;
@@ -428,8 +427,11 @@ void NetworkImplDft::load_model(


global_layout_transform(); global_layout_transform();


//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
adapt_option_valid(); adapt_option_valid();


//! find how many compnode the model has, this should call before update_io
cross_compnode_model_detect(); cross_compnode_model_detect();


//! update the IO of the network //! update the IO of the network
@@ -496,7 +498,6 @@ void NetworkImplDft::finish() const {
} }


void NetworkImplDft::set_io(const NetworkIO& network_io) { void NetworkImplDft::set_io(const NetworkIO& network_io) {
m_network_io = std::make_unique<NetworkIOInner>();
for (auto&& in : network_io.inputs) { for (auto&& in : network_io.inputs) {
m_network_io->inputs.emplace_back(in); m_network_io->inputs.emplace_back(in);
} }


+ 5
- 1
lite/src/mge/network_impl.h View File

@@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase {
LITE_DYN_TYPE_OBJ_FINAL_DECL; LITE_DYN_TYPE_OBJ_FINAL_DECL;


public: public:
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); }
NetworkImplDft() {
m_load_config.comp_graph = mgb::ComputingGraph::make();
m_user_config = std::make_unique<Config>();
m_network_io = std::make_unique<NetworkIOInner>();
}
using S = megdnn::param::ExecutionPolicy::Strategy; using S = megdnn::param::ExecutionPolicy::Strategy;
using Var = mgb::cg::SymbolVar; using Var = mgb::cg::SymbolVar;
//! set the config of the network, include: //! set the config of the network, include:


+ 18
- 3
lite/src/network.cpp View File

@@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) {
ModelParser model_parser(model_data, size); ModelParser model_parser(model_data, size);
//! parse the model info //! parse the model info
if (model_parser.parse_model_info( if (model_parser.parse_model_info(
m_config, m_network_io, separate_config_map, m_extra_info)) {
m_config, m_network_io, separate_config_map, m_extra_info,
!m_extra_config.disable_configure_by_model_info)) {
if (m_config.backend == LiteBackend::LITE_DEFAULT && if (m_config.backend == LiteBackend::LITE_DEFAULT &&
m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) {
m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>( m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>(
"parse_model")); "parse_model"));
} }
m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
if (!m_extra_config.disable_configure_by_model_info) {
m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
}
} }
//! decryption the model //! decryption the model
size_t model_length; size_t model_length;
@@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }


void Network::extra_configure(const ExtraConfig& extra_config) {
LITE_ERROR_HANDLER_BEGIN
if (!extra_config.disable_configure_by_model_info) {
LITE_ASSERT(
!m_loaded,
"disable_configure_by_model_info should be configured before model "
"loaded.");
}
m_extra_config = extra_config;
LITE_ERROR_HANDLER_END
}

/*********************** MGE special network function ***************/ /*********************** MGE special network function ***************/


void Runtime::set_cpu_threads_number( void Runtime::set_cpu_threads_number(


+ 6
- 2
lite/src/parse_model/model_parser.cpp View File

@@ -43,7 +43,7 @@ void ModelParser::parse_header() {
bool ModelParser::parse_model_info( bool ModelParser::parse_model_info(
Config& network_config, NetworkIO& network_io, Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map, std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const {
std::string& extra_info, bool configure_valid) const {
//! no model info, no parse, direct return //! no model info, no parse, direct return
if (m_is_bare_model || !m_info) { if (m_is_bare_model || !m_info) {
return false; return false;
@@ -78,7 +78,7 @@ bool ModelParser::parse_model_info(
} }
} }
//! parse ModelInfo::algo_policy //! parse ModelInfo::algo_policy
if (m_info->algo_policy()) {
if (m_info->algo_policy() && configure_valid) {
size_t cache_length = m_info->algo_policy()->size(); size_t cache_length = m_info->algo_policy()->size();
const uint8_t* cache = m_info->algo_policy()->Data(); const uint8_t* cache = m_info->algo_policy()->Data();
if (m_info_cache_parse_func_name == "LITE_parse_cache") { if (m_info_cache_parse_func_name == "LITE_parse_cache") {
@@ -93,6 +93,10 @@ bool ModelParser::parse_model_info(
} else { } else {
LITE_THROW("opencl binary cache is not given"); LITE_THROW("opencl binary cache is not given");
} }
} else {
LITE_THROW(ssprintf(
"model cache parse function of %s is not defined.",
m_info_cache_parse_func_name.c_str()));
} }
} }
return true; return true;


+ 1
- 1
lite/src/parse_model/model_parser.h View File

@@ -25,7 +25,7 @@ public:
bool parse_model_info( bool parse_model_info(
Config& network_config, NetworkIO& network_io, Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map, std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const;
std::string& extra_info, bool configure_valid) const;


//! parse the model and decrypt the model //! parse the model and decrypt the model
std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const; std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const;


+ 25
- 0
lite/test/test_network_options.cpp View File

@@ -7,6 +7,8 @@
#include "lite/global.h" #include "lite/global.h"


#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include "megbrain/utils/infile_persistent_cache.h"
#include "megbrain/utils/persistent_cache.h"
#include "test_common.h" #include "test_common.h"


#include <string.h> #include <string.h>
@@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) {
compare_lite_tensor<float>(output_tensor, result_mgb); compare_lite_tensor<float>(output_tensor, result_mgb);
} }


TEST(TestNetWorkOptions, DisableModelInfo) {
//! clear the cache set by other test
mgb::PersistentCache::inst().set_impl(
std::make_shared<mgb::InMemoryPersistentCache>());
Config config;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./test_pack_cache_to_model.lite";
std::string model_path2 = "./test_pack_cache_to_model.lite";
std::string input_name = "data";
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->extra_configure({true});
Runtime::set_cpu_inplace_mode(network);
network->load_model(model_path);
//! the fast-run cache will not configure, so it is not support dump
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), false);
ASSERT_EQ(Runtime::is_cpu_inplace_mode(network), true);

std::shared_ptr<Network> network2 = std::make_shared<Network>(config);
network2->load_model(model_path2);
//! the fast-run cache is configured by the model information
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), true);
}

TEST(TestNetWorkOptions, FastRunIgnorBatch) { TEST(TestNetWorkOptions, FastRunIgnorBatch) {
Config config; Config config;
auto tensor = get_input_data("./input_data.npy"); auto tensor = get_input_data("./input_data.npy");


Loading…
Cancel
Save