GitOrigin-RevId: 36a4b26b42
tags/v1.8.0
@@ -97,7 +97,7 @@ struct LITE_API Options { | |||||
bool no_profiling_on_shape_change = false; | bool no_profiling_on_shape_change = false; | ||||
uint8_t jit_level = 0; | uint8_t jit_level = 0; | ||||
uint8_t comp_node_seq_record_level = 0; | uint8_t comp_node_seq_record_level = 0; | ||||
uint8_t graph_opt_level = 2; | |||||
uint8_t graph_opt_level = 0; | |||||
uint16_t async_exec_level = 1; | uint16_t async_exec_level = 1; | ||||
//! layout transform options | //! layout transform options | ||||
@@ -366,6 +366,14 @@ public: | |||||
static void shared_weight_with_network( | static void shared_weight_with_network( | ||||
std::shared_ptr<Network> dst_network, | std::shared_ptr<Network> dst_network, | ||||
const std::shared_ptr<Network> src_network); | const std::shared_ptr<Network> src_network); | ||||
//! set global layout transform optimization for network | |||||
static void enable_global_layout_transform(std::shared_ptr<Network> network); | |||||
//! dump network after global layout transform optimization | |||||
static void dump_layout_transform_model( | |||||
std::shared_ptr<Network> network, std::string optimized_model_path); | |||||
}; | }; | ||||
} // namespace lite | } // namespace lite | ||||
@@ -572,6 +572,22 @@ LITE_API int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out | |||||
LITE_API int LITE_get_static_memory_alloc_info( | LITE_API int LITE_get_static_memory_alloc_info( | ||||
LiteNetwork network, const char* log_dir); | LiteNetwork network, const char* log_dir); | ||||
/** | |||||
* \brief enable the global layout transform optimization | |||||
* \return int if the return is not zero, error happened, the error message | |||||
* can get by LITE_get_last_error | |||||
*/ | |||||
LITE_API int LITE_enable_global_layout_transform(LiteNetwork network); | |||||
/** | |||||
* \brief dump the model after the global layout transform optimization | |||||
* \param[in] dump_file_path The model file path need to dump | |||||
* \return int if the return is not zero, error happened, the error message | |||||
* can get by LITE_get_last_error | |||||
*/ | |||||
LITE_API int LITE_dump_layout_transform_model( | |||||
LiteNetwork network, const char* dump_file_path); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
@@ -648,4 +648,21 @@ int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir) | |||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
int LITE_enable_global_layout_transform(LiteNetwork network) { | |||||
LITE_CAPI_BEGIN(); | |||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | |||||
std::shared_ptr<lite::Network> network_shared{ | |||||
static_cast<lite::Network*>(network), [](void*) {}}; | |||||
lite::Runtime::enable_global_layout_transform(network_shared); | |||||
LITE_CAPI_END(); | |||||
} | |||||
int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_path) { | |||||
LITE_CAPI_BEGIN(); | |||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | |||||
std::shared_ptr<lite::Network> network_shared{ | |||||
static_cast<lite::Network*>(network), [](void*) {}}; | |||||
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path); | |||||
LITE_CAPI_END(); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -121,6 +121,8 @@ inline void call_func<NetworkImplDft, void>( | |||||
CALL_FUNC(use_tensorrt); | CALL_FUNC(use_tensorrt); | ||||
} else if (func_name == "set_cpu_inplace_mode") { | } else if (func_name == "set_cpu_inplace_mode") { | ||||
CALL_FUNC(set_cpu_inplace_mode); | CALL_FUNC(set_cpu_inplace_mode); | ||||
} else if (func_name == "enable_global_layout_transform") { | |||||
CALL_FUNC(enable_global_layout_transform); | |||||
} else { | } else { | ||||
THROW_FUNC_ERROR(func_name); | THROW_FUNC_ERROR(func_name); | ||||
} | } | ||||
@@ -186,6 +188,8 @@ inline void call_func<NetworkImplDft, void>( | |||||
return CALL_FUNC(enable_io_txt_dump, file_name); | return CALL_FUNC(enable_io_txt_dump, file_name); | ||||
} else if (func_name == "enable_io_bin_dump") { | } else if (func_name == "enable_io_bin_dump") { | ||||
return CALL_FUNC(enable_io_bin_dump, file_name); | return CALL_FUNC(enable_io_bin_dump, file_name); | ||||
} else if (func_name == "dump_layout_transform_model") { | |||||
return CALL_FUNC(dump_layout_transform_model, file_name); | |||||
} | } | ||||
THROW_FUNC_ERROR(func_name); | THROW_FUNC_ERROR(func_name); | ||||
} | } | ||||
@@ -22,7 +22,6 @@ | |||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
#include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/graph/cg.h" | #include "megbrain/graph/cg.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
@@ -364,19 +363,26 @@ void NetworkImplDft::adapt_option_valid() { | |||||
} | } | ||||
} | } | ||||
void NetworkImplDft::global_layout_transform() { | |||||
if (m_set_layout_transform) { | |||||
m_load_result.output_var_list = mgb::gopt::layout_transform( | |||||
m_load_result.output_var_list, m_layout_transform_target); | |||||
} | |||||
} | |||||
void NetworkImplDft::load_model( | void NetworkImplDft::load_model( | ||||
std::shared_ptr<void> model_mem, size_t size, | std::shared_ptr<void> model_mem, size_t size, | ||||
std::unordered_map<std::string, LiteAny> separate_config_map) { | std::unordered_map<std::string, LiteAny> separate_config_map) { | ||||
if (!m_loader) { | if (!m_loader) { | ||||
m_input_file = | m_input_file = | ||||
mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false); | mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false); | ||||
auto format = mgb::serialization::GraphLoader::identify_graph_dump_format( | |||||
m_format = mgb::serialization::GraphLoader::identify_graph_dump_format( | |||||
*m_input_file); | *m_input_file); | ||||
if (!format.valid()) { | |||||
if (!m_format.valid()) { | |||||
LITE_THROW("invalid model format"); | LITE_THROW("invalid model format"); | ||||
} | } | ||||
m_loader = mgb::serialization::GraphLoader::make( | m_loader = mgb::serialization::GraphLoader::make( | ||||
std::move(m_input_file), format.val()); | |||||
std::move(m_input_file), m_format.val()); | |||||
} | } | ||||
//! applay the user configration to mge model | //! applay the user configration to mge model | ||||
@@ -400,7 +406,9 @@ void NetworkImplDft::load_model( | |||||
use_tensorrt(); | use_tensorrt(); | ||||
} | } | ||||
m_load_result = m_loader->load(m_load_config, true); | |||||
m_load_result = m_loader->load(m_load_config, false); | |||||
global_layout_transform(); | |||||
adapt_option_valid(); | adapt_option_valid(); | ||||
@@ -847,9 +855,6 @@ const char* NetworkImplDft::get_input_name(size_t index) const { | |||||
//! Plugin part | //! Plugin part | ||||
void NetworkImplDft::enable_profile_performance(std::string profile_json_file) { | void NetworkImplDft::enable_profile_performance(std::string profile_json_file) { | ||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
#if MGB_OPENCL | |||||
mgb::CompNode::enable_opencl_profile(true); | |||||
#endif | |||||
m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get()); | m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get()); | ||||
m_profiler_output_file = profile_json_file; | m_profiler_output_file = profile_json_file; | ||||
#else | #else | ||||
@@ -889,5 +894,40 @@ void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) co | |||||
LITE_MARK_USED_VAR(log_dir); | LITE_MARK_USED_VAR(log_dir); | ||||
} | } | ||||
void NetworkImplDft::enable_global_layout_transform() { | |||||
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC; | |||||
switch (m_user_config->device_type) { | |||||
case LiteDeviceType::LITE_CPU: | |||||
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU; | |||||
break; | |||||
case LiteDeviceType::LITE_CUDA: | |||||
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA; | |||||
break; | |||||
default: | |||||
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC; | |||||
LITE_WARN( | |||||
"lite compnode type: enum value: %d. is unspecial for layout " | |||||
"transform", | |||||
(int)(m_user_config->device_type)); | |||||
} | |||||
m_set_layout_transform = true; | |||||
} | |||||
void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) { | |||||
if (m_set_layout_transform) { | |||||
auto out_file = mgb::serialization::OutputFile::make_fs( | |||||
optimized_model_path.c_str(), 'w'); | |||||
using DumpConfig = mgb::serialization::GraphDumper::DumpConfig; | |||||
DumpConfig config{1, false, false}; | |||||
auto dumper = mgb::serialization::GraphDumper::make( | |||||
std::move(out_file), m_format.val()); | |||||
dumper->dump(m_load_result.output_var_list, config); | |||||
} else { | |||||
LITE_THROW( | |||||
ssprintf("dump layout transform model should call " | |||||
"enable_global_layout_transform before")); | |||||
} | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -19,6 +19,9 @@ | |||||
#include "network_impl_base.h" | #include "network_impl_base.h" | ||||
#include "tensor_impl.h" | #include "tensor_impl.h" | ||||
#include <memory> | |||||
#include <unordered_map> | |||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/graph/bases.h" | #include "megbrain/graph/bases.h" | ||||
#include "megbrain/plugin/opr_io_dump.h" | #include "megbrain/plugin/opr_io_dump.h" | ||||
#include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
@@ -28,9 +31,6 @@ | |||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/utils/thin/hash_table.h" | #include "megbrain/utils/thin/hash_table.h" | ||||
#include <memory> | |||||
#include <unordered_map> | |||||
namespace lite { | namespace lite { | ||||
/*! | /*! | ||||
@@ -170,11 +170,20 @@ public: | |||||
void get_static_memory_alloc_info( | void get_static_memory_alloc_info( | ||||
const std::string& log_dir = "logs/test") const override; | const std::string& log_dir = "logs/test") const override; | ||||
//! set global layout transform optimization for network | |||||
void enable_global_layout_transform(); | |||||
//! dump network after global layout transform optimization | |||||
void dump_layout_transform_model(std::string optimized_model_path); | |||||
private: | private: | ||||
//! construct the outputspec according to the m_network_io, and set the | //! construct the outputspec according to the m_network_io, and set the | ||||
//! call_back to the outputspec | //! call_back to the outputspec | ||||
void make_output_spec(); | void make_output_spec(); | ||||
//! do the global layout transform for the given platform target | |||||
void global_layout_transform(); | |||||
//! modify the execution policy | //! modify the execution policy | ||||
void modify_exection_policy(); | void modify_exection_policy(); | ||||
@@ -223,6 +232,7 @@ private: | |||||
int m_nr_device_type = 0; | int m_nr_device_type = 0; | ||||
size_t m_nr_threads = 1; | size_t m_nr_threads = 1; | ||||
bool m_compute_configured_output_only = false; | bool m_compute_configured_output_only = false; | ||||
bool m_set_layout_transform = false; | |||||
mgb::CompNode::Locator m_compnode_locator; | mgb::CompNode::Locator m_compnode_locator; | ||||
AsyncCallback m_async_callback = nullptr; | AsyncCallback m_async_callback = nullptr; | ||||
@@ -233,6 +243,9 @@ private: | |||||
//! The model load related data | //! The model load related data | ||||
S m_execution_policy = static_cast<S>(0); | S m_execution_policy = static_cast<S>(0); | ||||
std::unique_ptr<mgb::serialization::InputFile> m_input_file; | std::unique_ptr<mgb::serialization::InputFile> m_input_file; | ||||
mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format; | |||||
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | |||||
mgb::serialization::GraphLoadConfig m_load_config; | mgb::serialization::GraphLoadConfig m_load_config; | ||||
mgb::serialization::GraphLoader::LoadResult m_load_result; | mgb::serialization::GraphLoader::LoadResult m_load_result; | ||||
mgb::ComputingGraph::OutputSpec m_output_spec; | mgb::ComputingGraph::OutputSpec m_output_spec; | ||||
@@ -505,4 +505,33 @@ void Runtime::shared_weight_with_network( | |||||
LITE_ERROR_HANDLER_END | LITE_ERROR_HANDLER_END | ||||
} | } | ||||
void Runtime::enable_global_layout_transform(std::shared_ptr<Network> network) { | |||||
LITE_ERROR_HANDLER_BEGIN | |||||
auto network_impl = NetworkHelper::implement(network); | |||||
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { | |||||
LITE_ASSERT( | |||||
!NetworkHelper::loaded(network), | |||||
"enable_global_layout_transform should be used before model loaded."); | |||||
call_func<NetworkImplDft, void>("enable_global_layout_transform", network_impl); | |||||
return; | |||||
} | |||||
LITE_THROW("enable_global_layout_transform is not aviliable in the backend."); | |||||
LITE_ERROR_HANDLER_END | |||||
} | |||||
void Runtime::dump_layout_transform_model( | |||||
std::shared_ptr<Network> network, std::string optimized_model_path) { | |||||
LITE_ERROR_HANDLER_BEGIN | |||||
auto network_impl = NetworkHelper::implement(network); | |||||
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) { | |||||
LITE_ASSERT( | |||||
NetworkHelper::loaded(network), | |||||
"dump_layout_transform_model should be used after model loaded."); | |||||
call_func<NetworkImplDft, void>( | |||||
"dump_layout_transform_model", network_impl, optimized_model_path); | |||||
return; | |||||
} | |||||
LITE_THROW("dump_layout_transform_model is not aviliable in the backend."); | |||||
LITE_ERROR_HANDLER_END | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -909,6 +909,30 @@ TEST(TestNetWork, LoadPackedModel) { | |||||
network->wait(); | network->wait(); | ||||
} | } | ||||
TEST(TestNetWork, GlabalLayoutTransform) { | |||||
// set_log_level(LiteLogLevel::DEBUG); | |||||
auto tensor = get_input_data("./input_data.npy"); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
std::string input_name = "data"; | |||||
std::string dump_model_name = "./shufflenet_after_trans.mge"; | |||||
NetworkIO IO; | |||||
Config config; | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config, IO); | |||||
Runtime::enable_global_layout_transform(network); | |||||
network->load_model(model_path); | |||||
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name); | |||||
auto src_ptr = tensor->get_memory_ptr(); | |||||
auto src_layout = tensor->get_layout(); | |||||
input_tensor->reset(src_ptr, src_layout); | |||||
Runtime::dump_layout_transform_model(network, dump_model_name); | |||||
network->forward(); | |||||
network->wait(); | |||||
ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); | |||||
} | |||||
TEST(TestNetWork, GetDeviceType) { | TEST(TestNetWork, GetDeviceType) { | ||||
auto tensor = get_input_data("./input_data.npy"); | auto tensor = get_input_data("./input_data.npy"); | ||||
std::string model_path = "./shufflenet.mge"; | std::string model_path = "./shufflenet.mge"; | ||||
@@ -889,6 +889,21 @@ TEST(TestCapiNetWork, ProfileIOdump) { | |||||
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | ||||
} | } | ||||
TEST(TestCapiNetWork, GlabalLayoutTransform) { | |||||
ForwardMgb; | |||||
MakeNetwork; | |||||
LITE_CAPI_CHECK(LITE_enable_global_layout_transform(c_network)); | |||||
LoadNetwork; | |||||
LITE_CAPI_CHECK(LITE_dump_layout_transform_model( | |||||
c_network, "./shufflenet_after_trans.mge")); | |||||
SetInput; | |||||
ForwardNetwork; | |||||
ASSERT_TRUE(fopen("./shufflenet_after_trans.mge", "r")); | |||||
GetOutput; | |||||
CompareResult; | |||||
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | |||||
} | |||||
TEST(TestCapiNetWork, GetDeviceType) { | TEST(TestCapiNetWork, GetDeviceType) { | ||||
lite::Config config; | lite::Config config; | ||||
auto lite_tensor = lite::get_input_data("./input_data.npy"); | auto lite_tensor = lite::get_input_data("./input_data.npy"); | ||||