|
|
@@ -26,6 +26,7 @@ |
|
|
|
#include "megbrain/graph.h" |
|
|
|
#include "megbrain/graph/cg.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
|
#include "megbrain/tensor.h" |
|
|
|
|
|
|
|
#if MGB_OPENCL |
|
|
@@ -340,6 +341,31 @@ void NetworkImplDft::cross_compnode_model_detect() { |
|
|
|
m_nr_device_type = nr_used_device_type.size(); |
|
|
|
} |
|
|
|
|
|
|
|
void NetworkImplDft::adapt_option_valid() { |
|
|
|
auto&& options = m_load_config.comp_graph->options(); |
|
|
|
if (m_user_config->options.force_output_use_user_specified_memory) { |
|
|
|
for (auto&& out : m_load_result.output_var_list) { |
|
|
|
auto opr = out.node()->owner_opr(); |
|
|
|
//! all the dest operator inherit from ReadonlyFwdHelper can't |
|
|
|
//! support force_output_use_user_specified_memory options |
|
|
|
if (opr->try_cast_final<mgb::opr::Reshape>() || |
|
|
|
opr->try_cast_final<mgb::opr::Broadcast>() || |
|
|
|
opr->try_cast_final<mgb::opr::Subtensor>() || |
|
|
|
opr->try_cast_final<mgb::opr::AxisAddRemove>() || |
|
|
|
opr->try_cast_final<mgb::opr::Dimshuffle>()) { |
|
|
|
m_user_config->options.force_output_use_user_specified_memory = false; |
|
|
|
options.force_output_use_user_specified_memory = false; |
|
|
|
LITE_WARN( |
|
|
|
"detect the unsupported dest operator %s when config " |
|
|
|
"force_output_use_user_specified_memory, set " |
|
|
|
"force_output_use_user_specified_memory to false\n", |
|
|
|
opr->cname()); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void NetworkImplDft::load_model( |
|
|
|
std::shared_ptr<void> model_mem, size_t size, |
|
|
|
std::unordered_map<std::string, LiteAny> separate_config_map) { |
|
|
@@ -378,6 +404,8 @@ void NetworkImplDft::load_model( |
|
|
|
|
|
|
|
m_load_result = m_loader->load(m_load_config, true); |
|
|
|
|
|
|
|
adapt_option_valid(); |
|
|
|
|
|
|
|
cross_compnode_model_detect(); |
|
|
|
|
|
|
|
//! update the IO of the network |
|
|
|