|
|
@@ -20,8 +20,8 @@ template <> |
|
|
|
void InputOption::config_model_internel<ModelLite>( |
|
|
|
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { |
|
|
|
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { |
|
|
|
auto parser = model->get_input_parser(); |
|
|
|
auto io = model->get_networkIO(); |
|
|
|
auto&& parser = model->get_input_parser(); |
|
|
|
auto&& io = model->get_networkIO(); |
|
|
|
for (size_t idx = 0; idx < data_path.size(); ++idx) { |
|
|
|
parser.feed(data_path[idx].c_str()); |
|
|
|
} |
|
|
@@ -32,9 +32,8 @@ void InputOption::config_model_internel<ModelLite>( |
|
|
|
io.inputs.push_back({i.first, is_host}); |
|
|
|
} |
|
|
|
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { |
|
|
|
auto config = model->get_config(); |
|
|
|
auto parser = model->get_input_parser(); |
|
|
|
auto network = model->get_lite_network(); |
|
|
|
auto&& parser = model->get_input_parser(); |
|
|
|
auto&& network = model->get_lite_network(); |
|
|
|
|
|
|
|
//! datd type map from mgb data type to lite data type |
|
|
|
std::map<megdnn::DTypeEnum, LiteDataType> type_map = { |
|
|
@@ -75,8 +74,8 @@ void InputOption::config_model_internel<ModelMdl>( |
|
|
|
parser.feed(data_path[idx].c_str()); |
|
|
|
} |
|
|
|
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { |
|
|
|
auto parser = model->get_input_parser(); |
|
|
|
auto network = model->get_mdl_load_result(); |
|
|
|
auto&& parser = model->get_input_parser(); |
|
|
|
auto&& network = model->get_mdl_load_result(); |
|
|
|
auto tensormap = network.tensor_map; |
|
|
|
for (auto& i : parser.inputs) { |
|
|
|
mgb_assert( |
|
|
@@ -156,7 +155,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( |
|
|
|
} |
|
|
|
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { |
|
|
|
if (enable_bin_out_dump) { |
|
|
|
auto load_result = model->get_mdl_load_result(); |
|
|
|
auto&& load_result = model->get_mdl_load_result(); |
|
|
|
out_dumper->set(load_result.output_var_list); |
|
|
|
|
|
|
|
std::vector<mgb::ComputingGraph::Callback> cb; |
|
|
@@ -166,7 +165,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( |
|
|
|
model->set_output_callback(cb); |
|
|
|
} |
|
|
|
if (enable_copy_to_host) { |
|
|
|
auto load_result = model->get_mdl_load_result(); |
|
|
|
auto&& load_result = model->get_mdl_load_result(); |
|
|
|
|
|
|
|
std::vector<mgb::ComputingGraph::Callback> cb; |
|
|
|
for (size_t i = 0; i < load_result.output_var_list.size(); i++) { |
|
|
|