GitOrigin-RevId: e499f3ebf8
tags/v1.9.0
@@ -373,6 +373,14 @@ public: | |||
//! dump network after global layout transform optimization | |||
static void dump_layout_transform_model( | |||
std::shared_ptr<Network> network, std::string optimized_model_path); | |||
//! get the model io information before model loaded by model path. | |||
static NetworkIO get_model_io_info( | |||
const std::string& model_path, const Config& config = {}); | |||
//! get the model io information before model loaded by model memory. | |||
static NetworkIO get_model_io_info( | |||
const void* model_mem, size_t size, const Config& config = {}); | |||
}; | |||
} // namespace lite | |||
@@ -588,6 +588,28 @@ LITE_API int LITE_enable_global_layout_transform(LiteNetwork network); | |||
LITE_API int LITE_dump_layout_transform_model( | |||
LiteNetwork network, const char* dump_file_path); | |||
/**! get the model io information before model loaded by model path. | |||
* \param[in] model_path The model file path | |||
* \param[in] config The model config for loading | |||
* \param[out] ios The model io infermation | |||
* \return int if the return is not zero, error happened, the error message | |||
* can get by LITE_get_last_error | |||
*/ | |||
LITE_API int LITE_get_model_io_info_by_path( | |||
const char* model_path, const LiteConfig config, LiteNetworkIO* ios); | |||
/** get the model io information before model loaded by model memory. | |||
* \param[in] model_mem The model memory ptr | |||
* \param[in] size The model memory ptr length | |||
* \param[in] config The model config for loading | |||
* \param[out] ios The model io infermation | |||
* \return int if the return is not zero, error happened, the error message | |||
* can get by LITE_get_last_error | |||
*/ | |||
LITE_API int LITE_get_model_io_info_by_memory( | |||
const void* model_mem, size_t size, const LiteConfig config, | |||
LiteNetworkIO* ios); | |||
#ifdef __cplusplus | |||
} | |||
#endif | |||
@@ -167,6 +167,31 @@ lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) { | |||
return network_io; | |||
} | |||
struct InnerIO { | |||
std::vector<std::string> names; | |||
std::vector<LiteIO> inputs; | |||
std::vector<LiteIO> outputs; | |||
}; | |||
InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { | |||
InnerIO innner_io; | |||
for (size_t i = 0; i < network_io.inputs.size(); i++) { | |||
lite::IO io = network_io.inputs[i]; | |||
innner_io.names.push_back(io.name); | |||
innner_io.inputs.push_back( | |||
{innner_io.names.back().c_str(), io.is_host, io.io_type, | |||
convert_to_clayout(io.config_layout)}); | |||
} | |||
for (size_t i = 0; i < network_io.outputs.size(); i++) { | |||
lite::IO io = network_io.outputs[i]; | |||
innner_io.names.push_back(io.name); | |||
innner_io.outputs.push_back( | |||
{innner_io.names.back().c_str(), io.is_host, io.io_type, | |||
convert_to_clayout(io.config_layout)}); | |||
} | |||
return innner_io; | |||
} | |||
int LITE_make_default_network(LiteNetwork* network) { | |||
LITE_CAPI_BEGIN(); | |||
LITE_ASSERT(network, "The network pass to LITE api is null"); | |||
@@ -665,4 +690,59 @@ int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_ | |||
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path); | |||
LITE_CAPI_END(); | |||
} | |||
namespace { | |||
static LITE_MUTEX mtx_io; | |||
static std::unordered_map<const void*, InnerIO>& get_global_io_holder() { | |||
static std::unordered_map<const void*, InnerIO> global_holder; | |||
return global_holder; | |||
} | |||
int write_ios_from_cpp_io( | |||
const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) { | |||
LITE_CAPI_BEGIN(); | |||
LITE_LOCK_GUARD(mtx_io); | |||
get_global_io_holder()[key] = convert_to_inner_io(cpp_io); | |||
auto&& inner_io = get_global_io_holder()[key]; | |||
ios->input_size = inner_io.inputs.size(); | |||
ios->output_size = inner_io.outputs.size(); | |||
ios->inputs = inner_io.inputs.data(); | |||
ios->outputs = inner_io.outputs.data(); | |||
size_t i = 0; | |||
for (; i < ios->input_size; i++) { | |||
auto io_ptr = ios->inputs + i; | |||
io_ptr->name = inner_io.names[i].c_str(); | |||
} | |||
for (; i < ios->output_size; i++) { | |||
auto io_ptr = ios->outputs + i; | |||
io_ptr->name = inner_io.names[i].c_str(); | |||
} | |||
LITE_CAPI_END(); | |||
} | |||
} // namespace | |||
int LITE_get_model_io_info_by_path( | |||
const char* model_path, const LiteConfig config, LiteNetworkIO* ios) { | |||
LITE_CAPI_BEGIN(); | |||
LITE_ASSERT(model_path, "The model_path pass to LITE api is null"); | |||
auto&& cpp_ios = lite::Runtime::get_model_io_info( | |||
std::string{model_path}, convert_to_lite_config(config)); | |||
return write_ios_from_cpp_io( | |||
cpp_ios, ios, reinterpret_cast<const void*>(model_path)); | |||
LITE_CAPI_END(); | |||
} | |||
int LITE_get_model_io_info_by_memory( | |||
const void* model_mem, size_t size, const LiteConfig config, | |||
LiteNetworkIO* ios) { | |||
LITE_CAPI_BEGIN(); | |||
LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null"); | |||
auto&& cpp_ios = lite::Runtime::get_model_io_info( | |||
model_mem, size, convert_to_lite_config(config)); | |||
return write_ios_from_cpp_io( | |||
cpp_ios, ios, reinterpret_cast<const void*>(model_mem)); | |||
LITE_CAPI_END(); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -364,6 +364,14 @@ class _NetworkAPI(_LiteCObjBase): | |||
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | |||
("LITE_enable_global_layout_transform", [_Cnetwork]), | |||
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), | |||
( | |||
"LITE_get_model_io_info_by_path", | |||
[c_char_p, LiteConfig, POINTER(_LiteNetworkIO)], | |||
), | |||
( | |||
"LITE_get_model_io_info_by_memory", | |||
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | |||
), | |||
] | |||
@@ -619,3 +627,27 @@ class LiteNetwork(object): | |||
def dump_layout_transform_model(self, model_file): | |||
c_file = model_file.encode("utf-8") | |||
self._api.LITE_dump_layout_transform_model(self._network, c_file) | |||
def get_model_io_info(model_path, config=None): | |||
""" | |||
get the model IO information before create the NetWork, this IO | |||
information can be used to configuration the NetWork. | |||
""" | |||
api = _NetworkAPI()._lib | |||
c_path = c_char_p(model_path.encode("utf-8")) | |||
ios = _LiteNetworkIO() | |||
if config is not None: | |||
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) | |||
else: | |||
config = LiteConfig() | |||
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) | |||
ret_ios = LiteNetworkIO() | |||
for i in range(ios.input_size): | |||
ret_ios.add_input(ios.inputs[i]) | |||
for i in range(ios.output_size): | |||
ret_ios.add_output(ios.outputs[i]) | |||
return ret_ios |
@@ -8,6 +8,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import os | |||
import numpy as np | |||
@@ -200,3 +201,20 @@ def test_tensor_collect_batch_device_numpy(): | |||
for i in range(4): | |||
for j in range(48): | |||
assert data[i][j // 8][j % 8] == i + 1 | |||
def test_get_model_io_ahead(): | |||
source_dir = os.getenv("LITE_TEST_RESOURCE") | |||
model_path = os.path.join(source_dir, "shufflenet.mge") | |||
ios = get_model_io_info(model_path) | |||
assert len(ios.inputs) == 1 | |||
assert ios.inputs[0].name == "data" | |||
assert ios.inputs[0].config_layout.shapes[1] == 3 | |||
assert ios.inputs[0].config_layout.shapes[2] == 224 | |||
assert ios.inputs[0].config_layout.shapes[3] == 224 | |||
assert len(ios.outputs) == 1 | |||
assert ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]" | |||
assert ios.outputs[0].config_layout.shapes[0] == 1 | |||
assert ios.outputs[0].config_layout.shapes[1] == 1000 |
@@ -34,7 +34,7 @@ ADD_STATEMENT(NetworkImplDft, Dft); | |||
} // namespace | |||
// if it can't find the function, ignore | |||
template <typename tensor_type, typename ret_type, typename... Args> | |||
template <typename type, typename ret_type, typename... Args> | |||
ret_type try_call_func(std::string func_name, Args... args) { | |||
mark_used_variable(func_name); | |||
mark_used_variable(args...); | |||
@@ -42,10 +42,10 @@ ret_type try_call_func(std::string func_name, Args... args) { | |||
} | |||
// if it can't find the function, throw error | |||
template <typename tensor_type, typename ret_type, typename... Args> | |||
template <typename type, typename ret_type, typename... Args> | |||
ret_type call_func(std::string func_name, Args... args) { | |||
mark_used_variable(args...); | |||
auto backend_name = class_type_name<tensor_type>()(); | |||
auto backend_name = class_type_name<type>()(); | |||
auto msg_info = func_name + " is not aviliable in " + backend_name + " backend."; | |||
LITE_THROW(msg_info.c_str()); | |||
} | |||
@@ -206,6 +206,26 @@ inline void call_func<NetworkImplDft, void>( | |||
THROW_FUNC_ERROR(func_name); | |||
} | |||
} | |||
template <> | |||
inline NetworkIO call_func<NetworkImplDft, NetworkIO>( | |||
std::string func_name, std::string model_path, Config config) { | |||
if (func_name == "get_model_io_info") { | |||
return get_model_io_info_dft(model_path, config); | |||
} else { | |||
THROW_FUNC_ERROR(func_name); | |||
} | |||
} | |||
template <> | |||
inline NetworkIO call_func<NetworkImplDft, NetworkIO>( | |||
std::string func_name, const void* model_mem, size_t size, Config config) { | |||
if (func_name == "get_model_io_info") { | |||
return get_model_io_info_dft(model_mem, size, config); | |||
} else { | |||
THROW_FUNC_ERROR(func_name); | |||
} | |||
} | |||
#undef THROW_FUNC_ERROR | |||
} // namespace lite | |||
@@ -929,5 +929,75 @@ void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_pat | |||
"enable_global_layout_transform before")); | |||
} | |||
} | |||
NetworkIO lite::get_model_io_info_dft( | |||
const std::string& model_path, const Config& config) { | |||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||
fseek(fin, 0, SEEK_END); | |||
size_t size = ftell(fin); | |||
fseek(fin, 0, SEEK_SET); | |||
void* ptr = malloc(size); | |||
std::shared_ptr<void> buf{ptr, ::free}; | |||
auto nr = fread(buf.get(), 1, size, fin); | |||
LITE_ASSERT(nr == size); | |||
fclose(fin); | |||
return get_model_io_info_dft(ptr, size, config); | |||
} | |||
NetworkIO lite::get_model_io_info_dft( | |||
const void* model_mem, size_t size, const Config& config) { | |||
std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}}; | |||
auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false); | |||
auto format = | |||
mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file); | |||
if (!format.valid()) { | |||
LITE_THROW("invalid model format"); | |||
} | |||
auto loader = | |||
mgb::serialization::GraphLoader::make(std::move(input_file), format.val()); | |||
mgb::serialization::GraphLoadConfig load_config; | |||
load_config.comp_graph = mgb::ComputingGraph::make(); | |||
if (config.has_compression) { | |||
load_config.tensor_value_loader = decompressed_tensor_value_loader; | |||
} | |||
auto compnode_locator = to_compnode_locator(config.device_type); | |||
load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) { | |||
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { | |||
loc.type = compnode_locator.type; | |||
} | |||
loc.device = compnode_locator.device; | |||
}; | |||
auto load_result = loader->load(load_config, true); | |||
NetworkIO IOs; | |||
for (auto&& in_tensor_iter : load_result.tensor_map) { | |||
IO in_io; | |||
in_io.name = in_tensor_iter.first; | |||
in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout()); | |||
IOs.inputs.push_back(in_io); | |||
} | |||
auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* { | |||
auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager(); | |||
using InferType = mgb::cg::static_infer::InferType; | |||
if (static_infer_mgr.get_infer_type(var.node()).shape & | |||
(InferType::CONST | InferType::RT_STATIC)) { | |||
return static_infer_mgr.infer_shape_fallible(var.node()); | |||
} else { | |||
return nullptr; | |||
} | |||
}; | |||
for (auto&& out : load_result.output_var_list) { | |||
IO out_io; | |||
out_io.name = out.node()->name(); | |||
if (auto shape = infer_shape(out)) { | |||
out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()}); | |||
} else { | |||
out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()}); | |||
} | |||
IOs.outputs.push_back(out_io); | |||
} | |||
return IOs; | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -262,6 +262,13 @@ private: | |||
#endif | |||
std::unique_ptr<mgb::OprIODumpBase> m_iodump; | |||
}; | |||
//! get the model information before model loaded by Network | |||
NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config); | |||
//! get the model information before model loaded by Network by model memory and | |||
//! size | |||
NetworkIO get_model_io_info_dft( | |||
const void* model_mem, size_t size, const Config& config); | |||
} // namespace lite | |||
@@ -534,4 +534,26 @@ void Runtime::dump_layout_transform_model( | |||
LITE_THROW("dump_layout_transform_model is not aviliable in the backend."); | |||
LITE_ERROR_HANDLER_END | |||
} | |||
NetworkIO Runtime::get_model_io_info( | |||
const std::string& model_path, const Config& config) { | |||
LITE_ERROR_HANDLER_BEGIN | |||
if (config.backend == LiteBackend::LITE_DEFAULT) { | |||
return call_func<NetworkImplDft, NetworkIO>( | |||
"get_model_io_info", model_path, config); | |||
} | |||
LITE_THROW("get_model_io_info is not aviliable in the backend."); | |||
LITE_ERROR_HANDLER_END | |||
} | |||
NetworkIO Runtime::get_model_io_info( | |||
const void* model_mem, size_t size, const Config& config) { | |||
LITE_ERROR_HANDLER_BEGIN | |||
if (config.backend == LiteBackend::LITE_DEFAULT) { | |||
return call_func<NetworkImplDft, NetworkIO>( | |||
"get_model_io_info", model_mem, size, config); | |||
} | |||
LITE_THROW("get_model_io_info is not aviliable in the backend."); | |||
LITE_ERROR_HANDLER_END | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -106,6 +106,54 @@ TEST(TestNetWork, GetAllName) { | |||
ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
} | |||
TEST(TestNetWork, GetAllIoInfoAhead) { | |||
Config config; | |||
std::string model_path = "./shufflenet.mge"; | |||
auto ios = Runtime::get_model_io_info(model_path); | |||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||
ASSERT_TRUE(fin); | |||
fseek(fin, 0, SEEK_END); | |||
size_t size = ftell(fin); | |||
fseek(fin, 0, SEEK_SET); | |||
void* ptr = malloc(size); | |||
std::shared_ptr<void> buf{ptr, ::free}; | |||
auto nr = fread(buf.get(), 1, size, fin); | |||
LITE_ASSERT(nr == size); | |||
fclose(fin); | |||
auto ios_mem = Runtime::get_model_io_info(ptr, size); | |||
ASSERT_EQ(ios.inputs.size(), ios_mem.inputs.size()); | |||
ASSERT_EQ(ios.inputs.size(), 1); | |||
ASSERT_EQ(ios.outputs.size(), ios_mem.outputs.size()); | |||
ASSERT_EQ(ios.outputs.size(), 1); | |||
ASSERT_TRUE(ios.inputs[0].name == "data"); | |||
ASSERT_TRUE(ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
ASSERT_TRUE(ios_mem.inputs[0].name == "data"); | |||
ASSERT_TRUE( | |||
ios_mem.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
ASSERT_EQ(ios.inputs[0].config_layout.ndim, 4); | |||
ASSERT_EQ(ios.inputs[0].config_layout.shapes[1], 3); | |||
ASSERT_EQ(ios.inputs[0].config_layout.shapes[2], 224); | |||
ASSERT_EQ(ios.outputs[0].config_layout.ndim, 2); | |||
ASSERT_EQ(ios.outputs[0].config_layout.shapes[0], 1); | |||
ASSERT_EQ(ios.outputs[0].config_layout.shapes[1], 1000); | |||
ASSERT_EQ(ios_mem.inputs[0].config_layout.ndim, 4); | |||
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[1], 3); | |||
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[2], 224); | |||
ASSERT_EQ(ios_mem.outputs[0].config_layout.ndim, 2); | |||
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[0], 1); | |||
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[1], 1000); | |||
} | |||
TEST(TestNetWork, LoadFBSModel) { | |||
Config config; | |||
std::string model_path = "./ax.mge"; | |||
@@ -252,6 +252,55 @@ TEST(TestCapiNetWork, GetAllName) { | |||
LITE_destroy_network(c_network); | |||
} | |||
TEST(TestCapiNetWork, GetAllNameAhead) { | |||
std::string model_path = "./shufflenet.mge"; | |||
LiteNetworkIO ios, ios_mem; | |||
LITE_CAPI_CHECK(LITE_get_model_io_info_by_path( | |||
model_path.c_str(), *default_config(), &ios)); | |||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||
ASSERT_TRUE(fin); | |||
fseek(fin, 0, SEEK_END); | |||
size_t size = ftell(fin); | |||
fseek(fin, 0, SEEK_SET); | |||
void* ptr = malloc(size); | |||
std::shared_ptr<void> buf{ptr, ::free}; | |||
auto nr = fread(buf.get(), 1, size, fin); | |||
LITE_ASSERT(nr == size); | |||
fclose(fin); | |||
LITE_CAPI_CHECK( | |||
LITE_get_model_io_info_by_memory(ptr, size, *default_config(), &ios_mem)); | |||
ASSERT_EQ(ios.input_size, 1); | |||
ASSERT_EQ(ios.output_size, 1); | |||
ASSERT_EQ(ios_mem.input_size, 1); | |||
ASSERT_EQ(ios_mem.output_size, 1); | |||
ASSERT_TRUE(std::string(ios.inputs->name) == "data"); | |||
ASSERT_TRUE(ios.inputs->config_layout.ndim == 4); | |||
ASSERT_TRUE(ios.inputs->config_layout.shapes[1] == 3); | |||
ASSERT_TRUE(ios.inputs->config_layout.shapes[2] == 224); | |||
ASSERT_TRUE(ios.inputs->config_layout.shapes[3] == 224); | |||
ASSERT_TRUE( | |||
std::string(ios.outputs->name) == | |||
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
ASSERT_TRUE(ios.outputs->config_layout.ndim == 2); | |||
ASSERT_TRUE(ios.outputs->config_layout.shapes[0] == 1); | |||
ASSERT_TRUE(ios.outputs->config_layout.shapes[1] == 1000); | |||
ASSERT_TRUE(std::string(ios_mem.inputs->name) == "data"); | |||
ASSERT_TRUE(ios_mem.inputs->config_layout.ndim == 4); | |||
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[1] == 3); | |||
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[2] == 224); | |||
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[3] == 224); | |||
ASSERT_TRUE( | |||
std::string(ios_mem.outputs->name) == | |||
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
ASSERT_TRUE(ios_mem.outputs->config_layout.ndim == 2); | |||
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[0] == 1); | |||
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000); | |||
} | |||
#if LITE_BUILD_WITH_RKNPU | |||
static int GetTop( | |||