GitOrigin-RevId: ec44429f55
release-1.10
@@ -180,7 +180,13 @@ void DataParser::parse_npy(const std::string& name, const std::string& path) { | |||
inputs.insert(std::make_pair(name, std::move(hv))); | |||
} | |||
void DataParser::parse_string(const std::string name, const std::string& str) { | |||
void DataParser::parse_string(const std::string& name, const std::string& str) { | |||
//! parse shape | |||
if ('{' == str[0]) { | |||
parse_shape(name, str); | |||
return; | |||
} | |||
// data type | |||
megdnn::DType data_type = mgb::dtype::Int32(); | |||
if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { | |||
@@ -257,3 +263,31 @@ void DataParser::parse_string(const std::string name, const std::string& str) { | |||
} | |||
inputs.insert(std::make_pair(name, std::move(hv))); | |||
} | |||
void DataParser::parse_shape(const std::string& name, const std::string& str) { | |||
//! {d0,d1,..,dn} | |||
mgb_assert( | |||
"{" == str.substr(0, 1), | |||
"invalid value: %s for parse_shape, valid format: {d0,d1,..,dn}\n", | |||
str.c_str()); | |||
megdnn::SmallVector<size_t> shape; | |||
std::string shape_size = ""; | |||
for (size_t i = 0; i < str.size(); ++i) { | |||
char c = str[i]; | |||
if ('{' == c || ' ' == c) { | |||
continue; | |||
} else if (',' == c || '}' == c) { | |||
shape.push_back(std::stoul(shape_size)); | |||
shape_size = ""; | |||
if ('}' == c) { | |||
break; | |||
} | |||
} else { | |||
shape_size += c; | |||
} | |||
} | |||
mgb::HostTensorND hv(mgb::CompNode::default_cpu(), shape); | |||
mgb::HostTensorStorage storage(mgb::CompNode::default_cpu()); | |||
hv.only_reset_raw_storage(storage); | |||
inputs.insert(std::make_pair(name, std::move(hv))); | |||
} |
@@ -30,7 +30,10 @@ private: | |||
//! parser for .npy data | |||
void parse_npy(const std::string& name, const std::string& path); | |||
//! parser for user define string | |||
void parse_string(const std::string name, const std::string& str); | |||
//! parser for user defined string | |||
void parse_string(const std::string& name, const std::string& str); | |||
//! parser for user defined shape | |||
void parse_shape(const std::string& name, const std::string& str); | |||
}; | |||
} // namespace lar |
@@ -73,7 +73,17 @@ void InputOption::config_model_internel<ModelMdl>( | |||
tensormap.find(i.first) != tensormap.end(), | |||
"can't find tesnor named %s", i.first.c_str()); | |||
auto& in = tensormap.find(i.first)->second; | |||
in->copy_from(i.second); | |||
if (i.second.storage().empty()) { | |||
mgb::HostTensorND hv; | |||
hv.comp_node(mgb::CompNode::default_cpu(), true) | |||
.dtype(in->dtype()) | |||
.resize(i.second.shape()); | |||
mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||
memset((char*)raw_ptr, 1, hv.layout().total_nr_elems()); | |||
in->copy_from(hv); | |||
} else { | |||
in->copy_from(i.second); | |||
} | |||
} | |||
} | |||
} | |||
@@ -39,10 +39,24 @@ void GoptLayoutOption::config_model_internel<ModelLite>( | |||
template <> | |||
void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
if (m_layout_transform) { | |||
mgb_log_warn("using global layout transform optimization\n"); | |||
mgb_log_debug("update input shape for global layout transform\n"); | |||
auto&& load_result = model->get_mdl_load_result(); | |||
if (m_force_batch_size > 0) { | |||
for (auto&& i : load_result.tensor_map) { | |||
auto& in = i.second; | |||
mgb::TensorShape new_shape = in->shape(); | |||
new_shape[0] = m_force_batch_size; | |||
mgb::HostTensorND new_tensor; | |||
new_tensor.comp_node(mgb::CompNode::default_cpu(), true) | |||
.dtype(in->dtype()) | |||
.resize(new_shape); | |||
mgb::dt_byte* raw_ptr = new_tensor.raw_ptr(); | |||
memset((char*)raw_ptr, 1, new_tensor.layout().total_nr_elems()); | |||
in->copy_from(new_tensor); | |||
} | |||
} | |||
for (auto&& item : load_result.output_var_list) { | |||
if (item.shape()[0] > 1) { | |||
mgb_log_warn( | |||
@@ -81,7 +95,11 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
} | |||
load_result.output_var_list = output_vars; | |||
} | |||
} | |||
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||
if (m_layout_transform) { | |||
mgb_log_warn("using global layout transform optimization\n"); | |||
auto&& load_result = model->get_mdl_load_result(); | |||
load_result.output_var_list = mgb::gopt::layout_transform( | |||
load_result.output_var_list, m_layout_transform_target); | |||
@@ -156,6 +174,8 @@ GoptLayoutOption::GoptLayoutOption() { | |||
} | |||
m_layout_transform_dump_file = FLAGS_layout_transform_dump; | |||
m_force_batch_size = FLAGS_layout_transform_batch_size; | |||
m_option = { | |||
{"layout_transform", lar::String::make("")}, | |||
}; | |||
@@ -182,6 +202,14 @@ bool GoptLayoutOption::is_valid() { | |||
} | |||
} | |||
ret = ret || !FLAGS_layout_transform_dump.empty(); | |||
if (FLAGS_layout_transform_batch_size > 0) { | |||
mgb_assert( | |||
FLAGS_layout_transform_batch_size > 0 && | |||
!FLAGS_layout_transform.empty(), | |||
"\"layout-transform-batch-size\" should be set with " | |||
"\"layout-transform\""); | |||
ret = ret || FLAGS_layout_transform_batch_size > 0; | |||
} | |||
return ret || m_valid; | |||
} | |||
@@ -233,5 +261,8 @@ DEFINE_string( | |||
"The computing graph after global layout transform will be dumped to the given " | |||
"file path."); | |||
DEFINE_int32( | |||
layout_transform_batch_size, -1, | |||
"the batch size of input for global layout transform optimization working on"); | |||
REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); | |||
REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); |
@@ -5,6 +5,7 @@ | |||
#include "models/model.h" | |||
#include "option_base.h" | |||
DECLARE_string(layout_transform); | |||
DECLARE_int32(layout_transform_batch_size); | |||
DECLARE_string(layout_transform_dump); | |||
namespace lar { | |||
@@ -38,5 +39,6 @@ private: | |||
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | |||
static bool m_valid; | |||
OptionValMap m_option; | |||
int32_t m_force_batch_size; | |||
}; | |||
} // namespace lar |