Browse Source

feat(mge): add functional test

GitOrigin-RevId: aa0be62686
release-1.7
Megvii Engine Team 3 years ago
parent
commit
7591718d20
1 changed files with 32 additions and 0 deletions
  1. +32
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 32
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -15,6 +15,7 @@ import pytest
from utils import opr_test from utils import opr_test


import megengine.amp as amp import megengine.amp as amp
import megengine.config as config
import megengine.core.ops.builtin as builtin import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype import megengine.core.tensor.dtype as dtype
import megengine.functional as F import megengine.functional as F
@@ -1258,3 +1259,34 @@ def test_pixel_shuffle_symbolic(is_symbolic):
np.testing.assert_equal(out.numpy(), golden) np.testing.assert_equal(out.numpy(), golden)
if is_symbolic is None: if is_symbolic is None:
break break


def test_set_conv2d_config():
"""check setting config by contextmanager is equal to manually converted result"""
config._compute_mode = "float32"
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float16)
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float16)
config_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
config._compute_mode = "default"
with config._override(compute_mode="float32"):
context_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
expected = F.conv2d(
inp, weight, None, (2, 2), (3, 3), (1, 1), 1, compute_mode="float32",
)
np.testing.assert_allclose(config_out.numpy(), expected.numpy())
np.testing.assert_allclose(context_out.numpy(), expected.numpy())


def test_set_warp_perspective_config():
config._conv_format = "NHWC"
inp_shape = (1, 1, 4, 4)
inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
M = Tensor(np.random.randn(3, 3), dtype=np.float32).reshape(M_shape)
config_out = F.vision.warp_perspective(inp, M, (2, 2))
config._conv_format = "default"
with config._override(conv_format="NHWC"):
context_out = F.vision.warp_perspective(inp, M, (2, 2))
expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC")
np.testing.assert_allclose(config_out.numpy(), expected.numpy())
np.testing.assert_allclose(context_out.numpy(), expected.numpy())

Loading…
Cancel
Save