Browse Source

fix the type error of pixel shuffle

add the test of dtype float
tags/v1.9.0
Qsingle 3 years ago
parent
commit
14499e83d9
2 changed files with 30 additions and 4 deletions
  1. +3
    -3
      imperative/python/megengine/functional/nn.py
  2. +27
    -1
      imperative/python/test/unit/functional/test_functional.py

+ 3
- 3
imperative/python/megengine/functional/nn.py View File

@@ -1889,7 +1889,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
) )
shape_1 = ( shape_1 = (
*high_dim, *high_dim,
shape_ori[-3] / square,
int(shape_ori[-3] / square),
shape_ori[-2] * upscale_factor, shape_ori[-2] * upscale_factor,
shape_ori[-1] * upscale_factor, shape_ori[-1] * upscale_factor,
) )
@@ -1898,8 +1898,8 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:


layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order) layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)


shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device)
shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device)
shape_0 = convert_single_value(shape_0, device=inp.device)
shape_1 = convert_single_value(shape_1, device=inp.device)
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1) outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)


return outvar return outvar


+ 27
- 1
imperative/python/test/unit/functional/test_functional.py View File

@@ -1218,30 +1218,49 @@ def test_pixel_shuffle():
out = F.pixel_shuffle(tensor(inp), upscale_factor=4) out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
golden = pixel_shuffle(inp, 4) golden = pixel_shuffle(inp, 4)
np.testing.assert_equal(out.numpy(), golden) np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)


# ndim = 4 # ndim = 4
inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3) inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=3) out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
golden = pixel_shuffle(inp, 3) golden = pixel_shuffle(inp, 3)
np.testing.assert_equal(out.numpy(), golden) np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)


# ndim = 5 # ndim = 5
inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4) inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2) out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2) golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden) np.testing.assert_equal(out.numpy(), golden)

inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)
# ndim = 6 # ndim = 6
inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4) inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=5) out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
golden = pixel_shuffle(inp, 5) golden = pixel_shuffle(inp, 5)
np.testing.assert_equal(out.numpy(), golden) np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)


# ndim = 7 # ndim = 7
inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4) inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2) out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2) golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden) np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)




@pytest.mark.parametrize("is_symbolic", [False, True]) @pytest.mark.parametrize("is_symbolic", [False, True])
@@ -1260,6 +1279,13 @@ def test_pixel_shuffle_symbolic(is_symbolic):
if is_symbolic is None: if is_symbolic is None:
break break


inp = np.float32(inp)
golden = pixel_shuffle(inp, 2)
for _ in range(3):
out = fn(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
if is_symbolic is None:
break


def test_set_conv2d_config(): def test_set_conv2d_config():
"""check setting config by contextmanager is equal to manually converted result""" """check setting config by contextmanager is equal to manually converted result"""


Loading…
Cancel
Save