@@ -288,6 +288,7 @@ ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, Sm | |||||
cg_holder.graph->options().async_exec_level = 0; | cg_holder.graph->options().async_exec_level = 0; | ||||
cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level; | cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level; | ||||
cg_holder.graph->options().enable_var_mem_defragment = false; | cg_holder.graph->options().enable_var_mem_defragment = false; | ||||
cg_holder.graph->options().comp_seq_sync_device = false; | |||||
cg_holder.graph->set_device_memory_allocator(cg_holder.allocator); | cg_holder.graph->set_device_memory_allocator(cg_holder.allocator); | ||||
// cg_holder.graph->options().graph_opt.jit = 2; | // cg_holder.graph->options().graph_opt.jit = 2; | ||||
VarNodeArray input_vars; | VarNodeArray input_vars; | ||||
@@ -385,21 +385,27 @@ void ComputingGraphImpl::ComputingSequence::do_wait(bool explicit_user_wait) { | |||||
} | } | ||||
} | } | ||||
for (auto cn : m_used_comp_node) { | |||||
m_event_end.at(cn)->host_wait(); | |||||
bool sync_device = m_owner_graph->options().comp_seq_sync_device; | |||||
if (sync_device) { | |||||
for (auto cn : m_used_comp_node) { | |||||
m_event_end.at(cn)->host_wait(); | |||||
} | |||||
} | } | ||||
m_wait_finished = true; | m_wait_finished = true; | ||||
#if MGB_NEED_MEGDNN_ASYNC_ERROR | #if MGB_NEED_MEGDNN_ASYNC_ERROR | ||||
// FIXME: It CAN NOT work well if more than one ComputingSequnces has been | // FIXME: It CAN NOT work well if more than one ComputingSequnces has been | ||||
// executed on the same compnode and got AsyncError concurrently, because | // executed on the same compnode and got AsyncError concurrently, because | ||||
// only the first async error on each comp_node would be recorded. | // only the first async error on each comp_node would be recorded. | ||||
for (auto&& cn : m_used_comp_node) { | |||||
auto error = cn.check_async_error(); | |||||
if (error) { | |||||
static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info()) | |||||
->opr() | |||||
->owner_graph() | |||||
->record_async_error(std::move(error)); | |||||
if (sync_device) { | |||||
for (auto&& cn : m_used_comp_node) { | |||||
auto error = cn.check_async_error(); | |||||
if (error) { | |||||
static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info()) | |||||
->opr() | |||||
->owner_graph() | |||||
->record_async_error(std::move(error)); | |||||
} | |||||
} | } | ||||
} | } | ||||
#endif | #endif | ||||
@@ -520,6 +520,9 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
*/ | */ | ||||
bool no_force_inplace = false; | bool no_force_inplace = false; | ||||
//! whether to sync comp_node when waiting computing sequence | |||||
bool comp_seq_sync_device = true; | |||||
//! add extra deps for the comp seq if a specific var is dependent | //! add extra deps for the comp seq if a specific var is dependent | ||||
ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ||||