diff --git a/lite/load_and_run/src/models/model_lite.cpp b/lite/load_and_run/src/models/model_lite.cpp index 2cdf3834..7a913af0 100644 --- a/lite/load_and_run/src/models/model_lite.cpp +++ b/lite/load_and_run/src/models/model_lite.cpp @@ -19,6 +19,9 @@ ModelLite::ModelLite(const std::string& path) : model_path(path) { }; void ModelLite::load_model() { m_network = std::make_shared(config, IO); + if (enable_layout_transform) { + lite::Runtime::enable_global_layout_transform(m_network); + } if (share_model_mem) { //! WARNNING:maybe not right to share param memmory for this LITE_WARN("enable share model memory"); diff --git a/lite/load_and_run/src/models/model_lite.h b/lite/load_and_run/src/models/model_lite.h index dc22fc7d..88323beb 100644 --- a/lite/load_and_run/src/models/model_lite.h +++ b/lite/load_and_run/src/models/model_lite.h @@ -39,6 +39,10 @@ public: //! wait the end of asynchronous function execution void wait() override; + //! enable global layout transform + + void set_layout_transform(bool state) { enable_layout_transform = state; } + //! get the network of lite model std::shared_ptr& get_lite_network() { return m_network; } @@ -59,6 +63,7 @@ public: private: bool share_model_mem; + bool enable_layout_transform; std::string model_path; DataParser parser; diff --git a/lite/load_and_run/src/options/layout_trans_options.cpp b/lite/load_and_run/src/options/layout_trans_options.cpp index 05d3ddcc..22abb0b2 100644 --- a/lite/load_and_run/src/options/layout_trans_options.cpp +++ b/lite/load_and_run/src/options/layout_trans_options.cpp @@ -16,9 +16,30 @@ namespace lar { template <> void GoptLayoutOption::config_model_internel( - RuntimeParam& runtime_param, std::shared_ptr /* model */) { + RuntimeParam& runtime_param, std::shared_ptr model) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { - LITE_THROW("lite model don't support global graph optimization"); + if (m_layout_transform) { + if (m_layout_transform_target == + mgb::gopt::GraphTuningOptions::Target::CPU) { + model->get_config().device_type = LiteDeviceType::LITE_CPU; + } +#if LITE_WITH_CUDA + else if ( + m_layout_transform_target == + mgb::gopt::GraphTuningOptions::Target::CUDA) { + model->get_config().device_type = LiteDeviceType::LITE_CUDA; + } +#endif + model->set_layout_transform(true); + } + } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { + if (m_layout_transform) { + auto&& network = model->get_lite_network(); + if (!m_layout_transform_dump_file.empty()) { + lite::Runtime::dump_layout_transform_model( + network, m_layout_transform_dump_file); + } + } } } @@ -26,14 +47,14 @@ template <> void GoptLayoutOption::config_model_internel( RuntimeParam& runtime_param, std::shared_ptr model) { if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { - if (layout_transform) { + if (m_layout_transform) { auto&& load_result = model->get_mdl_load_result(); load_result.output_var_list = mgb::gopt::layout_transform( - load_result.output_var_list, layout_transform_target); + load_result.output_var_list, m_layout_transform_target); - if (!layout_transform_dump_file.empty()) { + if (!m_layout_transform_dump_file.empty()) { auto out_file = mgb::serialization::OutputFile::make_fs( - layout_transform_dump_file.c_str(), 'w'); + m_layout_transform_dump_file.c_str(), 'w'); auto testcase_num = model->get_testcase_num(); if (testcase_num) { @@ -56,7 +77,7 @@ void GoptLayoutOption::config_model_internel( mgb::serialization::GraphDumper::DumpConfig config{1, false, false}; for (size_t i = 0; i < testcase_num; ++i) { auto casefile = mgb::serialization::OutputFile::make_fs( - layout_transform_dump_file.c_str(), 'a'); + m_layout_transform_dump_file.c_str(), 'a'); auto casedumper = model->get_dumper(std::move(casefile)); casedumper->dump(testcase.output_var_list, config); if (i != testcase_num - 1) { @@ -80,29 +101,37 @@ using namespace lar; GoptLayoutOption::GoptLayoutOption() { m_option_name = "gopt_layout"; - if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" && - FLAGS_layout_transform != "opencl") { - layout_transform = false; - layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC; + if (FLAGS_layout_transform != "cpu" +#if LITE_WITH_CUDA + && FLAGS_layout_transform != "cuda" +#endif + ) { + m_layout_transform = false; + m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC; } else { - layout_transform = true; - if (FLAGS_layout_transform == "cuda") { - layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA; - } else if (FLAGS_layout_transform == "cpu") { - layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU; - } else if (FLAGS_layout_transform == "opencl") { - layout_transform_target = mgb::gopt::GraphTuningOptions::Target::OPENCL; + m_layout_transform = true; + + if (FLAGS_layout_transform == "cpu") { + m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU; + } +#if LITE_WITH_CUDA + else if (FLAGS_layout_transform == "cuda") { + m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA; } +#endif } - layout_transform_dump_file = FLAGS_layout_transform_dump; + m_layout_transform_dump_file = FLAGS_layout_transform_dump; } bool GoptLayoutOption::is_valid() { bool ret = false; if (!FLAGS_layout_transform.empty()) { - if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" && - FLAGS_layout_transform != "opencl") { + if (FLAGS_layout_transform != "cpu" +#if LITE_WITH_CUDA + && FLAGS_layout_transform != "cuda" +#endif + ) { mgb_assert( false, "unsupported target(got:%s) for global layout " diff --git a/lite/load_and_run/src/options/layout_trans_options.h b/lite/load_and_run/src/options/layout_trans_options.h index de94084d..ad233a04 100644 --- a/lite/load_and_run/src/options/layout_trans_options.h +++ b/lite/load_and_run/src/options/layout_trans_options.h @@ -37,9 +37,9 @@ private: //! config template for different model template void config_model_internel(RuntimeParam&, std::shared_ptr) {} - bool layout_transform; + bool m_layout_transform; std::string m_option_name; - std::string layout_transform_dump_file; - mgb::gopt::GraphTuningOptions::Target layout_transform_target; + std::string m_layout_transform_dump_file; + mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; }; } // namespace lar