diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index add99da0..e46ec7cd 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -654,3 +654,14 @@ def to_mgb_supported_dtype(dtype_): ): return dtype_ return _detail._to_mgb_supported_dtype(dtype_) + + +def return_free_memory(): + """return free memory chunks on all devices. + + This function will try it best to free all consecutive free chunks back to + operating system, small pieces may not be returned. + + Please notice that this function will not move any memory in-use. + """ + _detail.CompNode._try_coalesce_all_free_memory() diff --git a/python_module/src/swig/comp_node.i b/python_module/src/swig/comp_node.i index 2d11eeef..eb10bc2a 100644 --- a/python_module/src/swig/comp_node.i +++ b/python_module/src/swig/comp_node.i @@ -49,6 +49,10 @@ class CompNode { str2device_type(type, false)); } + static void _try_coalesce_all_free_memory() { + CompNode::try_coalesce_all_free_memory(); + } + bool _check_eq(const CompNode &rhs) const { return (*$self) == rhs; } @@ -83,6 +87,10 @@ class CompNode { size_t __hash__() { return mgb::hash(*$self); } + + std::pair _get_mem_status_bytes() { + return $self->get_mem_status_bytes(); + } } %pythoncode { @@ -121,6 +129,16 @@ class CompNode { """physical locator: a tuple containing (type, device, stream)""" t, d, s = self._get_locator()[3:] return self.DEVICE_TYPE_MAP[t], d, s + + @property + def mem_status_bytes(self) -> [int, int]: + """get (total, free) memory on the computing device in bytes. + + Free memory includes memory chunks that buffered by the memory manager. + + Please note that the results are the same for different CompNode within same device. + """ + return self._get_mem_status_bytes() } }; %template(_VectorCompNode) std::vector; diff --git a/python_module/src/swig/mgb.i b/python_module/src/swig/mgb.i index a7769807..c9c2a950 100644 --- a/python_module/src/swig/mgb.i +++ b/python_module/src/swig/mgb.i @@ -30,6 +30,7 @@ void _init_bfloat16_types(PyObject *m); // implemented in bfloat16.cpp %template(_VectorInt) std::vector; %template(_VectorString) std::vector; %template(_PairStringSizeT) std::pair; +%template(_PairSizeTSizeT) std::pair; %template(_VectorPairUint64String) std::vector>; %pythoncode %{