Browse Source

refactor(mge/imperative): move detach into C++

GitOrigin-RevId: 8c0d86cbbf
release-1.2
Megvii Engine Team 4 years ago
parent
commit
34c705fcbf
4 changed files with 12 additions and 11 deletions
  1. +0
    -7
      imperative/python/megengine/tensor.py
  2. +10
    -3
      imperative/python/src/tensor.cpp
  3. +1
    -0
      imperative/python/src/tensor.h
  4. +1
    -1
      imperative/python/test/unit/functional/test_functional.py

+ 0
- 7
imperative/python/megengine/tensor.py View File

@@ -118,13 +118,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
def __setstate__(self, state):
self.q_dict = state.pop("qdict")

def detach(self):
r"""
Returns a new tensor sharing the same data memory, which is treated as a constant
during backward gradient calcuation, i.e. its gradient is zero.
"""
Wrapper = type(self)
return Wrapper(self)


tensor = Tensor


+ 10
- 3
imperative/python/src/tensor.cpp View File

@@ -68,9 +68,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
return nullptr;
}
auto* op = args[0];
if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){
op = PyObject_CallMethod(op,"to_c","");
}

PyTypeObject* pytype = args[1]->ob_type;
++args;
@@ -195,6 +192,15 @@ void TensorWrapper::reset(PyObject* tensor) {
m_tensor = t->m_tensor;
}

PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type;
auto new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr();

}

PyObject* TensorWrapper::isscalar() {
if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
Py_RETURN_TRUE;
@@ -233,6 +239,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar")
.def<&TensorWrapper::detach>("detach")
.finalize();
if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type);


+ 1
- 0
imperative/python/src/tensor.h View File

@@ -128,6 +128,7 @@ struct TensorWrapper {
PyObject* device();
PyObject* numpy();
void reset(PyObject*);
PyObject* detach();
PyObject* isscalar();
void setscalar();
};


+ 1
- 1
imperative/python/test/unit/functional/test_functional.py View File

@@ -166,7 +166,7 @@ def test_interpolate():


def _save_to(self, name="grad"):
def callback(tensor, grad):
def callback(grad):
setattr(self, name, grad)

return callback


Loading…
Cancel
Save