Browse Source

fix(mge): correct pytype when calling apply from python

GitOrigin-RevId: 6abfa06ada
release-1.5
Megvii Engine Team 4 years ago
parent
commit
e417798f9e
4 changed files with 26 additions and 1 deletions
  1. +7
    -0
      imperative/python/megengine/tensor.py
  2. +4
    -0
      imperative/python/src/pyext17.h
  3. +6
    -0
      imperative/python/src/tensor.cpp
  4. +9
    -1
      imperative/python/test/unit/core/test_tensor_wrapper.py

+ 7
- 0
imperative/python/megengine/tensor.py View File

@@ -246,4 +246,11 @@ tensor = Tensor
class Parameter(Tensor): class Parameter(Tensor):
r""" r"""
A kind of Tensor that is to be considered a module parameter. A kind of Tensor that is to be considered a module parameter.

.. note::

Operations happened on Parameter usually return a Tensor instead of Parameter.
For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor.
Any operations between Parameter and Tensor will have Tensor as outputs.

""" """

+ 4
- 0
imperative/python/src/pyext17.h View File

@@ -397,6 +397,10 @@ public:
return Py_TYPE(op) == &m_type; return Py_TYPE(op) == &m_type;
} }


bool same_pytype(PyTypeObject *pt) {
return pt == &m_type;
}

PyObject* finalize() { PyObject* finalize() {
if (!m_finalized) { if (!m_finalized) {
m_finalized = true; m_finalized = true;


+ 6
- 0
imperative/python/src/tensor.cpp View File

@@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
auto* op = args[0]; auto* op = args[0];


PyTypeObject* pytype = args[1]->ob_type; PyTypeObject* pytype = args[1]->ob_type;

// check if pytype is Parameter(and all other python Tensor's derived class),
// if yes, using it's tp_base(python Tensor)
if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) {
pytype = pytype->tp_base;
}
++args; ++args;
--nargs; --nargs;




+ 9
- 1
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -13,7 +13,7 @@ import pytest
from utils import make_tensor from utils import make_tensor


from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Tensor
from megengine.tensor import Parameter, Tensor
from megengine.utils.network import Network from megengine.utils.network import Network




@@ -198,3 +198,11 @@ def test_name():
assert x.name == "x" assert x.name == "x"
x = Tensor(0, name="x") x = Tensor(0, name="x")
assert x.name == "x" assert x.name == "x"


def test_tensor_type():
x1 = Parameter(1)
x2 = Tensor(2)
y1 = x1 + x2
y2 = x2 + x1
assert type(y1) == type(y2)

Loading…
Cancel
Save