Browse Source

feat(lite): feat layout transform interface for lite model

GitOrigin-RevId: 57c7678419
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
49d92d9c3f
4 changed files with 61 additions and 24 deletions
  1. +3
    -0
      lite/load_and_run/src/models/model_lite.cpp
  2. +5
    -0
      lite/load_and_run/src/models/model_lite.h
  3. +50
    -21
      lite/load_and_run/src/options/layout_trans_options.cpp
  4. +3
    -3
      lite/load_and_run/src/options/layout_trans_options.h

+ 3
- 0
lite/load_and_run/src/models/model_lite.cpp View File

@@ -19,6 +19,9 @@ ModelLite::ModelLite(const std::string& path) : model_path(path) {
}; };
void ModelLite::load_model() { void ModelLite::load_model() {
m_network = std::make_shared<lite::Network>(config, IO); m_network = std::make_shared<lite::Network>(config, IO);
if (enable_layout_transform) {
lite::Runtime::enable_global_layout_transform(m_network);
}
if (share_model_mem) { if (share_model_mem) {
//! WARNNING:maybe not right to share param memmory for this //! WARNNING:maybe not right to share param memmory for this
LITE_WARN("enable share model memory"); LITE_WARN("enable share model memory");


+ 5
- 0
lite/load_and_run/src/models/model_lite.h View File

@@ -39,6 +39,10 @@ public:
//! wait the end of asynchronous function execution //! wait the end of asynchronous function execution
void wait() override; void wait() override;


//! enable global layout transform

void set_layout_transform(bool state) { enable_layout_transform = state; }

//! get the network of lite model //! get the network of lite model
std::shared_ptr<lite::Network>& get_lite_network() { return m_network; } std::shared_ptr<lite::Network>& get_lite_network() { return m_network; }


@@ -59,6 +63,7 @@ public:


private: private:
bool share_model_mem; bool share_model_mem;
bool enable_layout_transform;
std::string model_path; std::string model_path;


DataParser parser; DataParser parser;


+ 50
- 21
lite/load_and_run/src/options/layout_trans_options.cpp View File

@@ -16,9 +16,30 @@ namespace lar {


template <> template <>
void GoptLayoutOption::config_model_internel<ModelLite>( void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> /* model */) {
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
LITE_THROW("lite model don't support global graph optimization");
if (m_layout_transform) {
if (m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CPU) {
model->get_config().device_type = LiteDeviceType::LITE_CPU;
}
#if LITE_WITH_CUDA
else if (
m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CUDA) {
model->get_config().device_type = LiteDeviceType::LITE_CUDA;
}
#endif
model->set_layout_transform(true);
}
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
auto&& network = model->get_lite_network();
if (!m_layout_transform_dump_file.empty()) {
lite::Runtime::dump_layout_transform_model(
network, m_layout_transform_dump_file);
}
}
} }
} }


@@ -26,14 +47,14 @@ template <>
void GoptLayoutOption::config_model_internel<ModelMdl>( void GoptLayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (layout_transform) {
if (m_layout_transform) {
auto&& load_result = model->get_mdl_load_result(); auto&& load_result = model->get_mdl_load_result();
load_result.output_var_list = mgb::gopt::layout_transform( load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, layout_transform_target);
load_result.output_var_list, m_layout_transform_target);


if (!layout_transform_dump_file.empty()) {
if (!m_layout_transform_dump_file.empty()) {
auto out_file = mgb::serialization::OutputFile::make_fs( auto out_file = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'w');
m_layout_transform_dump_file.c_str(), 'w');
auto testcase_num = model->get_testcase_num(); auto testcase_num = model->get_testcase_num();


if (testcase_num) { if (testcase_num) {
@@ -56,7 +77,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
mgb::serialization::GraphDumper::DumpConfig config{1, false, false}; mgb::serialization::GraphDumper::DumpConfig config{1, false, false};
for (size_t i = 0; i < testcase_num; ++i) { for (size_t i = 0; i < testcase_num; ++i) {
auto casefile = mgb::serialization::OutputFile::make_fs( auto casefile = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'a');
m_layout_transform_dump_file.c_str(), 'a');
auto casedumper = model->get_dumper(std::move(casefile)); auto casedumper = model->get_dumper(std::move(casefile));
casedumper->dump(testcase.output_var_list, config); casedumper->dump(testcase.output_var_list, config);
if (i != testcase_num - 1) { if (i != testcase_num - 1) {
@@ -80,29 +101,37 @@ using namespace lar;


GoptLayoutOption::GoptLayoutOption() { GoptLayoutOption::GoptLayoutOption() {
m_option_name = "gopt_layout"; m_option_name = "gopt_layout";
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" &&
FLAGS_layout_transform != "opencl") {
layout_transform = false;
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
if (FLAGS_layout_transform != "cpu"
#if LITE_WITH_CUDA
&& FLAGS_layout_transform != "cuda"
#endif
) {
m_layout_transform = false;
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;


} else { } else {
layout_transform = true;
if (FLAGS_layout_transform == "cuda") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
} else if (FLAGS_layout_transform == "cpu") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
} else if (FLAGS_layout_transform == "opencl") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::OPENCL;
m_layout_transform = true;

if (FLAGS_layout_transform == "cpu") {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
}
#if LITE_WITH_CUDA
else if (FLAGS_layout_transform == "cuda") {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
} }
#endif
} }
layout_transform_dump_file = FLAGS_layout_transform_dump;
m_layout_transform_dump_file = FLAGS_layout_transform_dump;
} }


bool GoptLayoutOption::is_valid() { bool GoptLayoutOption::is_valid() {
bool ret = false; bool ret = false;
if (!FLAGS_layout_transform.empty()) { if (!FLAGS_layout_transform.empty()) {
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" &&
FLAGS_layout_transform != "opencl") {
if (FLAGS_layout_transform != "cpu"
#if LITE_WITH_CUDA
&& FLAGS_layout_transform != "cuda"
#endif
) {
mgb_assert( mgb_assert(
false, false,
"unsupported target(got:%s) for global layout " "unsupported target(got:%s) for global layout "


+ 3
- 3
lite/load_and_run/src/options/layout_trans_options.h View File

@@ -37,9 +37,9 @@ private:
//! config template for different model //! config template for different model
template <typename ModelImpl> template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {}
bool layout_transform;
bool m_layout_transform;
std::string m_option_name; std::string m_option_name;
std::string layout_transform_dump_file;
mgb::gopt::GraphTuningOptions::Target layout_transform_target;
std::string m_layout_transform_dump_file;
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
}; };
} // namespace lar } // namespace lar

Loading…
Cancel
Save