diff --git a/src/core/impl/graph/event.cpp b/src/core/impl/graph/event.cpp index 401d8fac..48746026 100644 --- a/src/core/impl/graph/event.cpp +++ b/src/core/impl/graph/event.cpp @@ -28,6 +28,9 @@ MGB_TYPEINFO_OBJ_IMPL(CompSeqExecBeforeStart); MGB_TYPEINFO_OBJ_IMPL(CompSeqExecFinished); MGB_TYPEINFO_OBJ_IMPL(CompSeqExecError); MGB_TYPEINFO_OBJ_IMPL(SubgraphAssociated); +#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER +MGB_TYPEINFO_OBJ_IMPL(BeforeMemDefrag); +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/impl/graph/symbol_var.cpp b/src/core/impl/graph/symbol_var.cpp index 878ff750..323de296 100644 --- a/src/core/impl/graph/symbol_var.cpp +++ b/src/core/impl/graph/symbol_var.cpp @@ -49,6 +49,11 @@ SymbolVar SymbolVar::flatten() const { return opr::Reshape::make(*this, make_scalar(1), 0); } +SymbolVar SymbolVar::add_axis(size_t idx) const { + return opr::AxisAddRemove::make(*this, + {opr::AxisAddRemove::AxisDesc::make_add(idx)}); +} + Maybe SymbolVar::as_immutable_scalar() const { using IT = static_infer::InferType; auto &&mgr = node()->owner_graph()->static_infer_manager(); diff --git a/src/core/impl/graph/var_node_mem_mgr.cpp b/src/core/impl/graph/var_node_mem_mgr.cpp index aaa7bb5c..401f2390 100644 --- a/src/core/impl/graph/var_node_mem_mgr.cpp +++ b/src/core/impl/graph/var_node_mem_mgr.cpp @@ -294,6 +294,14 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): on_comp_seq_finish); graph->event().register_receiver_permanent( on_comp_seq_error); + +#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER + auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) { + m_cuda_asyn_var_releaser->wait_release_finish(); + }; + graph->event().register_receiver_permanent( + on_mem_defrag_start); +#endif } VarNodeMemManager::~VarNodeMemManager() noexcept = default; diff --git a/src/core/impl/graph/var_node_mem_mgr/defrag.cpp b/src/core/impl/graph/var_node_mem_mgr/defrag.cpp index 5049310b..5053ce10 100644 --- a/src/core/impl/graph/var_node_mem_mgr/defrag.cpp +++ b/src/core/impl/graph/var_node_mem_mgr/defrag.cpp @@ -77,6 +77,7 @@ void VarDevMemDefragmenter::defrag(VarNode* req_var, ->current_exec_env(); mgb_assert(exec_env); exec_env->pause_exec(); + m_mem_mgr->owner_graph()->event().signal_inplace(); MGB_TRY { defrag_impl(req_var, cn_info, extra_size); } MGB_FINALLY(exec_env->resume_exec();); } @@ -123,7 +124,8 @@ void VarDevMemDefragmenter::defrag_impl(VarNode* req_var, if (refcnt == iter->second.readers.size()) { tot_size += get_aligned_power2(iter->first->size(), alignment); nr_var += iter->second.readers.size(); - auto&& tensor = iter->first->owner_var->dev_tensor(); + auto owner_var = iter->first->owner_var; + auto&& tensor = owner_var->m_dev_tensor; iter->second.value.comp_node(cn) .ensure_size(iter->first->size()) .copy_from(tensor.storage(), iter->first->size()); @@ -132,6 +134,17 @@ void VarDevMemDefragmenter::defrag_impl(VarNode* req_var, for (auto var : iter->second.readers) { const_cast(var->dev_tensor()).storage({}); } + // release memory of owner_var + auto&& mem_plan = owner_var->mem_plan(); + if (!mem_plan.valid()) { + // mem_plan of owner_var was invalid here if all reader oprs + // of owner_var have already been executed, but its tensor + // storage should not be released until the refcnt of chunk + // decreasing to zero (see release_chunk() for more details) + mgb_assert(tensor.storage().comp_node_valid() && + tensor.layout().eq_layout(mem_plan.layout())); + tensor.storage({}); + } } else { mgb_assert(refcnt > iter->second.readers.size()); ++nr_refcnt_mismatch; @@ -170,6 +183,11 @@ void VarDevMemDefragmenter::defrag_impl(VarNode* req_var, } mgb_assert(var->dev_tensor_valid()); } + auto owner_var = i.first->owner_var; + if (!owner_var->mem_plan().valid()) { + owner_var->m_dev_tensor.reset( + storage, owner_var->mem_plan().layout()); + } } mgb_assert(offset + extra_size == tot_size); cn.sync(); // wait copy finish before destructing host values diff --git a/src/core/impl/graph/var_node_mem_mgr/defrag.h b/src/core/impl/graph/var_node_mem_mgr/defrag.h index 92dd5989..2b8aad9d 100644 --- a/src/core/impl/graph/var_node_mem_mgr/defrag.h +++ b/src/core/impl/graph/var_node_mem_mgr/defrag.h @@ -13,12 +13,6 @@ #include "../impl_common.h" -#if MGB_CUDA && MGB_ENABLE_EXCEPTION -#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 1 -#else -#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 0 -#endif - namespace mgb { namespace cg { diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h index 98078640..deaedbf7 100644 --- a/src/core/include/megbrain/graph/bases.h +++ b/src/core/include/megbrain/graph/bases.h @@ -40,6 +40,12 @@ #define MGB_IF_COND_EXEC(x...) #endif +#if MGB_CUDA && MGB_ENABLE_EXCEPTION +#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 1 +#else +#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 0 +#endif // whether enable memory defragment + namespace mgb { //! computing graph diff --git a/src/core/include/megbrain/graph/event.h b/src/core/include/megbrain/graph/event.h index 23e259d8..1be05b78 100644 --- a/src/core/include/megbrain/graph/event.h +++ b/src/core/include/megbrain/graph/event.h @@ -224,6 +224,15 @@ struct SubgraphAssociated { MGB_TYPEINFO_OBJ_DECL; }; +#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER +/*! + * \brief signaled before graph memory defragementation + */ +struct BeforeMemDefrag { + MGB_TYPEINFO_OBJ_DECL; +}; +#endif + } // namespace event } // namespace cg } // namespace mgb diff --git a/src/core/include/megbrain/graph/symbol_var.h b/src/core/include/megbrain/graph/symbol_var.h index ea819a55..a89598dc 100644 --- a/src/core/include/megbrain/graph/symbol_var.h +++ b/src/core/include/megbrain/graph/symbol_var.h @@ -66,6 +66,7 @@ class SymbolVar { SymbolVar broadcast(SymbolVar tshape) const; SymbolVar symshape() const; SymbolVar flatten() const; + SymbolVar add_axis(size_t idx) const; const TensorShape& shape() const { return m_node->shape(); diff --git a/src/core/test/graph/defrag.cpp b/src/core/test/graph/defrag.cpp index 48fbc6e6..cd025c6a 100644 --- a/src/core/test/graph/defrag.cpp +++ b/src/core/test/graph/defrag.cpp @@ -49,8 +49,8 @@ void run_graph(size_t mem_reserved, bool enable_defrag) { graph->options().var_sanity_check_first_run = false; auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x).rename("x0"), - // x1 has rdonly fwd - x1 = opr::Concat::make({x0, x0}, 0).reshape({size*2}).rename("x1"), + // x1 has rdonly fwd chain + x1 = opr::Concat::make({x0, x0}, 0).add_axis(0).reshape({size*2}).rename("x1"), x2 = opr::Concat::make({x1, x0}, 0).rename("x2"), x3 = opr::Concat::make({x2, x0}, 0).rename("x3"), x4 = opr::Concat::make({x3, x0}, 0).rename("x4"),