Browse Source

feat(lite): add global layout transform c/c++ interface for lite

GitOrigin-RevId: 36a4b26b42
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
e70c07a223
9 changed files with 178 additions and 12 deletions
  1. +9
    -1
      lite/include/lite/network.h
  2. +16
    -0
      lite/lite-c/include/lite-c/network_c.h
  3. +17
    -0
      lite/lite-c/src/network.cpp
  4. +4
    -0
      lite/src/mge/function_dft.h
  5. +48
    -8
      lite/src/mge/network_impl.cpp
  6. +16
    -3
      lite/src/mge/network_impl.h
  7. +29
    -0
      lite/src/network.cpp
  8. +24
    -0
      lite/test/test_network.cpp
  9. +15
    -0
      lite/test/test_network_c.cpp

+ 9
- 1
lite/include/lite/network.h View File

@@ -97,7 +97,7 @@ struct LITE_API Options {
bool no_profiling_on_shape_change = false; bool no_profiling_on_shape_change = false;
uint8_t jit_level = 0; uint8_t jit_level = 0;
uint8_t comp_node_seq_record_level = 0; uint8_t comp_node_seq_record_level = 0;
uint8_t graph_opt_level = 2;
uint8_t graph_opt_level = 0;
uint16_t async_exec_level = 1; uint16_t async_exec_level = 1;


//! layout transform options //! layout transform options
@@ -366,6 +366,14 @@ public:
static void shared_weight_with_network( static void shared_weight_with_network(
std::shared_ptr<Network> dst_network, std::shared_ptr<Network> dst_network,
const std::shared_ptr<Network> src_network); const std::shared_ptr<Network> src_network);

//! set global layout transform optimization for network

static void enable_global_layout_transform(std::shared_ptr<Network> network);

//! dump network after global layout transform optimization
static void dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path);
}; };


} // namespace lite } // namespace lite


+ 16
- 0
lite/lite-c/include/lite-c/network_c.h View File

@@ -572,6 +572,22 @@ LITE_API int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out
LITE_API int LITE_get_static_memory_alloc_info( LITE_API int LITE_get_static_memory_alloc_info(
LiteNetwork network, const char* log_dir); LiteNetwork network, const char* log_dir);


/**
* \brief enable the global layout transform optimization
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_enable_global_layout_transform(LiteNetwork network);

/**
* \brief dump the model after the global layout transform optimization
* \param[in] dump_file_path The model file path need to dump
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_dump_layout_transform_model(
LiteNetwork network, const char* dump_file_path);

#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif


+ 17
- 0
lite/lite-c/src/network.cpp View File

@@ -648,4 +648,21 @@ int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir)
LITE_CAPI_END(); LITE_CAPI_END();
} }


int LITE_enable_global_layout_transform(LiteNetwork network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
std::shared_ptr<lite::Network> network_shared{
static_cast<lite::Network*>(network), [](void*) {}};
lite::Runtime::enable_global_layout_transform(network_shared);
LITE_CAPI_END();
}

int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_path) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
std::shared_ptr<lite::Network> network_shared{
static_cast<lite::Network*>(network), [](void*) {}};
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
LITE_CAPI_END();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 4
- 0
lite/src/mge/function_dft.h View File

@@ -121,6 +121,8 @@ inline void call_func<NetworkImplDft, void>(
CALL_FUNC(use_tensorrt); CALL_FUNC(use_tensorrt);
} else if (func_name == "set_cpu_inplace_mode") { } else if (func_name == "set_cpu_inplace_mode") {
CALL_FUNC(set_cpu_inplace_mode); CALL_FUNC(set_cpu_inplace_mode);
} else if (func_name == "enable_global_layout_transform") {
CALL_FUNC(enable_global_layout_transform);
} else { } else {
THROW_FUNC_ERROR(func_name); THROW_FUNC_ERROR(func_name);
} }
@@ -186,6 +188,8 @@ inline void call_func<NetworkImplDft, void>(
return CALL_FUNC(enable_io_txt_dump, file_name); return CALL_FUNC(enable_io_txt_dump, file_name);
} else if (func_name == "enable_io_bin_dump") { } else if (func_name == "enable_io_bin_dump") {
return CALL_FUNC(enable_io_bin_dump, file_name); return CALL_FUNC(enable_io_bin_dump, file_name);
} else if (func_name == "dump_layout_transform_model") {
return CALL_FUNC(dump_layout_transform_model, file_name);
} }
THROW_FUNC_ERROR(func_name); THROW_FUNC_ERROR(func_name);
} }


+ 48
- 8
lite/src/mge/network_impl.cpp View File

@@ -22,7 +22,6 @@
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/graph/cg.h" #include "megbrain/graph/cg.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
@@ -364,19 +363,26 @@ void NetworkImplDft::adapt_option_valid() {
} }
} }


void NetworkImplDft::global_layout_transform() {
if (m_set_layout_transform) {
m_load_result.output_var_list = mgb::gopt::layout_transform(
m_load_result.output_var_list, m_layout_transform_target);
}
}

void NetworkImplDft::load_model( void NetworkImplDft::load_model(
std::shared_ptr<void> model_mem, size_t size, std::shared_ptr<void> model_mem, size_t size,
std::unordered_map<std::string, LiteAny> separate_config_map) { std::unordered_map<std::string, LiteAny> separate_config_map) {
if (!m_loader) { if (!m_loader) {
m_input_file = m_input_file =
mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false); mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
auto format = mgb::serialization::GraphLoader::identify_graph_dump_format(
m_format = mgb::serialization::GraphLoader::identify_graph_dump_format(
*m_input_file); *m_input_file);
if (!format.valid()) {
if (!m_format.valid()) {
LITE_THROW("invalid model format"); LITE_THROW("invalid model format");
} }
m_loader = mgb::serialization::GraphLoader::make( m_loader = mgb::serialization::GraphLoader::make(
std::move(m_input_file), format.val());
std::move(m_input_file), m_format.val());
} }


//! applay the user configration to mge model //! applay the user configration to mge model
@@ -400,7 +406,9 @@ void NetworkImplDft::load_model(
use_tensorrt(); use_tensorrt();
} }


m_load_result = m_loader->load(m_load_config, true);
m_load_result = m_loader->load(m_load_config, false);

global_layout_transform();


adapt_option_valid(); adapt_option_valid();


@@ -847,9 +855,6 @@ const char* NetworkImplDft::get_input_name(size_t index) const {
//! Plugin part //! Plugin part
void NetworkImplDft::enable_profile_performance(std::string profile_json_file) { void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
#if MGB_OPENCL
mgb::CompNode::enable_opencl_profile(true);
#endif
m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get()); m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
m_profiler_output_file = profile_json_file; m_profiler_output_file = profile_json_file;
#else #else
@@ -889,5 +894,40 @@ void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) co
LITE_MARK_USED_VAR(log_dir); LITE_MARK_USED_VAR(log_dir);
} }


void NetworkImplDft::enable_global_layout_transform() {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;

switch (m_user_config->device_type) {
case LiteDeviceType::LITE_CPU:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
break;
case LiteDeviceType::LITE_CUDA:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
break;
default:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
LITE_WARN(
"lite compnode type: enum value: %d. is unspecial for layout "
"transform",
(int)(m_user_config->device_type));
}
m_set_layout_transform = true;
}

void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) {
if (m_set_layout_transform) {
auto out_file = mgb::serialization::OutputFile::make_fs(
optimized_model_path.c_str(), 'w');
using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
DumpConfig config{1, false, false};
auto dumper = mgb::serialization::GraphDumper::make(
std::move(out_file), m_format.val());
dumper->dump(m_load_result.output_var_list, config);
} else {
LITE_THROW(
ssprintf("dump layout transform model should call "
"enable_global_layout_transform before"));
}
}
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 16
- 3
lite/src/mge/network_impl.h View File

@@ -19,6 +19,9 @@
#include "network_impl_base.h" #include "network_impl_base.h"
#include "tensor_impl.h" #include "tensor_impl.h"


#include <memory>
#include <unordered_map>
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/bases.h" #include "megbrain/graph/bases.h"
#include "megbrain/plugin/opr_io_dump.h" #include "megbrain/plugin/opr_io_dump.h"
#include "megbrain/plugin/profiler.h" #include "megbrain/plugin/profiler.h"
@@ -28,9 +31,6 @@
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/thin/hash_table.h"


#include <memory>
#include <unordered_map>

namespace lite { namespace lite {


/*! /*!
@@ -170,11 +170,20 @@ public:
void get_static_memory_alloc_info( void get_static_memory_alloc_info(
const std::string& log_dir = "logs/test") const override; const std::string& log_dir = "logs/test") const override;


//! set global layout transform optimization for network
void enable_global_layout_transform();

//! dump network after global layout transform optimization
void dump_layout_transform_model(std::string optimized_model_path);

private: private:
//! construct the outputspec according to the m_network_io, and set the //! construct the outputspec according to the m_network_io, and set the
//! call_back to the outputspec //! call_back to the outputspec
void make_output_spec(); void make_output_spec();


//! do the global layout transform for the given platform target
void global_layout_transform();

//! modify the execution policy //! modify the execution policy
void modify_exection_policy(); void modify_exection_policy();


@@ -223,6 +232,7 @@ private:
int m_nr_device_type = 0; int m_nr_device_type = 0;
size_t m_nr_threads = 1; size_t m_nr_threads = 1;
bool m_compute_configured_output_only = false; bool m_compute_configured_output_only = false;
bool m_set_layout_transform = false;
mgb::CompNode::Locator m_compnode_locator; mgb::CompNode::Locator m_compnode_locator;


AsyncCallback m_async_callback = nullptr; AsyncCallback m_async_callback = nullptr;
@@ -233,6 +243,9 @@ private:
//! The model load related data //! The model load related data
S m_execution_policy = static_cast<S>(0); S m_execution_policy = static_cast<S>(0);
std::unique_ptr<mgb::serialization::InputFile> m_input_file; std::unique_ptr<mgb::serialization::InputFile> m_input_file;
mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format;
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;

mgb::serialization::GraphLoadConfig m_load_config; mgb::serialization::GraphLoadConfig m_load_config;
mgb::serialization::GraphLoader::LoadResult m_load_result; mgb::serialization::GraphLoader::LoadResult m_load_result;
mgb::ComputingGraph::OutputSpec m_output_spec; mgb::ComputingGraph::OutputSpec m_output_spec;


+ 29
- 0
lite/src/network.cpp View File

@@ -505,4 +505,33 @@ void Runtime::shared_weight_with_network(
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }


void Runtime::enable_global_layout_transform(std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"enable_global_layout_transform should be used before model loaded.");
call_func<NetworkImplDft, void>("enable_global_layout_transform", network_impl);
return;
}
LITE_THROW("enable_global_layout_transform is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}

void Runtime::dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
NetworkHelper::loaded(network),
"dump_layout_transform_model should be used after model loaded.");
call_func<NetworkImplDft, void>(
"dump_layout_transform_model", network_impl, optimized_model_path);
return;
}
LITE_THROW("dump_layout_transform_model is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 24
- 0
lite/test/test_network.cpp View File

@@ -909,6 +909,30 @@ TEST(TestNetWork, LoadPackedModel) {
network->wait(); network->wait();
} }


TEST(TestNetWork, GlabalLayoutTransform) {
// set_log_level(LiteLogLevel::DEBUG);
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
std::string input_name = "data";
std::string dump_model_name = "./shufflenet_after_trans.mge";

NetworkIO IO;
Config config;
std::shared_ptr<Network> network = std::make_shared<Network>(config, IO);
Runtime::enable_global_layout_transform(network);
network->load_model(model_path);

std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name);
auto src_ptr = tensor->get_memory_ptr();
auto src_layout = tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);

Runtime::dump_layout_transform_model(network, dump_model_name);
network->forward();
network->wait();
ASSERT_TRUE(fopen(dump_model_name.c_str(), "r"));
}

TEST(TestNetWork, GetDeviceType) { TEST(TestNetWork, GetDeviceType) {
auto tensor = get_input_data("./input_data.npy"); auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge"; std::string model_path = "./shufflenet.mge";


+ 15
- 0
lite/test/test_network_c.cpp View File

@@ -889,6 +889,21 @@ TEST(TestCapiNetWork, ProfileIOdump) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); LITE_CAPI_CHECK(LITE_destroy_network(c_network));
} }


TEST(TestCapiNetWork, GlabalLayoutTransform) {
ForwardMgb;
MakeNetwork;
LITE_CAPI_CHECK(LITE_enable_global_layout_transform(c_network));
LoadNetwork;
LITE_CAPI_CHECK(LITE_dump_layout_transform_model(
c_network, "./shufflenet_after_trans.mge"));
SetInput;
ForwardNetwork;
ASSERT_TRUE(fopen("./shufflenet_after_trans.mge", "r"));
GetOutput;
CompareResult;
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}

TEST(TestCapiNetWork, GetDeviceType) { TEST(TestCapiNetWork, GetDeviceType) {
lite::Config config; lite::Config config;
auto lite_tensor = lite::get_input_data("./input_data.npy"); auto lite_tensor = lite::get_input_data("./input_data.npy");


Loading…
Cancel
Save