|
|
@@ -13,6 +13,7 @@ import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
import megengine.core.tensor.megbrain_graph as G |
|
|
|
import megengine.functional as F |
|
|
|
from megengine import cgtools, tensor |
|
|
|
from megengine.core._trace_option import set_tensor_shape |
|
|
|
from megengine.core.ops import builtin as ops |
|
|
@@ -261,3 +262,36 @@ def test_trace_reshape(): |
|
|
|
f(x1) |
|
|
|
f(x2) |
|
|
|
f(x3) |
|
|
|
|
|
|
|
|
|
|
|
def test_trace_topk(): |
|
|
|
x = tensor([5, 2, 7, 1, 0, 3, 2]) |
|
|
|
|
|
|
|
@trace(symbolic=True) |
|
|
|
def f(x): |
|
|
|
y = F.topk(x, 3) |
|
|
|
np.testing.assert_equal(y[0].shape.numpy(), np.array([3,])) |
|
|
|
return y |
|
|
|
|
|
|
|
for i in range(3): |
|
|
|
f(x) |
|
|
|
|
|
|
|
|
|
|
|
def test_trace_warp_perspective(): |
|
|
|
inp_shape = (1, 1, 4, 4) |
|
|
|
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) |
|
|
|
M_shape = (1, 3, 3) |
|
|
|
M = tensor( |
|
|
|
np.array( |
|
|
|
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 |
|
|
|
).reshape(M_shape) |
|
|
|
) |
|
|
|
|
|
|
|
@trace(symbolic=True) |
|
|
|
def f(x, M): |
|
|
|
out = F.warp_perspective(x, M, (2, 2)) |
|
|
|
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) |
|
|
|
return out |
|
|
|
|
|
|
|
for i in range(1): |
|
|
|
f(x, M) |