From 2001c4940885ecf330eb33c208af5e014fd26b17 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 26 May 2022 20:02:32 +0800 Subject: [PATCH] feat(lite): add input shape parse for load and run GitOrigin-RevId: ec44429f55050af49a24ab1e2aa671a08aff460e --- lite/load_and_run/src/helpers/data_parser.cpp | 36 ++++++++++++++++++++- lite/load_and_run/src/helpers/data_parser.h | 7 ++-- lite/load_and_run/src/options/io_options.cpp | 12 ++++++- .../src/options/layout_trans_options.cpp | 37 ++++++++++++++++++++-- .../src/options/layout_trans_options.h | 2 ++ 5 files changed, 87 insertions(+), 7 deletions(-) diff --git a/lite/load_and_run/src/helpers/data_parser.cpp b/lite/load_and_run/src/helpers/data_parser.cpp index 86527d53..ac378f88 100644 --- a/lite/load_and_run/src/helpers/data_parser.cpp +++ b/lite/load_and_run/src/helpers/data_parser.cpp @@ -180,7 +180,13 @@ void DataParser::parse_npy(const std::string& name, const std::string& path) { inputs.insert(std::make_pair(name, std::move(hv))); } -void DataParser::parse_string(const std::string name, const std::string& str) { +void DataParser::parse_string(const std::string& name, const std::string& str) { + //! parse shape + if ('{' == str[0]) { + parse_shape(name, str); + return; + } + // data type megdnn::DType data_type = mgb::dtype::Int32(); if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { @@ -257,3 +263,31 @@ void DataParser::parse_string(const std::string name, const std::string& str) { } inputs.insert(std::make_pair(name, std::move(hv))); } + +void DataParser::parse_shape(const std::string& name, const std::string& str) { + //! {d0,d1,..,dn} + mgb_assert( + "{" == str.substr(0, 1), + "invalid value: %s for parse_shape, valid format: {d0,d1,..,dn}\n", + str.c_str()); + megdnn::SmallVector shape; + std::string shape_size = ""; + for (size_t i = 0; i < str.size(); ++i) { + char c = str[i]; + if ('{' == c || ' ' == c) { + continue; + } else if (',' == c || '}' == c) { + shape.push_back(std::stoul(shape_size)); + shape_size = ""; + if ('}' == c) { + break; + } + } else { + shape_size += c; + } + } + mgb::HostTensorND hv(mgb::CompNode::default_cpu(), shape); + mgb::HostTensorStorage storage(mgb::CompNode::default_cpu()); + hv.only_reset_raw_storage(storage); + inputs.insert(std::make_pair(name, std::move(hv))); +} diff --git a/lite/load_and_run/src/helpers/data_parser.h b/lite/load_and_run/src/helpers/data_parser.h index ef9af154..c58e15c8 100644 --- a/lite/load_and_run/src/helpers/data_parser.h +++ b/lite/load_and_run/src/helpers/data_parser.h @@ -30,7 +30,10 @@ private: //! parser for .npy data void parse_npy(const std::string& name, const std::string& path); - //! parser for user define string - void parse_string(const std::string name, const std::string& str); + //! parser for user defined string + void parse_string(const std::string& name, const std::string& str); + + //! parser for user defined shape + void parse_shape(const std::string& name, const std::string& str); }; } // namespace lar diff --git a/lite/load_and_run/src/options/io_options.cpp b/lite/load_and_run/src/options/io_options.cpp index 786e4c2e..105397d3 100644 --- a/lite/load_and_run/src/options/io_options.cpp +++ b/lite/load_and_run/src/options/io_options.cpp @@ -73,7 +73,17 @@ void InputOption::config_model_internel( tensormap.find(i.first) != tensormap.end(), "can't find tesnor named %s", i.first.c_str()); auto& in = tensormap.find(i.first)->second; - in->copy_from(i.second); + if (i.second.storage().empty()) { + mgb::HostTensorND hv; + hv.comp_node(mgb::CompNode::default_cpu(), true) + .dtype(in->dtype()) + .resize(i.second.shape()); + mgb::dt_byte* raw_ptr = hv.raw_ptr(); + memset((char*)raw_ptr, 1, hv.layout().total_nr_elems()); + in->copy_from(hv); + } else { + in->copy_from(i.second); + } } } } diff --git a/lite/load_and_run/src/options/layout_trans_options.cpp b/lite/load_and_run/src/options/layout_trans_options.cpp index ab9f17f2..1d35acbd 100644 --- a/lite/load_and_run/src/options/layout_trans_options.cpp +++ b/lite/load_and_run/src/options/layout_trans_options.cpp @@ -39,10 +39,24 @@ void GoptLayoutOption::config_model_internel( template <> void GoptLayoutOption::config_model_internel( RuntimeParam& runtime_param, std::shared_ptr model) { - if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { + if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { if (m_layout_transform) { - mgb_log_warn("using global layout transform optimization\n"); + mgb_log_debug("update input shape for global layout transform\n"); auto&& load_result = model->get_mdl_load_result(); + if (m_force_batch_size > 0) { + for (auto&& i : load_result.tensor_map) { + auto& in = i.second; + mgb::TensorShape new_shape = in->shape(); + new_shape[0] = m_force_batch_size; + mgb::HostTensorND new_tensor; + new_tensor.comp_node(mgb::CompNode::default_cpu(), true) + .dtype(in->dtype()) + .resize(new_shape); + mgb::dt_byte* raw_ptr = new_tensor.raw_ptr(); + memset((char*)raw_ptr, 1, new_tensor.layout().total_nr_elems()); + in->copy_from(new_tensor); + } + } for (auto&& item : load_result.output_var_list) { if (item.shape()[0] > 1) { mgb_log_warn( @@ -81,7 +95,11 @@ void GoptLayoutOption::config_model_internel( } load_result.output_var_list = output_vars; } - + } + } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { + if (m_layout_transform) { + mgb_log_warn("using global layout transform optimization\n"); + auto&& load_result = model->get_mdl_load_result(); load_result.output_var_list = mgb::gopt::layout_transform( load_result.output_var_list, m_layout_transform_target); @@ -156,6 +174,8 @@ GoptLayoutOption::GoptLayoutOption() { } m_layout_transform_dump_file = FLAGS_layout_transform_dump; + m_force_batch_size = FLAGS_layout_transform_batch_size; + m_option = { {"layout_transform", lar::String::make("")}, }; @@ -182,6 +202,14 @@ bool GoptLayoutOption::is_valid() { } } ret = ret || !FLAGS_layout_transform_dump.empty(); + if (FLAGS_layout_transform_batch_size > 0) { + mgb_assert( + FLAGS_layout_transform_batch_size > 0 && + !FLAGS_layout_transform.empty(), + "\"layout-transform-batch-size\" should be set with " + "\"layout-transform\""); + ret = ret || FLAGS_layout_transform_batch_size > 0; + } return ret || m_valid; } @@ -233,5 +261,8 @@ DEFINE_string( "The computing graph after global layout transform will be dumped to the given " "file path."); +DEFINE_int32( + layout_transform_batch_size, -1, + "the batch size of input for global layout transform optimization working on"); REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); \ No newline at end of file diff --git a/lite/load_and_run/src/options/layout_trans_options.h b/lite/load_and_run/src/options/layout_trans_options.h index 05f80e87..fac96dcb 100644 --- a/lite/load_and_run/src/options/layout_trans_options.h +++ b/lite/load_and_run/src/options/layout_trans_options.h @@ -5,6 +5,7 @@ #include "models/model.h" #include "option_base.h" DECLARE_string(layout_transform); +DECLARE_int32(layout_transform_batch_size); DECLARE_string(layout_transform_dump); namespace lar { @@ -38,5 +39,6 @@ private: mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; static bool m_valid; OptionValMap m_option; + int32_t m_force_batch_size; }; } // namespace lar