Browse Source

feat(mge): add functional test

GitOrigin-RevId: aa0be62686
release-1.7
Megvii Engine Team 3 years ago
parent
commit
7591718d20
1 changed files with 32 additions and 0 deletions
  1. +32
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 32
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -15,6 +15,7 @@ import pytest
from utils import opr_test

import megengine.amp as amp
import megengine.config as config
import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
@@ -1258,3 +1259,34 @@ def test_pixel_shuffle_symbolic(is_symbolic):
np.testing.assert_equal(out.numpy(), golden)
if is_symbolic is None:
break


def test_set_conv2d_config():
"""check setting config by contextmanager is equal to manually converted result"""
config._compute_mode = "float32"
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float16)
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float16)
config_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
config._compute_mode = "default"
with config._override(compute_mode="float32"):
context_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
expected = F.conv2d(
inp, weight, None, (2, 2), (3, 3), (1, 1), 1, compute_mode="float32",
)
np.testing.assert_allclose(config_out.numpy(), expected.numpy())
np.testing.assert_allclose(context_out.numpy(), expected.numpy())


def test_set_warp_perspective_config():
config._conv_format = "NHWC"
inp_shape = (1, 1, 4, 4)
inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
M = Tensor(np.random.randn(3, 3), dtype=np.float32).reshape(M_shape)
config_out = F.vision.warp_perspective(inp, M, (2, 2))
config._conv_format = "default"
with config._override(conv_format="NHWC"):
context_out = F.vision.warp_perspective(inp, M, (2, 2))
expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC")
np.testing.assert_allclose(config_out.numpy(), expected.numpy())
np.testing.assert_allclose(context_out.numpy(), expected.numpy())

Loading…
Cancel
Save