Browse Source

feat(lite): add input shape parse for load and run

GitOrigin-RevId: ec44429f55
release-1.10
Megvii Engine Team 3 years ago
parent
commit
2001c49408
5 changed files with 87 additions and 7 deletions
  1. +35
    -1
      lite/load_and_run/src/helpers/data_parser.cpp
  2. +5
    -2
      lite/load_and_run/src/helpers/data_parser.h
  3. +11
    -1
      lite/load_and_run/src/options/io_options.cpp
  4. +34
    -3
      lite/load_and_run/src/options/layout_trans_options.cpp
  5. +2
    -0
      lite/load_and_run/src/options/layout_trans_options.h

+ 35
- 1
lite/load_and_run/src/helpers/data_parser.cpp View File

@@ -180,7 +180,13 @@ void DataParser::parse_npy(const std::string& name, const std::string& path) {
inputs.insert(std::make_pair(name, std::move(hv)));
}

void DataParser::parse_string(const std::string name, const std::string& str) {
void DataParser::parse_string(const std::string& name, const std::string& str) {
//! parse shape
if ('{' == str[0]) {
parse_shape(name, str);
return;
}

// data type
megdnn::DType data_type = mgb::dtype::Int32();
if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) {
@@ -257,3 +263,31 @@ void DataParser::parse_string(const std::string name, const std::string& str) {
}
inputs.insert(std::make_pair(name, std::move(hv)));
}

void DataParser::parse_shape(const std::string& name, const std::string& str) {
//! {d0,d1,..,dn}
mgb_assert(
"{" == str.substr(0, 1),
"invalid value: %s for parse_shape, valid format: {d0,d1,..,dn}\n",
str.c_str());
megdnn::SmallVector<size_t> shape;
std::string shape_size = "";
for (size_t i = 0; i < str.size(); ++i) {
char c = str[i];
if ('{' == c || ' ' == c) {
continue;
} else if (',' == c || '}' == c) {
shape.push_back(std::stoul(shape_size));
shape_size = "";
if ('}' == c) {
break;
}
} else {
shape_size += c;
}
}
mgb::HostTensorND hv(mgb::CompNode::default_cpu(), shape);
mgb::HostTensorStorage storage(mgb::CompNode::default_cpu());
hv.only_reset_raw_storage(storage);
inputs.insert(std::make_pair(name, std::move(hv)));
}

+ 5
- 2
lite/load_and_run/src/helpers/data_parser.h View File

@@ -30,7 +30,10 @@ private:
//! parser for .npy data
void parse_npy(const std::string& name, const std::string& path);

//! parser for user define string
void parse_string(const std::string name, const std::string& str);
//! parser for user defined string
void parse_string(const std::string& name, const std::string& str);

//! parser for user defined shape
void parse_shape(const std::string& name, const std::string& str);
};
} // namespace lar

+ 11
- 1
lite/load_and_run/src/options/io_options.cpp View File

@@ -73,7 +73,17 @@ void InputOption::config_model_internel<ModelMdl>(
tensormap.find(i.first) != tensormap.end(),
"can't find tesnor named %s", i.first.c_str());
auto& in = tensormap.find(i.first)->second;
in->copy_from(i.second);
if (i.second.storage().empty()) {
mgb::HostTensorND hv;
hv.comp_node(mgb::CompNode::default_cpu(), true)
.dtype(in->dtype())
.resize(i.second.shape());
mgb::dt_byte* raw_ptr = hv.raw_ptr();
memset((char*)raw_ptr, 1, hv.layout().total_nr_elems());
in->copy_from(hv);
} else {
in->copy_from(i.second);
}
}
}
}


+ 34
- 3
lite/load_and_run/src/options/layout_trans_options.cpp View File

@@ -39,10 +39,24 @@ void GoptLayoutOption::config_model_internel<ModelLite>(
template <>
void GoptLayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (m_layout_transform) {
mgb_log_warn("using global layout transform optimization\n");
mgb_log_debug("update input shape for global layout transform\n");
auto&& load_result = model->get_mdl_load_result();
if (m_force_batch_size > 0) {
for (auto&& i : load_result.tensor_map) {
auto& in = i.second;
mgb::TensorShape new_shape = in->shape();
new_shape[0] = m_force_batch_size;
mgb::HostTensorND new_tensor;
new_tensor.comp_node(mgb::CompNode::default_cpu(), true)
.dtype(in->dtype())
.resize(new_shape);
mgb::dt_byte* raw_ptr = new_tensor.raw_ptr();
memset((char*)raw_ptr, 1, new_tensor.layout().total_nr_elems());
in->copy_from(new_tensor);
}
}
for (auto&& item : load_result.output_var_list) {
if (item.shape()[0] > 1) {
mgb_log_warn(
@@ -81,7 +95,11 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
}
load_result.output_var_list = output_vars;
}

}
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
mgb_log_warn("using global layout transform optimization\n");
auto&& load_result = model->get_mdl_load_result();
load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, m_layout_transform_target);

@@ -156,6 +174,8 @@ GoptLayoutOption::GoptLayoutOption() {
}
m_layout_transform_dump_file = FLAGS_layout_transform_dump;

m_force_batch_size = FLAGS_layout_transform_batch_size;

m_option = {
{"layout_transform", lar::String::make("")},
};
@@ -182,6 +202,14 @@ bool GoptLayoutOption::is_valid() {
}
}
ret = ret || !FLAGS_layout_transform_dump.empty();
if (FLAGS_layout_transform_batch_size > 0) {
mgb_assert(
FLAGS_layout_transform_batch_size > 0 &&
!FLAGS_layout_transform.empty(),
"\"layout-transform-batch-size\" should be set with "
"\"layout-transform\"");
ret = ret || FLAGS_layout_transform_batch_size > 0;
}
return ret || m_valid;
}

@@ -233,5 +261,8 @@ DEFINE_string(
"The computing graph after global layout transform will be dumped to the given "
"file path.");

DEFINE_int32(
layout_transform_batch_size, -1,
"the batch size of input for global layout transform optimization working on");
REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option);
REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid);

+ 2
- 0
lite/load_and_run/src/options/layout_trans_options.h View File

@@ -5,6 +5,7 @@
#include "models/model.h"
#include "option_base.h"
DECLARE_string(layout_transform);
DECLARE_int32(layout_transform_batch_size);
DECLARE_string(layout_transform_dump);

namespace lar {
@@ -38,5 +39,6 @@ private:
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
static bool m_valid;
OptionValMap m_option;
int32_t m_force_batch_size;
};
} // namespace lar

Loading…
Cancel
Save