|
|
@@ -442,6 +442,24 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void NetworkImplDft::try_infer_tensor_layout( |
|
|
|
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var) { |
|
|
|
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); |
|
|
|
auto infer_trait = var.node()->get_static_infer_trait(); |
|
|
|
if (std::get<0>(infer_trait)) { |
|
|
|
auto shape = static_infer_mgr.infer_shape_fallible(var.node()); |
|
|
|
if (!shape) { |
|
|
|
LITE_WARN( |
|
|
|
"Lite infer output shape failed, maybe the model is " |
|
|
|
"dynamic " |
|
|
|
"shape.\n"); |
|
|
|
return; |
|
|
|
} |
|
|
|
Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()}); |
|
|
|
tensor->set_layout(layout); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void NetworkImplDft::update_io() { |
|
|
|
update_input(); |
|
|
|
update_output(); |
|
|
@@ -564,6 +582,14 @@ void NetworkImplDft::update_output() { |
|
|
|
out_it->lite_tensor = |
|
|
|
std::make_shared<Tensor>(device_id, stream_id, device_type); |
|
|
|
} |
|
|
|
mgb::SymbolVar var; |
|
|
|
for (auto&& out_var : m_load_result.output_var_list) { |
|
|
|
if (out_var.node()->name() == out_it->name) { |
|
|
|
var = out_var; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
try_infer_tensor_layout(out_it->lite_tensor, var); |
|
|
|
} |
|
|
|
//! user not set, use default output |
|
|
|
} else { |
|
|
@@ -579,12 +605,14 @@ void NetworkImplDft::update_output() { |
|
|
|
it->lite_tensor = |
|
|
|
std::make_shared<Tensor>(device_id, stream_id, device_type); |
|
|
|
} |
|
|
|
try_infer_tensor_layout(it->lite_tensor, out); |
|
|
|
} else { |
|
|
|
IOInner output; |
|
|
|
output.name = out.node()->name(); |
|
|
|
output.lite_tensor = std::make_shared<Tensor>( |
|
|
|
device_id, stream_id, device_type, true); |
|
|
|
m_network_io->outputs.push_back({output}); |
|
|
|
try_infer_tensor_layout(output.lite_tensor, out); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|