@@ -28,6 +28,9 @@ MGB_TYPEINFO_OBJ_IMPL(CompSeqExecBeforeStart); | |||||
MGB_TYPEINFO_OBJ_IMPL(CompSeqExecFinished); | MGB_TYPEINFO_OBJ_IMPL(CompSeqExecFinished); | ||||
MGB_TYPEINFO_OBJ_IMPL(CompSeqExecError); | MGB_TYPEINFO_OBJ_IMPL(CompSeqExecError); | ||||
MGB_TYPEINFO_OBJ_IMPL(SubgraphAssociated); | MGB_TYPEINFO_OBJ_IMPL(SubgraphAssociated); | ||||
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER | |||||
MGB_TYPEINFO_OBJ_IMPL(BeforeMemDefrag); | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
@@ -49,6 +49,11 @@ SymbolVar SymbolVar::flatten() const { | |||||
return opr::Reshape::make(*this, make_scalar(1), 0); | return opr::Reshape::make(*this, make_scalar(1), 0); | ||||
} | } | ||||
SymbolVar SymbolVar::add_axis(size_t idx) const { | |||||
return opr::AxisAddRemove::make(*this, | |||||
{opr::AxisAddRemove::AxisDesc::make_add(idx)}); | |||||
} | |||||
Maybe<DTypeScalar> SymbolVar::as_immutable_scalar() const { | Maybe<DTypeScalar> SymbolVar::as_immutable_scalar() const { | ||||
using IT = static_infer::InferType; | using IT = static_infer::InferType; | ||||
auto &&mgr = node()->owner_graph()->static_infer_manager(); | auto &&mgr = node()->owner_graph()->static_infer_manager(); | ||||
@@ -294,6 +294,14 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): | |||||
on_comp_seq_finish); | on_comp_seq_finish); | ||||
graph->event().register_receiver_permanent<event::CompSeqExecError>( | graph->event().register_receiver_permanent<event::CompSeqExecError>( | ||||
on_comp_seq_error); | on_comp_seq_error); | ||||
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER | |||||
auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) { | |||||
m_cuda_asyn_var_releaser->wait_release_finish(); | |||||
}; | |||||
graph->event().register_receiver_permanent<event::BeforeMemDefrag>( | |||||
on_mem_defrag_start); | |||||
#endif | |||||
} | } | ||||
VarNodeMemManager::~VarNodeMemManager() noexcept = default; | VarNodeMemManager::~VarNodeMemManager() noexcept = default; | ||||
@@ -77,6 +77,7 @@ void VarDevMemDefragmenter::defrag(VarNode* req_var, | |||||
->current_exec_env(); | ->current_exec_env(); | ||||
mgb_assert(exec_env); | mgb_assert(exec_env); | ||||
exec_env->pause_exec(); | exec_env->pause_exec(); | ||||
m_mem_mgr->owner_graph()->event().signal_inplace<event::BeforeMemDefrag>(); | |||||
MGB_TRY { defrag_impl(req_var, cn_info, extra_size); } | MGB_TRY { defrag_impl(req_var, cn_info, extra_size); } | ||||
MGB_FINALLY(exec_env->resume_exec();); | MGB_FINALLY(exec_env->resume_exec();); | ||||
} | } | ||||
@@ -123,7 +124,8 @@ void VarDevMemDefragmenter::defrag_impl(VarNode* req_var, | |||||
if (refcnt == iter->second.readers.size()) { | if (refcnt == iter->second.readers.size()) { | ||||
tot_size += get_aligned_power2(iter->first->size(), alignment); | tot_size += get_aligned_power2(iter->first->size(), alignment); | ||||
nr_var += iter->second.readers.size(); | nr_var += iter->second.readers.size(); | ||||
auto&& tensor = iter->first->owner_var->dev_tensor(); | |||||
auto owner_var = iter->first->owner_var; | |||||
auto&& tensor = owner_var->m_dev_tensor; | |||||
iter->second.value.comp_node(cn) | iter->second.value.comp_node(cn) | ||||
.ensure_size(iter->first->size()) | .ensure_size(iter->first->size()) | ||||
.copy_from(tensor.storage(), iter->first->size()); | .copy_from(tensor.storage(), iter->first->size()); | ||||
@@ -132,6 +134,17 @@ void VarDevMemDefragmenter::defrag_impl(VarNode* req_var, | |||||
for (auto var : iter->second.readers) { | for (auto var : iter->second.readers) { | ||||
const_cast<DeviceTensorND&>(var->dev_tensor()).storage({}); | const_cast<DeviceTensorND&>(var->dev_tensor()).storage({}); | ||||
} | } | ||||
// release memory of owner_var | |||||
auto&& mem_plan = owner_var->mem_plan(); | |||||
if (!mem_plan.valid()) { | |||||
// mem_plan of owner_var was invalid here if all reader oprs | |||||
// of owner_var have already been executed, but its tensor | |||||
// storage should not be released until the refcnt of chunk | |||||
// decreasing to zero (see release_chunk() for more details) | |||||
mgb_assert(tensor.storage().comp_node_valid() && | |||||
tensor.layout().eq_layout(mem_plan.layout())); | |||||
tensor.storage({}); | |||||
} | |||||
} else { | } else { | ||||
mgb_assert(refcnt > iter->second.readers.size()); | mgb_assert(refcnt > iter->second.readers.size()); | ||||
++nr_refcnt_mismatch; | ++nr_refcnt_mismatch; | ||||
@@ -170,6 +183,11 @@ void VarDevMemDefragmenter::defrag_impl(VarNode* req_var, | |||||
} | } | ||||
mgb_assert(var->dev_tensor_valid()); | mgb_assert(var->dev_tensor_valid()); | ||||
} | } | ||||
auto owner_var = i.first->owner_var; | |||||
if (!owner_var->mem_plan().valid()) { | |||||
owner_var->m_dev_tensor.reset( | |||||
storage, owner_var->mem_plan().layout()); | |||||
} | |||||
} | } | ||||
mgb_assert(offset + extra_size == tot_size); | mgb_assert(offset + extra_size == tot_size); | ||||
cn.sync(); // wait copy finish before destructing host values | cn.sync(); // wait copy finish before destructing host values | ||||
@@ -13,12 +13,6 @@ | |||||
#include "../impl_common.h" | #include "../impl_common.h" | ||||
#if MGB_CUDA && MGB_ENABLE_EXCEPTION | |||||
#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 1 | |||||
#else | |||||
#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 0 | |||||
#endif | |||||
namespace mgb { | namespace mgb { | ||||
namespace cg { | namespace cg { | ||||
@@ -40,6 +40,12 @@ | |||||
#define MGB_IF_COND_EXEC(x...) | #define MGB_IF_COND_EXEC(x...) | ||||
#endif | #endif | ||||
#if MGB_CUDA && MGB_ENABLE_EXCEPTION | |||||
#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 1 | |||||
#else | |||||
#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 0 | |||||
#endif // whether enable memory defragment | |||||
namespace mgb { | namespace mgb { | ||||
//! computing graph | //! computing graph | ||||
@@ -224,6 +224,15 @@ struct SubgraphAssociated { | |||||
MGB_TYPEINFO_OBJ_DECL; | MGB_TYPEINFO_OBJ_DECL; | ||||
}; | }; | ||||
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER | |||||
/*! | |||||
* \brief signaled before graph memory defragementation | |||||
*/ | |||||
struct BeforeMemDefrag { | |||||
MGB_TYPEINFO_OBJ_DECL; | |||||
}; | |||||
#endif | |||||
} // namespace event | } // namespace event | ||||
} // namespace cg | } // namespace cg | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -66,6 +66,7 @@ class SymbolVar { | |||||
SymbolVar broadcast(SymbolVar tshape) const; | SymbolVar broadcast(SymbolVar tshape) const; | ||||
SymbolVar symshape() const; | SymbolVar symshape() const; | ||||
SymbolVar flatten() const; | SymbolVar flatten() const; | ||||
SymbolVar add_axis(size_t idx) const; | |||||
const TensorShape& shape() const { | const TensorShape& shape() const { | ||||
return m_node->shape(); | return m_node->shape(); | ||||
@@ -49,8 +49,8 @@ void run_graph(size_t mem_reserved, bool enable_defrag) { | |||||
graph->options().var_sanity_check_first_run = false; | graph->options().var_sanity_check_first_run = false; | ||||
auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x).rename("x0"), | auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x).rename("x0"), | ||||
// x1 has rdonly fwd | |||||
x1 = opr::Concat::make({x0, x0}, 0).reshape({size*2}).rename("x1"), | |||||
// x1 has rdonly fwd chain | |||||
x1 = opr::Concat::make({x0, x0}, 0).add_axis(0).reshape({size*2}).rename("x1"), | |||||
x2 = opr::Concat::make({x1, x0}, 0).rename("x2"), | x2 = opr::Concat::make({x1, x0}, 0).rename("x2"), | ||||
x3 = opr::Concat::make({x2, x0}, 0).rename("x3"), | x3 = opr::Concat::make({x2, x0}, 0).rename("x3"), | ||||
x4 = opr::Concat::make({x3, x0}, 0).rename("x4"), | x4 = opr::Concat::make({x3, x0}, 0).rename("x4"), | ||||