Browse Source

feat(lite): add get static mem info function in lite c++

GitOrigin-RevId: 8c9e42a744
release-1.7
Megvii Engine Team 3 years ago
parent
commit
8bdcf6b5a6
6 changed files with 68 additions and 0 deletions
  1. +3
    -0
      lite/include/lite/network.h
  2. +10
    -0
      lite/src/mge/network_impl.cpp
  3. +4
    -0
      lite/src/mge/network_impl.h
  4. +14
    -0
      lite/src/network.cpp
  5. +8
    -0
      lite/src/network_impl_base.h
  6. +29
    -0
      lite/test/test_network.cpp

+ 3
- 0
lite/include/lite/network.h View File

@@ -282,6 +282,9 @@ public:
//! get device type
LiteDeviceType get_device_type() const;

//! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const;

public:
friend class NetworkHelper;



+ 10
- 0
lite/src/mge/network_impl.cpp View File

@@ -778,6 +778,16 @@ void inline NetworkImplDft::output_plugin_result() const {
}
#endif
}

void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
m_execute_func->get_static_memory_alloc_info(log_dir);
return;
#endif
#endif
LITE_MARK_USED_VAR(log_dir);
}

#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 4
- 0
lite/src/mge/network_impl.h View File

@@ -163,6 +163,10 @@ public:
//! directory, in binary format
void enable_io_bin_dump(std::string io_bin_out_dir);

//! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(
const std::string& log_dir = "logs/test") const override;

private:
//! construct the outputspec according to the m_network_io, and set the
//! call_back to the outputspec


+ 14
- 0
lite/src/network.cpp View File

@@ -283,6 +283,20 @@ LiteDeviceType Network::get_device_type() const {
LITE_ERROR_HANDLER_END
}

void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_BEGIN
#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
LITE_ASSERT(m_loaded, "get_all_output_name should be used after model loaded.");
m_impl->get_static_memory_alloc_info(log_dir);
return;
#endif
#endif
LITE_MARK_USED_VAR(log_dir);
LITE_THROW("Doesn't support get_static_memory_alloc_info().Please check macro.");
LITE_ERROR_HANDLER_END
}

/*********************** MGE special network function ***************/

void Runtime::set_cpu_threads_number(


+ 8
- 0
lite/src/network_impl_base.h View File

@@ -125,6 +125,14 @@ public:

//! enable profile the network, a file will be generated
virtual void enable_profile_performance(std::string profile_file_path) = 0;

//! get static peak memory info showed by Graph visualization
virtual void get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_MARK_USED_VAR(log_dir);
LITE_THROW(
"This nerworkimpl doesn't support get_static_memory_alloc_info() "
"function.");
}
};

/******************************** friend class *****************************/


+ 29
- 0
lite/test/test_network.cpp View File

@@ -646,6 +646,35 @@ TEST(TestNetWork, GetModelExtraInfo) {
printf("extra_info %s \n", extra_info.c_str());
}

#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
TEST(TestNetWork, GetMemoryInfo) {
Config config;
auto lite_tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";

auto result_mgb = mgb_lar(model_path, config, "data", lite_tensor);

std::shared_ptr<Network> network = std::make_shared<Network>(config);
Runtime::set_cpu_threads_number(network, 2);

network->load_model(model_path);
network->get_static_memory_alloc_info();
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);

auto src_ptr = lite_tensor->get_memory_ptr();
auto src_layout = lite_tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);

network->forward();
network->wait();
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);

compare_lite_tensor<float>(output_tensor, result_mgb);
}
#endif
#endif

#if LITE_WITH_CUDA

TEST(TestNetWork, BasicDevice) {


Loading…
Cancel
Save