GitOrigin-RevId: ffaf4c1416
tags/v1.7.1.m1
@@ -26,6 +26,7 @@ | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/graph/cg.h" | #include "megbrain/graph/cg.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
#if MGB_OPENCL | #if MGB_OPENCL | ||||
@@ -340,6 +341,31 @@ void NetworkImplDft::cross_compnode_model_detect() { | |||||
m_nr_device_type = nr_used_device_type.size(); | m_nr_device_type = nr_used_device_type.size(); | ||||
} | } | ||||
void NetworkImplDft::adapt_option_valid() { | |||||
auto&& options = m_load_config.comp_graph->options(); | |||||
if (m_user_config->options.force_output_use_user_specified_memory) { | |||||
for (auto&& out : m_load_result.output_var_list) { | |||||
auto opr = out.node()->owner_opr(); | |||||
//! all the dest operator inherit from ReadonlyFwdHelper can't | |||||
//! support force_output_use_user_specified_memory options | |||||
if (opr->try_cast_final<mgb::opr::Reshape>() || | |||||
opr->try_cast_final<mgb::opr::Broadcast>() || | |||||
opr->try_cast_final<mgb::opr::Subtensor>() || | |||||
opr->try_cast_final<mgb::opr::AxisAddRemove>() || | |||||
opr->try_cast_final<mgb::opr::Dimshuffle>()) { | |||||
m_user_config->options.force_output_use_user_specified_memory = false; | |||||
options.force_output_use_user_specified_memory = false; | |||||
LITE_WARN( | |||||
"detect the unsupported dest operator %s when config " | |||||
"force_output_use_user_specified_memory, set " | |||||
"force_output_use_user_specified_memory to false\n", | |||||
opr->cname()); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void NetworkImplDft::load_model( | void NetworkImplDft::load_model( | ||||
std::shared_ptr<void> model_mem, size_t size, | std::shared_ptr<void> model_mem, size_t size, | ||||
std::unordered_map<std::string, LiteAny> separate_config_map) { | std::unordered_map<std::string, LiteAny> separate_config_map) { | ||||
@@ -378,6 +404,8 @@ void NetworkImplDft::load_model( | |||||
m_load_result = m_loader->load(m_load_config, true); | m_load_result = m_loader->load(m_load_config, true); | ||||
adapt_option_valid(); | |||||
cross_compnode_model_detect(); | cross_compnode_model_detect(); | ||||
//! update the IO of the network | //! update the IO of the network | ||||
@@ -214,6 +214,9 @@ private: | |||||
//! optimized output tensor copy | //! optimized output tensor copy | ||||
void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor); | void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor); | ||||
//! adapt option valid, it should call after update_io | |||||
void adapt_option_valid(); | |||||
private: | private: | ||||
bool m_async = false; | bool m_async = false; | ||||
bool m_is_cpu_inplace_mode = false; | bool m_is_cpu_inplace_mode = false; | ||||
@@ -250,14 +250,10 @@ std::unique_ptr<CompNodeSeqRecorder> ComputingGraphImpl::ComputingSequence:: | |||||
"graph."); | "graph."); | ||||
return {}; | return {}; | ||||
} | } | ||||
auto is_graph_dest_varnode = [&](VarNode* var) { | |||||
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(var).size() == | |||||
0; | |||||
}; | |||||
for (auto i : *m_opr_seq) { | for (auto i : *m_opr_seq) { | ||||
for (auto j : i->output()) { | for (auto j : i->output()) { | ||||
if (!is_static_var_storage(j) && !is_graph_dest_varnode(j)) { | |||||
if (!is_static_var_storage(j) && !j->is_graph_dest_varnode()) { | |||||
mgb_log_error( | mgb_log_error( | ||||
"can not enable CompNodeSeqRecorder because var " | "can not enable CompNodeSeqRecorder because var " | ||||
"storage not static: %s", | "storage not static: %s", | ||||
@@ -504,6 +504,10 @@ public: | |||||
*/ | */ | ||||
MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); | MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); | ||||
//! whether the var is graph output, if it is output, the Flag of | |||||
//! NO_SYS_MEM_ALLOC can be modified. | |||||
MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode(); | |||||
private: | private: | ||||
//! whether its memory should be allocated by mgb system during graph | //! whether its memory should be allocated by mgb system during graph | ||||
//! execution; initialized in VarNodeMemManager::reset_opr_seq() | //! execution; initialized in VarNodeMemManager::reset_opr_seq() | ||||
@@ -552,10 +556,6 @@ private: | |||||
MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag); | MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag); | ||||
//! whether the var is graph output, if it is output, the Flag of | |||||
//! NO_SYS_MEM_ALLOC can be modified. | |||||
bool is_graph_dest_varnode(); | |||||
MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor( | MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor( | ||||
const DeviceTensorND& value); | const DeviceTensorND& value); | ||||
@@ -919,7 +919,7 @@ Split::Options Split::Options::make_callback( | |||||
int axis, size_t nr_part, callback_t callback) { | int axis, size_t nr_part, callback_t callback) { | ||||
mgb_assert(nr_part); | mgb_assert(nr_part); | ||||
Options rst; | Options rst; | ||||
rst.method = Method::CALLBACK; | |||||
rst.method = Method::CALL_BACK; | |||||
rst.axis = axis; | rst.axis = axis; | ||||
rst.callback = callback; | rst.callback = callback; | ||||
rst.nr_part = nr_part; | rst.nr_part = nr_part; | ||||
@@ -955,7 +955,7 @@ Split::Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config) | |||||
// disable dedup | // disable dedup | ||||
add_equivalence_component<ScalarHash<void*>>(this); | add_equivalence_component<ScalarHash<void*>>(this); | ||||
mgb_assert(m_opt.method == Options::Method::CALLBACK); | |||||
mgb_assert(m_opt.method == Options::Method::CALL_BACK); | |||||
mgb_assert(m_opt.nr_part); | mgb_assert(m_opt.nr_part); | ||||
} | } | ||||
@@ -172,7 +172,7 @@ cg::OperatorNodeBase* opr_shallow_copy_split( | |||||
auto option = opr.options(); | auto option = opr.options(); | ||||
using Meth = Split::Options::Method; | using Meth = Split::Options::Method; | ||||
switch (option.method) { | switch (option.method) { | ||||
case Meth::CALLBACK: | |||||
case Meth::CALL_BACK: | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
break; | break; | ||||
case Meth::SPECIFY: | case Meth::SPECIFY: | ||||
@@ -408,8 +408,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Split, intl::OutshapeBySymvarOprBase) // { | |||||
public: | public: | ||||
struct Options { | struct Options { | ||||
enum class Method { | enum class Method { | ||||
SPECIFY, //!< specify output sizes | |||||
CALLBACK //!< output sizes obtained from callback | |||||
SPECIFY, //!< specify output sizes | |||||
CALL_BACK //!< output sizes obtained from callback | |||||
}; | }; | ||||
Method method; | Method method; | ||||
size_t nr_part = 0; | size_t nr_part = 0; | ||||