Browse Source

feat(mge/imperative): expose c++ tensor reference count

GitOrigin-RevId: 1940881adc
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
adc49de803
2 changed files with 5 additions and 1 deletions
  1. +4
    -1
      imperative/python/src/tensor.cpp
  2. +1
    -0
      imperative/python/src/tensor.h

+ 4
- 1
imperative/python/src/tensor.cpp View File

@@ -541,6 +541,7 @@ struct TensorWeakRef {
}
return py::none();
}
int _use_cnt() { return wptr.use_count(); }
};

/* ============== convert inputs ============== */
@@ -774,6 +775,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def<&TensorWrapper::_use_cnt>("_use_cnt")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
@@ -787,7 +789,8 @@ void init_tensor(py::module m) {

py::class_<TensorWeakRef>(m, "TensorWeakRef")
.def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator());
.def("__call__", &TensorWeakRef::operator())
.def("_use_cnt", &TensorWeakRef::_use_cnt);

static PyMethodDef method_defs[] = {
MGE_PY_INTERFACE(apply, py_apply),


+ 1
- 0
imperative/python/src/tensor.h View File

@@ -170,6 +170,7 @@ struct TensorWrapper {
void set_compiled_info(PyObject *);
PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject *);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
};




Loading…
Cancel
Save