|
|
@@ -15,6 +15,7 @@ import pytest |
|
|
|
from utils import opr_test |
|
|
|
|
|
|
|
import megengine.amp as amp |
|
|
|
import megengine.config as config |
|
|
|
import megengine.core.ops.builtin as builtin |
|
|
|
import megengine.core.tensor.dtype as dtype |
|
|
|
import megengine.functional as F |
|
|
@@ -1258,3 +1259,34 @@ def test_pixel_shuffle_symbolic(is_symbolic): |
|
|
|
np.testing.assert_equal(out.numpy(), golden) |
|
|
|
if is_symbolic is None: |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
def test_set_conv2d_config(): |
|
|
|
"""check setting config by contextmanager is equal to manually converted result""" |
|
|
|
config._compute_mode = "float32" |
|
|
|
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float16) |
|
|
|
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float16) |
|
|
|
config_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) |
|
|
|
config._compute_mode = "default" |
|
|
|
with config._override(compute_mode="float32"): |
|
|
|
context_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) |
|
|
|
expected = F.conv2d( |
|
|
|
inp, weight, None, (2, 2), (3, 3), (1, 1), 1, compute_mode="float32", |
|
|
|
) |
|
|
|
np.testing.assert_allclose(config_out.numpy(), expected.numpy()) |
|
|
|
np.testing.assert_allclose(context_out.numpy(), expected.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def test_set_warp_perspective_config(): |
|
|
|
config._conv_format = "NHWC" |
|
|
|
inp_shape = (1, 1, 4, 4) |
|
|
|
inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) |
|
|
|
M_shape = (1, 3, 3) |
|
|
|
M = Tensor(np.random.randn(3, 3), dtype=np.float32).reshape(M_shape) |
|
|
|
config_out = F.vision.warp_perspective(inp, M, (2, 2)) |
|
|
|
config._conv_format = "default" |
|
|
|
with config._override(conv_format="NHWC"): |
|
|
|
context_out = F.vision.warp_perspective(inp, M, (2, 2)) |
|
|
|
expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC") |
|
|
|
np.testing.assert_allclose(config_out.numpy(), expected.numpy()) |
|
|
|
np.testing.assert_allclose(context_out.numpy(), expected.numpy()) |