Browse Source

feat(lite): add get model infomation before create network interface

GitOrigin-RevId: e499f3ebf8
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
26b52a61de
12 changed files with 379 additions and 3 deletions
  1. +8
    -0
      lite/include/lite/network.h
  2. +22
    -0
      lite/lite-c/include/lite-c/network_c.h
  3. +80
    -0
      lite/lite-c/src/network.cpp
  4. +32
    -0
      lite/pylite/megenginelite/network.py
  5. +18
    -0
      lite/pylite/test/test_utils.py
  6. +3
    -3
      lite/src/function_base.h
  7. +20
    -0
      lite/src/mge/function_dft.h
  8. +70
    -0
      lite/src/mge/network_impl.cpp
  9. +7
    -0
      lite/src/mge/network_impl.h
  10. +22
    -0
      lite/src/network.cpp
  11. +48
    -0
      lite/test/test_network.cpp
  12. +49
    -0
      lite/test/test_network_c.cpp

+ 8
- 0
lite/include/lite/network.h View File

@@ -373,6 +373,14 @@ public:
//! dump network after global layout transform optimization
static void dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path);

//! get the model io information before model loaded by model path.
static NetworkIO get_model_io_info(
const std::string& model_path, const Config& config = {});

//! get the model io information before model loaded by model memory.
static NetworkIO get_model_io_info(
const void* model_mem, size_t size, const Config& config = {});
};

} // namespace lite


+ 22
- 0
lite/lite-c/include/lite-c/network_c.h View File

@@ -588,6 +588,28 @@ LITE_API int LITE_enable_global_layout_transform(LiteNetwork network);
LITE_API int LITE_dump_layout_transform_model(
LiteNetwork network, const char* dump_file_path);

/**! get the model io information before model loaded by model path.
* \param[in] model_path The model file path
* \param[in] config The model config for loading
* \param[out] ios The model io infermation
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_get_model_io_info_by_path(
const char* model_path, const LiteConfig config, LiteNetworkIO* ios);

/** get the model io information before model loaded by model memory.
* \param[in] model_mem The model memory ptr
* \param[in] size The model memory ptr length
* \param[in] config The model config for loading
* \param[out] ios The model io infermation
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios);

#ifdef __cplusplus
}
#endif


+ 80
- 0
lite/lite-c/src/network.cpp View File

@@ -167,6 +167,31 @@ lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) {
return network_io;
}

struct InnerIO {
std::vector<std::string> names;
std::vector<LiteIO> inputs;
std::vector<LiteIO> outputs;
};

InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
InnerIO innner_io;
for (size_t i = 0; i < network_io.inputs.size(); i++) {
lite::IO io = network_io.inputs[i];
innner_io.names.push_back(io.name);
innner_io.inputs.push_back(
{innner_io.names.back().c_str(), io.is_host, io.io_type,
convert_to_clayout(io.config_layout)});
}
for (size_t i = 0; i < network_io.outputs.size(); i++) {
lite::IO io = network_io.outputs[i];
innner_io.names.push_back(io.name);
innner_io.outputs.push_back(
{innner_io.names.back().c_str(), io.is_host, io.io_type,
convert_to_clayout(io.config_layout)});
}
return innner_io;
}

int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
@@ -665,4 +690,59 @@ int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
LITE_CAPI_END();
}

namespace {
static LITE_MUTEX mtx_io;
static std::unordered_map<const void*, InnerIO>& get_global_io_holder() {
static std::unordered_map<const void*, InnerIO> global_holder;
return global_holder;
}

int write_ios_from_cpp_io(
const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) {
LITE_CAPI_BEGIN();
LITE_LOCK_GUARD(mtx_io);
get_global_io_holder()[key] = convert_to_inner_io(cpp_io);
auto&& inner_io = get_global_io_holder()[key];
ios->input_size = inner_io.inputs.size();
ios->output_size = inner_io.outputs.size();
ios->inputs = inner_io.inputs.data();
ios->outputs = inner_io.outputs.data();
size_t i = 0;
for (; i < ios->input_size; i++) {
auto io_ptr = ios->inputs + i;
io_ptr->name = inner_io.names[i].c_str();
}
for (; i < ios->output_size; i++) {
auto io_ptr = ios->outputs + i;
io_ptr->name = inner_io.names[i].c_str();
}
LITE_CAPI_END();
}

} // namespace

int LITE_get_model_io_info_by_path(
const char* model_path, const LiteConfig config, LiteNetworkIO* ios) {
LITE_CAPI_BEGIN();
LITE_ASSERT(model_path, "The model_path pass to LITE api is null");
auto&& cpp_ios = lite::Runtime::get_model_io_info(
std::string{model_path}, convert_to_lite_config(config));
return write_ios_from_cpp_io(
cpp_ios, ios, reinterpret_cast<const void*>(model_path));
LITE_CAPI_END();
}

int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios) {
LITE_CAPI_BEGIN();
LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null");
auto&& cpp_ios = lite::Runtime::get_model_io_info(
model_mem, size, convert_to_lite_config(config));
return write_ios_from_cpp_io(
cpp_ios, ios, reinterpret_cast<const void*>(model_mem));
LITE_CAPI_END();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 32
- 0
lite/pylite/megenginelite/network.py View File

@@ -364,6 +364,14 @@ class _NetworkAPI(_LiteCObjBase):
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
("LITE_enable_global_layout_transform", [_Cnetwork]),
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
(
"LITE_get_model_io_info_by_path",
[c_char_p, LiteConfig, POINTER(_LiteNetworkIO)],
),
(
"LITE_get_model_io_info_by_memory",
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
),
]


@@ -619,3 +627,27 @@ class LiteNetwork(object):
def dump_layout_transform_model(self, model_file):
c_file = model_file.encode("utf-8")
self._api.LITE_dump_layout_transform_model(self._network, c_file)


def get_model_io_info(model_path, config=None):
"""
get the model IO information before create the NetWork, this IO
information can be used to configuration the NetWork.
"""
api = _NetworkAPI()._lib
c_path = c_char_p(model_path.encode("utf-8"))

ios = _LiteNetworkIO()

if config is not None:
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
else:
config = LiteConfig()
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))

ret_ios = LiteNetworkIO()
for i in range(ios.input_size):
ret_ios.add_input(ios.inputs[i])
for i in range(ios.output_size):
ret_ios.add_output(ios.outputs[i])
return ret_ios

+ 18
- 0
lite/pylite/test/test_utils.py View File

@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import functools
import os

import numpy as np

@@ -200,3 +201,20 @@ def test_tensor_collect_batch_device_numpy():
for i in range(4):
for j in range(48):
assert data[i][j // 8][j % 8] == i + 1


def test_get_model_io_ahead():
source_dir = os.getenv("LITE_TEST_RESOURCE")
model_path = os.path.join(source_dir, "shufflenet.mge")
ios = get_model_io_info(model_path)

assert len(ios.inputs) == 1
assert ios.inputs[0].name == "data"
assert ios.inputs[0].config_layout.shapes[1] == 3
assert ios.inputs[0].config_layout.shapes[2] == 224
assert ios.inputs[0].config_layout.shapes[3] == 224

assert len(ios.outputs) == 1
assert ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
assert ios.outputs[0].config_layout.shapes[0] == 1
assert ios.outputs[0].config_layout.shapes[1] == 1000

+ 3
- 3
lite/src/function_base.h View File

@@ -34,7 +34,7 @@ ADD_STATEMENT(NetworkImplDft, Dft);
} // namespace

// if it can't find the function, ignore
template <typename tensor_type, typename ret_type, typename... Args>
template <typename type, typename ret_type, typename... Args>
ret_type try_call_func(std::string func_name, Args... args) {
mark_used_variable(func_name);
mark_used_variable(args...);
@@ -42,10 +42,10 @@ ret_type try_call_func(std::string func_name, Args... args) {
}

// if it can't find the function, throw error
template <typename tensor_type, typename ret_type, typename... Args>
template <typename type, typename ret_type, typename... Args>
ret_type call_func(std::string func_name, Args... args) {
mark_used_variable(args...);
auto backend_name = class_type_name<tensor_type>()();
auto backend_name = class_type_name<type>()();
auto msg_info = func_name + " is not aviliable in " + backend_name + " backend.";
LITE_THROW(msg_info.c_str());
}


+ 20
- 0
lite/src/mge/function_dft.h View File

@@ -206,6 +206,26 @@ inline void call_func<NetworkImplDft, void>(
THROW_FUNC_ERROR(func_name);
}
}

template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
std::string func_name, std::string model_path, Config config) {
if (func_name == "get_model_io_info") {
return get_model_io_info_dft(model_path, config);
} else {
THROW_FUNC_ERROR(func_name);
}
}

template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
std::string func_name, const void* model_mem, size_t size, Config config) {
if (func_name == "get_model_io_info") {
return get_model_io_info_dft(model_mem, size, config);
} else {
THROW_FUNC_ERROR(func_name);
}
}
#undef THROW_FUNC_ERROR

} // namespace lite


+ 70
- 0
lite/src/mge/network_impl.cpp View File

@@ -929,5 +929,75 @@ void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_pat
"enable_global_layout_transform before"));
}
}

NetworkIO lite::get_model_io_info_dft(
const std::string& model_path, const Config& config) {
FILE* fin = fopen(model_path.c_str(), "rb");
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, ::free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size);
fclose(fin);
return get_model_io_info_dft(ptr, size, config);
}

NetworkIO lite::get_model_io_info_dft(
const void* model_mem, size_t size, const Config& config) {
std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}};
auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false);
auto format =
mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file);
if (!format.valid()) {
LITE_THROW("invalid model format");
}
auto loader =
mgb::serialization::GraphLoader::make(std::move(input_file), format.val());

mgb::serialization::GraphLoadConfig load_config;
load_config.comp_graph = mgb::ComputingGraph::make();
if (config.has_compression) {
load_config.tensor_value_loader = decompressed_tensor_value_loader;
}
auto compnode_locator = to_compnode_locator(config.device_type);
load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
loc.type = compnode_locator.type;
}
loc.device = compnode_locator.device;
};
auto load_result = loader->load(load_config, true);
NetworkIO IOs;
for (auto&& in_tensor_iter : load_result.tensor_map) {
IO in_io;
in_io.name = in_tensor_iter.first;
in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout());
IOs.inputs.push_back(in_io);
}
auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* {
auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager();
using InferType = mgb::cg::static_infer::InferType;
if (static_infer_mgr.get_infer_type(var.node()).shape &
(InferType::CONST | InferType::RT_STATIC)) {
return static_infer_mgr.infer_shape_fallible(var.node());
} else {
return nullptr;
}
};
for (auto&& out : load_result.output_var_list) {
IO out_io;
out_io.name = out.node()->name();
if (auto shape = infer_shape(out)) {
out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()});
} else {
out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()});
}
IOs.outputs.push_back(out_io);
}
return IOs;
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 7
- 0
lite/src/mge/network_impl.h View File

@@ -262,6 +262,13 @@ private:
#endif
std::unique_ptr<mgb::OprIODumpBase> m_iodump;
};
//! get the model information before model loaded by Network
NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config);

//! get the model information before model loaded by Network by model memory and
//! size
NetworkIO get_model_io_info_dft(
const void* model_mem, size_t size, const Config& config);

} // namespace lite



+ 22
- 0
lite/src/network.cpp View File

@@ -534,4 +534,26 @@ void Runtime::dump_layout_transform_model(
LITE_THROW("dump_layout_transform_model is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}

NetworkIO Runtime::get_model_io_info(
const std::string& model_path, const Config& config) {
LITE_ERROR_HANDLER_BEGIN
if (config.backend == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, NetworkIO>(
"get_model_io_info", model_path, config);
}
LITE_THROW("get_model_io_info is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}

NetworkIO Runtime::get_model_io_info(
const void* model_mem, size_t size, const Config& config) {
LITE_ERROR_HANDLER_BEGIN
if (config.backend == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, NetworkIO>(
"get_model_io_info", model_mem, size, config);
}
LITE_THROW("get_model_io_info is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 48
- 0
lite/test/test_network.cpp View File

@@ -106,6 +106,54 @@ TEST(TestNetWork, GetAllName) {
ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
}

TEST(TestNetWork, GetAllIoInfoAhead) {
Config config;
std::string model_path = "./shufflenet.mge";

auto ios = Runtime::get_model_io_info(model_path);

FILE* fin = fopen(model_path.c_str(), "rb");
ASSERT_TRUE(fin);
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, ::free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size);
fclose(fin);

auto ios_mem = Runtime::get_model_io_info(ptr, size);

ASSERT_EQ(ios.inputs.size(), ios_mem.inputs.size());
ASSERT_EQ(ios.inputs.size(), 1);

ASSERT_EQ(ios.outputs.size(), ios_mem.outputs.size());
ASSERT_EQ(ios.outputs.size(), 1);

ASSERT_TRUE(ios.inputs[0].name == "data");
ASSERT_TRUE(ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");

ASSERT_TRUE(ios_mem.inputs[0].name == "data");
ASSERT_TRUE(
ios_mem.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
ASSERT_EQ(ios.inputs[0].config_layout.ndim, 4);
ASSERT_EQ(ios.inputs[0].config_layout.shapes[1], 3);
ASSERT_EQ(ios.inputs[0].config_layout.shapes[2], 224);

ASSERT_EQ(ios.outputs[0].config_layout.ndim, 2);
ASSERT_EQ(ios.outputs[0].config_layout.shapes[0], 1);
ASSERT_EQ(ios.outputs[0].config_layout.shapes[1], 1000);

ASSERT_EQ(ios_mem.inputs[0].config_layout.ndim, 4);
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[1], 3);
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[2], 224);

ASSERT_EQ(ios_mem.outputs[0].config_layout.ndim, 2);
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[0], 1);
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[1], 1000);
}

TEST(TestNetWork, LoadFBSModel) {
Config config;
std::string model_path = "./ax.mge";


+ 49
- 0
lite/test/test_network_c.cpp View File

@@ -252,6 +252,55 @@ TEST(TestCapiNetWork, GetAllName) {
LITE_destroy_network(c_network);
}

TEST(TestCapiNetWork, GetAllNameAhead) {
std::string model_path = "./shufflenet.mge";
LiteNetworkIO ios, ios_mem;
LITE_CAPI_CHECK(LITE_get_model_io_info_by_path(
model_path.c_str(), *default_config(), &ios));
FILE* fin = fopen(model_path.c_str(), "rb");
ASSERT_TRUE(fin);
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, ::free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size);
fclose(fin);

LITE_CAPI_CHECK(
LITE_get_model_io_info_by_memory(ptr, size, *default_config(), &ios_mem));

ASSERT_EQ(ios.input_size, 1);
ASSERT_EQ(ios.output_size, 1);
ASSERT_EQ(ios_mem.input_size, 1);
ASSERT_EQ(ios_mem.output_size, 1);

ASSERT_TRUE(std::string(ios.inputs->name) == "data");
ASSERT_TRUE(ios.inputs->config_layout.ndim == 4);
ASSERT_TRUE(ios.inputs->config_layout.shapes[1] == 3);
ASSERT_TRUE(ios.inputs->config_layout.shapes[2] == 224);
ASSERT_TRUE(ios.inputs->config_layout.shapes[3] == 224);
ASSERT_TRUE(
std::string(ios.outputs->name) ==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
ASSERT_TRUE(ios.outputs->config_layout.ndim == 2);
ASSERT_TRUE(ios.outputs->config_layout.shapes[0] == 1);
ASSERT_TRUE(ios.outputs->config_layout.shapes[1] == 1000);

ASSERT_TRUE(std::string(ios_mem.inputs->name) == "data");
ASSERT_TRUE(ios_mem.inputs->config_layout.ndim == 4);
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[1] == 3);
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[2] == 224);
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[3] == 224);
ASSERT_TRUE(
std::string(ios_mem.outputs->name) ==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
ASSERT_TRUE(ios_mem.outputs->config_layout.ndim == 2);
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[0] == 1);
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000);
}

#if LITE_BUILD_WITH_RKNPU

static int GetTop(


Loading…
Cancel
Save