GitOrigin-RevId: 6abfa06ada
release-1.5
@@ -246,4 +246,11 @@ tensor = Tensor | |||||
class Parameter(Tensor): | class Parameter(Tensor): | ||||
r""" | r""" | ||||
A kind of Tensor that is to be considered a module parameter. | A kind of Tensor that is to be considered a module parameter. | ||||
.. note:: | |||||
Operations happened on Parameter usually return a Tensor instead of Parameter. | |||||
For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor. | |||||
Any operations between Parameter and Tensor will have Tensor as outputs. | |||||
""" | """ |
@@ -397,6 +397,10 @@ public: | |||||
return Py_TYPE(op) == &m_type; | return Py_TYPE(op) == &m_type; | ||||
} | } | ||||
bool same_pytype(PyTypeObject *pt) { | |||||
return pt == &m_type; | |||||
} | |||||
PyObject* finalize() { | PyObject* finalize() { | ||||
if (!m_finalized) { | if (!m_finalized) { | ||||
m_finalized = true; | m_finalized = true; | ||||
@@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
auto* op = args[0]; | auto* op = args[0]; | ||||
PyTypeObject* pytype = args[1]->ob_type; | PyTypeObject* pytype = args[1]->ob_type; | ||||
// check if pytype is Parameter(and all other python Tensor's derived class), | |||||
// if yes, using it's tp_base(python Tensor) | |||||
if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) { | |||||
pytype = pytype->tp_base; | |||||
} | |||||
++args; | ++args; | ||||
--nargs; | --nargs; | ||||
@@ -13,7 +13,7 @@ import pytest | |||||
from utils import make_tensor | from utils import make_tensor | ||||
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | ||||
from megengine.tensor import Tensor | |||||
from megengine.tensor import Parameter, Tensor | |||||
from megengine.utils.network import Network | from megengine.utils.network import Network | ||||
@@ -198,3 +198,11 @@ def test_name(): | |||||
assert x.name == "x" | assert x.name == "x" | ||||
x = Tensor(0, name="x") | x = Tensor(0, name="x") | ||||
assert x.name == "x" | assert x.name == "x" | ||||
def test_tensor_type(): | |||||
x1 = Parameter(1) | |||||
x2 = Tensor(2) | |||||
y1 = x1 + x2 | |||||
y2 = x2 + x1 | |||||
assert type(y1) == type(y2) |