GitOrigin-RevId: 2071bb63a8
release-1.1
@@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P | |||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import utils | from ..core.tensor import utils | ||||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
from ..core.tensor.utils import astensor1d | |||||
from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
from ..random import uniform | from ..random import uniform | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -868,7 +869,8 @@ def warp_perspective( | |||||
imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val | imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val | ||||
) | ) | ||||
inp, M = utils.convert_inputs(inp, M) | inp, M = utils.convert_inputs(inp, M) | ||||
(result,) = apply(op, inp, M, Tensor(dsize)) | |||||
dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device) | |||||
(result,) = apply(op, inp, M, dsize) | |||||
return result | return result | ||||
@@ -13,6 +13,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
import megengine.functional as F | |||||
from megengine import cgtools, tensor | from megengine import cgtools, tensor | ||||
from megengine.core._trace_option import set_tensor_shape | from megengine.core._trace_option import set_tensor_shape | ||||
from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
@@ -261,3 +262,36 @@ def test_trace_reshape(): | |||||
f(x1) | f(x1) | ||||
f(x2) | f(x2) | ||||
f(x3) | f(x3) | ||||
def test_trace_topk(): | |||||
x = tensor([5, 2, 7, 1, 0, 3, 2]) | |||||
@trace(symbolic=True) | |||||
def f(x): | |||||
y = F.topk(x, 3) | |||||
np.testing.assert_equal(y[0].shape.numpy(), np.array([3,])) | |||||
return y | |||||
for i in range(3): | |||||
f(x) | |||||
def test_trace_warp_perspective(): | |||||
inp_shape = (1, 1, 4, 4) | |||||
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||||
M_shape = (1, 3, 3) | |||||
M = tensor( | |||||
np.array( | |||||
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 | |||||
).reshape(M_shape) | |||||
) | |||||
@trace(symbolic=True) | |||||
def f(x, M): | |||||
out = F.warp_perspective(x, M, (2, 2)) | |||||
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) | |||||
return out | |||||
for i in range(1): | |||||
f(x, M) |