You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv.py 848 B

12345678910111213141516171819202122232425262728
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import numpy as np
  4. import pytest
  5. import megengine as mge
  6. import megengine.module as M
  7. from megengine import tensor
  8. # NOTE: test in module for convenience. should really test in functional
  9. @pytest.mark.parametrize(
  10. "name",
  11. ["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"],
  12. )
  13. def test_conv_dtype_promotion(name):
  14. old = mge.config.deterministic_kernel
  15. mge.config.deterministic_kernel = True
  16. N, Ci, Co, K = 2, 16, 32, 3
  17. S = (7,) * int(name[-2])
  18. if "Local" in name:
  19. m = getattr(M, name)(Ci, Co, *S, K)
  20. else:
  21. m = getattr(M, name)(Ci, Co, K)
  22. x = tensor(np.random.random(size=(N, Ci) + S).astype("float16"))
  23. np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy())
  24. mge.config.deterministic_kernel = old