GitOrigin-RevId: e499f3ebf8
tags/v1.9.0
@@ -373,6 +373,14 @@ public: | |||||
//! dump network after global layout transform optimization | //! dump network after global layout transform optimization | ||||
static void dump_layout_transform_model( | static void dump_layout_transform_model( | ||||
std::shared_ptr<Network> network, std::string optimized_model_path); | std::shared_ptr<Network> network, std::string optimized_model_path); | ||||
//! get the model io information before model loaded by model path. | |||||
static NetworkIO get_model_io_info( | |||||
const std::string& model_path, const Config& config = {}); | |||||
//! get the model io information before model loaded by model memory. | |||||
static NetworkIO get_model_io_info( | |||||
const void* model_mem, size_t size, const Config& config = {}); | |||||
}; | }; | ||||
} // namespace lite | } // namespace lite | ||||
@@ -588,6 +588,28 @@ LITE_API int LITE_enable_global_layout_transform(LiteNetwork network); | |||||
LITE_API int LITE_dump_layout_transform_model( | LITE_API int LITE_dump_layout_transform_model( | ||||
LiteNetwork network, const char* dump_file_path); | LiteNetwork network, const char* dump_file_path); | ||||
/**! get the model io information before model loaded by model path. | |||||
* \param[in] model_path The model file path | |||||
* \param[in] config The model config for loading | |||||
* \param[out] ios The model io infermation | |||||
* \return int if the return is not zero, error happened, the error message | |||||
* can get by LITE_get_last_error | |||||
*/ | |||||
LITE_API int LITE_get_model_io_info_by_path( | |||||
const char* model_path, const LiteConfig config, LiteNetworkIO* ios); | |||||
/** get the model io information before model loaded by model memory. | |||||
* \param[in] model_mem The model memory ptr | |||||
* \param[in] size The model memory ptr length | |||||
* \param[in] config The model config for loading | |||||
* \param[out] ios The model io infermation | |||||
* \return int if the return is not zero, error happened, the error message | |||||
* can get by LITE_get_last_error | |||||
*/ | |||||
LITE_API int LITE_get_model_io_info_by_memory( | |||||
const void* model_mem, size_t size, const LiteConfig config, | |||||
LiteNetworkIO* ios); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
@@ -167,6 +167,31 @@ lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) { | |||||
return network_io; | return network_io; | ||||
} | } | ||||
struct InnerIO { | |||||
std::vector<std::string> names; | |||||
std::vector<LiteIO> inputs; | |||||
std::vector<LiteIO> outputs; | |||||
}; | |||||
InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { | |||||
InnerIO innner_io; | |||||
for (size_t i = 0; i < network_io.inputs.size(); i++) { | |||||
lite::IO io = network_io.inputs[i]; | |||||
innner_io.names.push_back(io.name); | |||||
innner_io.inputs.push_back( | |||||
{innner_io.names.back().c_str(), io.is_host, io.io_type, | |||||
convert_to_clayout(io.config_layout)}); | |||||
} | |||||
for (size_t i = 0; i < network_io.outputs.size(); i++) { | |||||
lite::IO io = network_io.outputs[i]; | |||||
innner_io.names.push_back(io.name); | |||||
innner_io.outputs.push_back( | |||||
{innner_io.names.back().c_str(), io.is_host, io.io_type, | |||||
convert_to_clayout(io.config_layout)}); | |||||
} | |||||
return innner_io; | |||||
} | |||||
int LITE_make_default_network(LiteNetwork* network) { | int LITE_make_default_network(LiteNetwork* network) { | ||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
@@ -665,4 +690,59 @@ int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_ | |||||
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path); | lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
namespace { | |||||
static LITE_MUTEX mtx_io; | |||||
static std::unordered_map<const void*, InnerIO>& get_global_io_holder() { | |||||
static std::unordered_map<const void*, InnerIO> global_holder; | |||||
return global_holder; | |||||
} | |||||
int write_ios_from_cpp_io( | |||||
const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) { | |||||
LITE_CAPI_BEGIN(); | |||||
LITE_LOCK_GUARD(mtx_io); | |||||
get_global_io_holder()[key] = convert_to_inner_io(cpp_io); | |||||
auto&& inner_io = get_global_io_holder()[key]; | |||||
ios->input_size = inner_io.inputs.size(); | |||||
ios->output_size = inner_io.outputs.size(); | |||||
ios->inputs = inner_io.inputs.data(); | |||||
ios->outputs = inner_io.outputs.data(); | |||||
size_t i = 0; | |||||
for (; i < ios->input_size; i++) { | |||||
auto io_ptr = ios->inputs + i; | |||||
io_ptr->name = inner_io.names[i].c_str(); | |||||
} | |||||
for (; i < ios->output_size; i++) { | |||||
auto io_ptr = ios->outputs + i; | |||||
io_ptr->name = inner_io.names[i].c_str(); | |||||
} | |||||
LITE_CAPI_END(); | |||||
} | |||||
} // namespace | |||||
int LITE_get_model_io_info_by_path( | |||||
const char* model_path, const LiteConfig config, LiteNetworkIO* ios) { | |||||
LITE_CAPI_BEGIN(); | |||||
LITE_ASSERT(model_path, "The model_path pass to LITE api is null"); | |||||
auto&& cpp_ios = lite::Runtime::get_model_io_info( | |||||
std::string{model_path}, convert_to_lite_config(config)); | |||||
return write_ios_from_cpp_io( | |||||
cpp_ios, ios, reinterpret_cast<const void*>(model_path)); | |||||
LITE_CAPI_END(); | |||||
} | |||||
int LITE_get_model_io_info_by_memory( | |||||
const void* model_mem, size_t size, const LiteConfig config, | |||||
LiteNetworkIO* ios) { | |||||
LITE_CAPI_BEGIN(); | |||||
LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null"); | |||||
auto&& cpp_ios = lite::Runtime::get_model_io_info( | |||||
model_mem, size, convert_to_lite_config(config)); | |||||
return write_ios_from_cpp_io( | |||||
cpp_ios, ios, reinterpret_cast<const void*>(model_mem)); | |||||
LITE_CAPI_END(); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -364,6 +364,14 @@ class _NetworkAPI(_LiteCObjBase): | |||||
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ||||
("LITE_enable_global_layout_transform", [_Cnetwork]), | ("LITE_enable_global_layout_transform", [_Cnetwork]), | ||||
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), | ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), | ||||
( | |||||
"LITE_get_model_io_info_by_path", | |||||
[c_char_p, LiteConfig, POINTER(_LiteNetworkIO)], | |||||
), | |||||
( | |||||
"LITE_get_model_io_info_by_memory", | |||||
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | |||||
), | |||||
] | ] | ||||
@@ -619,3 +627,27 @@ class LiteNetwork(object): | |||||
def dump_layout_transform_model(self, model_file): | def dump_layout_transform_model(self, model_file): | ||||
c_file = model_file.encode("utf-8") | c_file = model_file.encode("utf-8") | ||||
self._api.LITE_dump_layout_transform_model(self._network, c_file) | self._api.LITE_dump_layout_transform_model(self._network, c_file) | ||||
def get_model_io_info(model_path, config=None): | |||||
""" | |||||
get the model IO information before create the NetWork, this IO | |||||
information can be used to configuration the NetWork. | |||||
""" | |||||
api = _NetworkAPI()._lib | |||||
c_path = c_char_p(model_path.encode("utf-8")) | |||||
ios = _LiteNetworkIO() | |||||
if config is not None: | |||||
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) | |||||
else: | |||||
config = LiteConfig() | |||||
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) | |||||
ret_ios = LiteNetworkIO() | |||||
for i in range(ios.input_size): | |||||
ret_ios.add_input(ios.inputs[i]) | |||||
for i in range(ios.output_size): | |||||
ret_ios.add_output(ios.outputs[i]) | |||||
return ret_ios |
@@ -8,6 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import functools | import functools | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
@@ -200,3 +201,20 @@ def test_tensor_collect_batch_device_numpy(): | |||||
for i in range(4): | for i in range(4): | ||||
for j in range(48): | for j in range(48): | ||||
assert data[i][j // 8][j % 8] == i + 1 | assert data[i][j // 8][j % 8] == i + 1 | ||||
def test_get_model_io_ahead(): | |||||
source_dir = os.getenv("LITE_TEST_RESOURCE") | |||||
model_path = os.path.join(source_dir, "shufflenet.mge") | |||||
ios = get_model_io_info(model_path) | |||||
assert len(ios.inputs) == 1 | |||||
assert ios.inputs[0].name == "data" | |||||
assert ios.inputs[0].config_layout.shapes[1] == 3 | |||||
assert ios.inputs[0].config_layout.shapes[2] == 224 | |||||
assert ios.inputs[0].config_layout.shapes[3] == 224 | |||||
assert len(ios.outputs) == 1 | |||||
assert ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]" | |||||
assert ios.outputs[0].config_layout.shapes[0] == 1 | |||||
assert ios.outputs[0].config_layout.shapes[1] == 1000 |
@@ -34,7 +34,7 @@ ADD_STATEMENT(NetworkImplDft, Dft); | |||||
} // namespace | } // namespace | ||||
// if it can't find the function, ignore | // if it can't find the function, ignore | ||||
template <typename tensor_type, typename ret_type, typename... Args> | |||||
template <typename type, typename ret_type, typename... Args> | |||||
ret_type try_call_func(std::string func_name, Args... args) { | ret_type try_call_func(std::string func_name, Args... args) { | ||||
mark_used_variable(func_name); | mark_used_variable(func_name); | ||||
mark_used_variable(args...); | mark_used_variable(args...); | ||||
@@ -42,10 +42,10 @@ ret_type try_call_func(std::string func_name, Args... args) { | |||||
} | } | ||||
// if it can't find the function, throw error | // if it can't find the function, throw error | ||||
template <typename tensor_type, typename ret_type, typename... Args> | |||||
template <typename type, typename ret_type, typename... Args> | |||||
ret_type call_func(std::string func_name, Args... args) { | ret_type call_func(std::string func_name, Args... args) { | ||||
mark_used_variable(args...); | mark_used_variable(args...); | ||||
auto backend_name = class_type_name<tensor_type>()(); | |||||
auto backend_name = class_type_name<type>()(); | |||||
auto msg_info = func_name + " is not aviliable in " + backend_name + " backend."; | auto msg_info = func_name + " is not aviliable in " + backend_name + " backend."; | ||||
LITE_THROW(msg_info.c_str()); | LITE_THROW(msg_info.c_str()); | ||||
} | } | ||||
@@ -206,6 +206,26 @@ inline void call_func<NetworkImplDft, void>( | |||||
THROW_FUNC_ERROR(func_name); | THROW_FUNC_ERROR(func_name); | ||||
} | } | ||||
} | } | ||||
template <> | |||||
inline NetworkIO call_func<NetworkImplDft, NetworkIO>( | |||||
std::string func_name, std::string model_path, Config config) { | |||||
if (func_name == "get_model_io_info") { | |||||
return get_model_io_info_dft(model_path, config); | |||||
} else { | |||||
THROW_FUNC_ERROR(func_name); | |||||
} | |||||
} | |||||
template <> | |||||
inline NetworkIO call_func<NetworkImplDft, NetworkIO>( | |||||
std::string func_name, const void* model_mem, size_t size, Config config) { | |||||
if (func_name == "get_model_io_info") { | |||||
return get_model_io_info_dft(model_mem, size, config); | |||||
} else { | |||||
THROW_FUNC_ERROR(func_name); | |||||
} | |||||
} | |||||
#undef THROW_FUNC_ERROR | #undef THROW_FUNC_ERROR | ||||
} // namespace lite | } // namespace lite | ||||
@@ -929,5 +929,75 @@ void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_pat | |||||
"enable_global_layout_transform before")); | "enable_global_layout_transform before")); | ||||
} | } | ||||
} | } | ||||
NetworkIO lite::get_model_io_info_dft( | |||||
const std::string& model_path, const Config& config) { | |||||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||||
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||||
fseek(fin, 0, SEEK_END); | |||||
size_t size = ftell(fin); | |||||
fseek(fin, 0, SEEK_SET); | |||||
void* ptr = malloc(size); | |||||
std::shared_ptr<void> buf{ptr, ::free}; | |||||
auto nr = fread(buf.get(), 1, size, fin); | |||||
LITE_ASSERT(nr == size); | |||||
fclose(fin); | |||||
return get_model_io_info_dft(ptr, size, config); | |||||
} | |||||
NetworkIO lite::get_model_io_info_dft( | |||||
const void* model_mem, size_t size, const Config& config) { | |||||
std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}}; | |||||
auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false); | |||||
auto format = | |||||
mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file); | |||||
if (!format.valid()) { | |||||
LITE_THROW("invalid model format"); | |||||
} | |||||
auto loader = | |||||
mgb::serialization::GraphLoader::make(std::move(input_file), format.val()); | |||||
mgb::serialization::GraphLoadConfig load_config; | |||||
load_config.comp_graph = mgb::ComputingGraph::make(); | |||||
if (config.has_compression) { | |||||
load_config.tensor_value_loader = decompressed_tensor_value_loader; | |||||
} | |||||
auto compnode_locator = to_compnode_locator(config.device_type); | |||||
load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) { | |||||
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { | |||||
loc.type = compnode_locator.type; | |||||
} | |||||
loc.device = compnode_locator.device; | |||||
}; | |||||
auto load_result = loader->load(load_config, true); | |||||
NetworkIO IOs; | |||||
for (auto&& in_tensor_iter : load_result.tensor_map) { | |||||
IO in_io; | |||||
in_io.name = in_tensor_iter.first; | |||||
in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout()); | |||||
IOs.inputs.push_back(in_io); | |||||
} | |||||
auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* { | |||||
auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager(); | |||||
using InferType = mgb::cg::static_infer::InferType; | |||||
if (static_infer_mgr.get_infer_type(var.node()).shape & | |||||
(InferType::CONST | InferType::RT_STATIC)) { | |||||
return static_infer_mgr.infer_shape_fallible(var.node()); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
}; | |||||
for (auto&& out : load_result.output_var_list) { | |||||
IO out_io; | |||||
out_io.name = out.node()->name(); | |||||
if (auto shape = infer_shape(out)) { | |||||
out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()}); | |||||
} else { | |||||
out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()}); | |||||
} | |||||
IOs.outputs.push_back(out_io); | |||||
} | |||||
return IOs; | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -262,6 +262,13 @@ private: | |||||
#endif | #endif | ||||
std::unique_ptr<mgb::OprIODumpBase> m_iodump; | std::unique_ptr<mgb::OprIODumpBase> m_iodump; | ||||
}; | }; | ||||
//! get the model information before model loaded by Network | |||||
NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config); | |||||
//! get the model information before model loaded by Network by model memory and | |||||
//! size | |||||
NetworkIO get_model_io_info_dft( | |||||
const void* model_mem, size_t size, const Config& config); | |||||
} // namespace lite | } // namespace lite | ||||
@@ -534,4 +534,26 @@ void Runtime::dump_layout_transform_model( | |||||
LITE_THROW("dump_layout_transform_model is not aviliable in the backend."); | LITE_THROW("dump_layout_transform_model is not aviliable in the backend."); | ||||
LITE_ERROR_HANDLER_END | LITE_ERROR_HANDLER_END | ||||
} | } | ||||
NetworkIO Runtime::get_model_io_info( | |||||
const std::string& model_path, const Config& config) { | |||||
LITE_ERROR_HANDLER_BEGIN | |||||
if (config.backend == LiteBackend::LITE_DEFAULT) { | |||||
return call_func<NetworkImplDft, NetworkIO>( | |||||
"get_model_io_info", model_path, config); | |||||
} | |||||
LITE_THROW("get_model_io_info is not aviliable in the backend."); | |||||
LITE_ERROR_HANDLER_END | |||||
} | |||||
NetworkIO Runtime::get_model_io_info( | |||||
const void* model_mem, size_t size, const Config& config) { | |||||
LITE_ERROR_HANDLER_BEGIN | |||||
if (config.backend == LiteBackend::LITE_DEFAULT) { | |||||
return call_func<NetworkImplDft, NetworkIO>( | |||||
"get_model_io_info", model_mem, size, config); | |||||
} | |||||
LITE_THROW("get_model_io_info is not aviliable in the backend."); | |||||
LITE_ERROR_HANDLER_END | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -106,6 +106,54 @@ TEST(TestNetWork, GetAllName) { | |||||
ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | ||||
} | } | ||||
TEST(TestNetWork, GetAllIoInfoAhead) { | |||||
Config config; | |||||
std::string model_path = "./shufflenet.mge"; | |||||
auto ios = Runtime::get_model_io_info(model_path); | |||||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||||
ASSERT_TRUE(fin); | |||||
fseek(fin, 0, SEEK_END); | |||||
size_t size = ftell(fin); | |||||
fseek(fin, 0, SEEK_SET); | |||||
void* ptr = malloc(size); | |||||
std::shared_ptr<void> buf{ptr, ::free}; | |||||
auto nr = fread(buf.get(), 1, size, fin); | |||||
LITE_ASSERT(nr == size); | |||||
fclose(fin); | |||||
auto ios_mem = Runtime::get_model_io_info(ptr, size); | |||||
ASSERT_EQ(ios.inputs.size(), ios_mem.inputs.size()); | |||||
ASSERT_EQ(ios.inputs.size(), 1); | |||||
ASSERT_EQ(ios.outputs.size(), ios_mem.outputs.size()); | |||||
ASSERT_EQ(ios.outputs.size(), 1); | |||||
ASSERT_TRUE(ios.inputs[0].name == "data"); | |||||
ASSERT_TRUE(ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||||
ASSERT_TRUE(ios_mem.inputs[0].name == "data"); | |||||
ASSERT_TRUE( | |||||
ios_mem.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||||
ASSERT_EQ(ios.inputs[0].config_layout.ndim, 4); | |||||
ASSERT_EQ(ios.inputs[0].config_layout.shapes[1], 3); | |||||
ASSERT_EQ(ios.inputs[0].config_layout.shapes[2], 224); | |||||
ASSERT_EQ(ios.outputs[0].config_layout.ndim, 2); | |||||
ASSERT_EQ(ios.outputs[0].config_layout.shapes[0], 1); | |||||
ASSERT_EQ(ios.outputs[0].config_layout.shapes[1], 1000); | |||||
ASSERT_EQ(ios_mem.inputs[0].config_layout.ndim, 4); | |||||
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[1], 3); | |||||
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[2], 224); | |||||
ASSERT_EQ(ios_mem.outputs[0].config_layout.ndim, 2); | |||||
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[0], 1); | |||||
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[1], 1000); | |||||
} | |||||
TEST(TestNetWork, LoadFBSModel) { | TEST(TestNetWork, LoadFBSModel) { | ||||
Config config; | Config config; | ||||
std::string model_path = "./ax.mge"; | std::string model_path = "./ax.mge"; | ||||
@@ -252,6 +252,55 @@ TEST(TestCapiNetWork, GetAllName) { | |||||
LITE_destroy_network(c_network); | LITE_destroy_network(c_network); | ||||
} | } | ||||
TEST(TestCapiNetWork, GetAllNameAhead) { | |||||
std::string model_path = "./shufflenet.mge"; | |||||
LiteNetworkIO ios, ios_mem; | |||||
LITE_CAPI_CHECK(LITE_get_model_io_info_by_path( | |||||
model_path.c_str(), *default_config(), &ios)); | |||||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||||
ASSERT_TRUE(fin); | |||||
fseek(fin, 0, SEEK_END); | |||||
size_t size = ftell(fin); | |||||
fseek(fin, 0, SEEK_SET); | |||||
void* ptr = malloc(size); | |||||
std::shared_ptr<void> buf{ptr, ::free}; | |||||
auto nr = fread(buf.get(), 1, size, fin); | |||||
LITE_ASSERT(nr == size); | |||||
fclose(fin); | |||||
LITE_CAPI_CHECK( | |||||
LITE_get_model_io_info_by_memory(ptr, size, *default_config(), &ios_mem)); | |||||
ASSERT_EQ(ios.input_size, 1); | |||||
ASSERT_EQ(ios.output_size, 1); | |||||
ASSERT_EQ(ios_mem.input_size, 1); | |||||
ASSERT_EQ(ios_mem.output_size, 1); | |||||
ASSERT_TRUE(std::string(ios.inputs->name) == "data"); | |||||
ASSERT_TRUE(ios.inputs->config_layout.ndim == 4); | |||||
ASSERT_TRUE(ios.inputs->config_layout.shapes[1] == 3); | |||||
ASSERT_TRUE(ios.inputs->config_layout.shapes[2] == 224); | |||||
ASSERT_TRUE(ios.inputs->config_layout.shapes[3] == 224); | |||||
ASSERT_TRUE( | |||||
std::string(ios.outputs->name) == | |||||
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||||
ASSERT_TRUE(ios.outputs->config_layout.ndim == 2); | |||||
ASSERT_TRUE(ios.outputs->config_layout.shapes[0] == 1); | |||||
ASSERT_TRUE(ios.outputs->config_layout.shapes[1] == 1000); | |||||
ASSERT_TRUE(std::string(ios_mem.inputs->name) == "data"); | |||||
ASSERT_TRUE(ios_mem.inputs->config_layout.ndim == 4); | |||||
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[1] == 3); | |||||
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[2] == 224); | |||||
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[3] == 224); | |||||
ASSERT_TRUE( | |||||
std::string(ios_mem.outputs->name) == | |||||
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||||
ASSERT_TRUE(ios_mem.outputs->config_layout.ndim == 2); | |||||
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[0] == 1); | |||||
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000); | |||||
} | |||||
#if LITE_BUILD_WITH_RKNPU | #if LITE_BUILD_WITH_RKNPU | ||||
static int GetTop( | static int GetTop( | ||||