Browse Source

fix(mge/functional): fix trace warp_perspective

GitOrigin-RevId: 2071bb63a8
release-1.1
Megvii Engine Team 4 years ago
parent
commit
06041f8a7e
2 changed files with 37 additions and 1 deletions
  1. +3
    -1
      imperative/python/megengine/functional/nn.py
  2. +34
    -0
      imperative/python/test/unit/test_tracing.py

+ 3
- 1
imperative/python/megengine/functional/nn.py View File

@@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import utils from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..random import uniform from ..random import uniform
from ..tensor import Tensor from ..tensor import Tensor
@@ -868,7 +869,8 @@ def warp_perspective(
imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val
) )
inp, M = utils.convert_inputs(inp, M) inp, M = utils.convert_inputs(inp, M)
(result,) = apply(op, inp, M, Tensor(dsize))
dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, M, dsize)
return result return result






+ 34
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -13,6 +13,7 @@ import numpy as np
import pytest import pytest


import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine import cgtools, tensor from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape from megengine.core._trace_option import set_tensor_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
@@ -261,3 +262,36 @@ def test_trace_reshape():
f(x1) f(x1)
f(x2) f(x2)
f(x3) f(x3)


def test_trace_topk():
x = tensor([5, 2, 7, 1, 0, 3, 2])

@trace(symbolic=True)
def f(x):
y = F.topk(x, 3)
np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
return y

for i in range(3):
f(x)


def test_trace_warp_perspective():
inp_shape = (1, 1, 4, 4)
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
M = tensor(
np.array(
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
).reshape(M_shape)
)

@trace(symbolic=True)
def f(x, M):
out = F.warp_perspective(x, M, (2, 2))
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
return out

for i in range(1):
f(x, M)

Loading…
Cancel
Save