Browse Source

fix(mge/functional): fix trace warp_perspective

GitOrigin-RevId: 2071bb63a8
release-1.1
Megvii Engine Team 4 years ago
parent
commit
06041f8a7e
2 changed files with 37 additions and 1 deletions
  1. +3
    -1
      imperative/python/megengine/functional/nn.py
  2. +34
    -0
      imperative/python/test/unit/test_tracing.py

+ 3
- 1
imperative/python/megengine/functional/nn.py View File

@@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed
from ..random import uniform
from ..tensor import Tensor
@@ -868,7 +869,8 @@ def warp_perspective(
imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val
)
inp, M = utils.convert_inputs(inp, M)
(result,) = apply(op, inp, M, Tensor(dsize))
dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, M, dsize)
return result




+ 34
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -13,6 +13,7 @@ import numpy as np
import pytest

import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape
from megengine.core.ops import builtin as ops
@@ -261,3 +262,36 @@ def test_trace_reshape():
f(x1)
f(x2)
f(x3)


def test_trace_topk():
x = tensor([5, 2, 7, 1, 0, 3, 2])

@trace(symbolic=True)
def f(x):
y = F.topk(x, 3)
np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
return y

for i in range(3):
f(x)


def test_trace_warp_perspective():
inp_shape = (1, 1, 4, 4)
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
M = tensor(
np.array(
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
).reshape(M_shape)
)

@trace(symbolic=True)
def f(x, M):
out = F.warp_perspective(x, M, (2, 2))
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
return out

for i in range(1):
f(x, M)

Loading…
Cancel
Save