diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 574e7851..cc17d924 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -246,4 +246,11 @@ tensor = Tensor class Parameter(Tensor): r""" A kind of Tensor that is to be considered a module parameter. + + .. note:: + + Operations happened on Parameter usually return a Tensor instead of Parameter. + For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor. + Any operations between Parameter and Tensor will have Tensor as outputs. + """ diff --git a/imperative/python/src/pyext17.h b/imperative/python/src/pyext17.h index feced951..4f10a207 100644 --- a/imperative/python/src/pyext17.h +++ b/imperative/python/src/pyext17.h @@ -397,6 +397,10 @@ public: return Py_TYPE(op) == &m_type; } + bool same_pytype(PyTypeObject *pt) { + return pt == &m_type; + } + PyObject* finalize() { if (!m_finalized) { m_finalized = true; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 8d44ae05..67be8ee1 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje auto* op = args[0]; PyTypeObject* pytype = args[1]->ob_type; + + // check if pytype is Parameter(and all other python Tensor's derived class), + // if yes, using it's tp_base(python Tensor) + if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) { + pytype = pytype->tp_base; + } ++args; --nargs; diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index b3aa6dfc..ba85be70 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -13,7 +13,7 @@ import pytest from utils import make_tensor from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 -from megengine.tensor import Tensor +from megengine.tensor import Parameter, Tensor from megengine.utils.network import Network @@ -198,3 +198,11 @@ def test_name(): assert x.name == "x" x = Tensor(0, name="x") assert x.name == "x" + + +def test_tensor_type(): + x1 = Parameter(1) + x2 = Tensor(2) + y1 = x1 + x2 + y2 = x2 + x1 + assert type(y1) == type(y2)