Browse Source

feat(lite): add disable configure by model info interface

GitOrigin-RevId: cd155a1fcf
release-1.10
Megvii Engine Team 3 years ago
parent
commit
d610c98756
11 changed files with 146 additions and 9 deletions
  1. +18
    -0
      lite/include/lite/network.h
  2. +17
    -0
      lite/lite-c/include/lite-c/network_c.h
  3. +14
    -0
      lite/lite-c/src/network.cpp
  4. +32
    -0
      lite/pylite/megenginelite/network.py
  5. +7
    -0
      lite/pylite/test/test_network.py
  6. +3
    -2
      lite/src/mge/network_impl.cpp
  7. +5
    -1
      lite/src/mge/network_impl.h
  8. +18
    -3
      lite/src/network.cpp
  9. +6
    -2
      lite/src/parse_model/model_parser.cpp
  10. +1
    -1
      lite/src/parse_model/model_parser.h
  11. +25
    -0
      lite/test/test_network_options.cpp

+ 18
- 0
lite/include/lite/network.h View File

@@ -118,6 +118,17 @@ struct LITE_API Config {
};

/*!
* \brief Extra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
struct LITE_API ExtraConfig {
bool disable_configure_by_model_info = false;
};

/*!
* \brief config the network input and output item
*
*/
@@ -275,6 +286,12 @@ public:
//! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const;

/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
void extra_configure(const ExtraConfig& extra_config);

public:
friend class NetworkHelper;

@@ -288,6 +305,7 @@ private:
private:
bool m_loaded = false;
Config m_config;
ExtraConfig m_extra_config;
NetworkIO m_network_io;
std::unique_ptr<NetworkImplBase> m_impl;
std::string m_extra_info;


+ 17
- 0
lite/lite-c/include/lite-c/network_c.h View File

@@ -114,6 +114,17 @@ typedef struct LiteConfig {
LITE_API LiteConfig* default_config();

/*!
* \brief Exetra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
typedef struct LiteExtraConfig {
int disable_configure_by_model_info;
} LiteExtraConfig;

/*!
* \brief config the network input and output item
*
*/
@@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios);

/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config);

#ifdef __cplusplus
}
#endif


+ 14
- 0
lite/lite-c/src/network.cpp View File

@@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
return innner_io;
}

lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) {
lite::ExtraConfig ret;
ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info;
return ret;
}

int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
@@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory(
LITE_CAPI_END();
}

LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
static_cast<lite::Network*>(network)->extra_configure(
convert_extra_config(extra_config));
LITE_CAPI_END();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 32
- 0
lite/pylite/megenginelite/network.py View File

@@ -134,6 +134,31 @@ class LiteConfig(Structure):
return data.__repr__()


class LiteExtraConfig(Structure):
"""
Extra configuration when load and compile the graph

disable_configure_by_model_info: disable the configuration dumped with
model, if set true, all configuration in the model will not apply, users
should configure the network.
"""

_fields_ = [
("disable_configure_by_model_info", c_int),
]

def __init__(self, disable_model_config=False):
self.disable_configure_by_model_info = disable_model_config

def __repr__(self):
data = {
"disable_configure_by_model_info": bool(
self.disable_configure_by_model_info
),
}
return data.__repr__()


class LiteIO(Structure):
"""
config the network input and output item
@@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase):
"LITE_get_model_io_info_by_memory",
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
),
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
]


@@ -541,6 +567,12 @@ class LiteNetwork(object):
ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
return ret_name

def extra_configure(self, extra_config):
"""
Extra Configuration to the network.
"""
self._api.LITE_extra_configure(self._network, extra_config)

def share_weights_with(self, src_network):
"""
share weights with the loaded network


+ 7
- 0
lite/pylite/test/test_network.py View File

@@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet):
network.load(model_path)
self.do_forward(network)

def test_disable_model_config(self):
model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
network = LiteNetwork()
network.extra_configure(LiteExtraConfig(True))
network.load(model_path)
self.do_forward(network)

def test_pack_cache_to_model(self):
model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite")
network = LiteNetwork()


+ 3
- 2
lite/src/mge/network_impl.cpp View File

@@ -31,7 +31,6 @@ using namespace mgb;
LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);

void NetworkImplDft::set_config(const Config& config) {
m_user_config = std::make_unique<Config>();
*m_user_config = config;
m_compnode_locator = to_compnode_locator(m_user_config->device_type);
m_compnode_locator.device = config.device_id;
@@ -428,8 +427,11 @@ void NetworkImplDft::load_model(

global_layout_transform();

//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
adapt_option_valid();

//! find how many compnode the model has, this should call before update_io
cross_compnode_model_detect();

//! update the IO of the network
@@ -496,7 +498,6 @@ void NetworkImplDft::finish() const {
}

void NetworkImplDft::set_io(const NetworkIO& network_io) {
m_network_io = std::make_unique<NetworkIOInner>();
for (auto&& in : network_io.inputs) {
m_network_io->inputs.emplace_back(in);
}


+ 5
- 1
lite/src/mge/network_impl.h View File

@@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase {
LITE_DYN_TYPE_OBJ_FINAL_DECL;

public:
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); }
NetworkImplDft() {
m_load_config.comp_graph = mgb::ComputingGraph::make();
m_user_config = std::make_unique<Config>();
m_network_io = std::make_unique<NetworkIOInner>();
}
using S = megdnn::param::ExecutionPolicy::Strategy;
using Var = mgb::cg::SymbolVar;
//! set the config of the network, include:


+ 18
- 3
lite/src/network.cpp View File

@@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) {
ModelParser model_parser(model_data, size);
//! parse the model info
if (model_parser.parse_model_info(
m_config, m_network_io, separate_config_map, m_extra_info)) {
m_config, m_network_io, separate_config_map, m_extra_info,
!m_extra_config.disable_configure_by_model_info)) {
if (m_config.backend == LiteBackend::LITE_DEFAULT &&
m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) {
m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>(
"parse_model"));
}
m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
if (!m_extra_config.disable_configure_by_model_info) {
m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
}
}
//! decryption the model
size_t model_length;
@@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_END
}

void Network::extra_configure(const ExtraConfig& extra_config) {
LITE_ERROR_HANDLER_BEGIN
if (!extra_config.disable_configure_by_model_info) {
LITE_ASSERT(
!m_loaded,
"disable_configure_by_model_info should be configured before model "
"loaded.");
}
m_extra_config = extra_config;
LITE_ERROR_HANDLER_END
}

/*********************** MGE special network function ***************/

void Runtime::set_cpu_threads_number(


+ 6
- 2
lite/src/parse_model/model_parser.cpp View File

@@ -43,7 +43,7 @@ void ModelParser::parse_header() {
bool ModelParser::parse_model_info(
Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const {
std::string& extra_info, bool configure_valid) const {
//! no model info, no parse, direct return
if (m_is_bare_model || !m_info) {
return false;
@@ -78,7 +78,7 @@ bool ModelParser::parse_model_info(
}
}
//! parse ModelInfo::algo_policy
if (m_info->algo_policy()) {
if (m_info->algo_policy() && configure_valid) {
size_t cache_length = m_info->algo_policy()->size();
const uint8_t* cache = m_info->algo_policy()->Data();
if (m_info_cache_parse_func_name == "LITE_parse_cache") {
@@ -93,6 +93,10 @@ bool ModelParser::parse_model_info(
} else {
LITE_THROW("opencl binary cache is not given");
}
} else {
LITE_THROW(ssprintf(
"model cache parse function of %s is not defined.",
m_info_cache_parse_func_name.c_str()));
}
}
return true;


+ 1
- 1
lite/src/parse_model/model_parser.h View File

@@ -25,7 +25,7 @@ public:
bool parse_model_info(
Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const;
std::string& extra_info, bool configure_valid) const;

//! parse the model and decrypt the model
std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const;


+ 25
- 0
lite/test/test_network_options.cpp View File

@@ -7,6 +7,8 @@
#include "lite/global.h"

#include "megbrain/tensor.h"
#include "megbrain/utils/infile_persistent_cache.h"
#include "megbrain/utils/persistent_cache.h"
#include "test_common.h"

#include <string.h>
@@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) {
compare_lite_tensor<float>(output_tensor, result_mgb);
}

TEST(TestNetWorkOptions, DisableModelInfo) {
//! clear the cache set by other test
mgb::PersistentCache::inst().set_impl(
std::make_shared<mgb::InMemoryPersistentCache>());
Config config;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./test_pack_cache_to_model.lite";
std::string model_path2 = "./test_pack_cache_to_model.lite";
std::string input_name = "data";
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->extra_configure({true});
Runtime::set_cpu_inplace_mode(network);
network->load_model(model_path);
//! the fast-run cache will not configure, so it is not support dump
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), false);
ASSERT_EQ(Runtime::is_cpu_inplace_mode(network), true);

std::shared_ptr<Network> network2 = std::make_shared<Network>(config);
network2->load_model(model_path2);
//! the fast-run cache is configured by the model information
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), true);
}

TEST(TestNetWorkOptions, FastRunIgnorBatch) {
Config config;
auto tensor = get_input_data("./input_data.npy");


Loading…
Cancel
Save