Browse Source

Merge pull request #426 from Qsingle:fix-pixel_suffle

GitOrigin-RevId: db9a0f7551
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
5e17b3e4c6
2 changed files with 26 additions and 6 deletions
  1. +3
    -3
      imperative/python/megengine/functional/nn.py
  2. +23
    -3
      imperative/python/test/unit/functional/test_functional.py

+ 3
- 3
imperative/python/megengine/functional/nn.py View File

@@ -1843,7 +1843,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
)
shape_1 = (
*high_dim,
shape_ori[-3] / square,
int(shape_ori[-3] / square),
shape_ori[-2] * upscale_factor,
shape_ori[-1] * upscale_factor,
)
@@ -1852,8 +1852,8 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:

layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)

shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device)
shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device)
shape_0 = convert_single_value(shape_0, device=inp.device)
shape_1 = convert_single_value(shape_1, device=inp.device)
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)

return outvar


+ 23
- 3
imperative/python/test/unit/functional/test_functional.py View File

@@ -1207,41 +1207,61 @@ def test_pixel_shuffle():
out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
golden = pixel_shuffle(inp, 4)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(), golden)

# ndim = 4
inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
golden = pixel_shuffle(inp, 3)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=3)
golden = pixel_shuffle(inp_float, 3)
np.testing.assert_equal(out.numpy(), golden)

# ndim = 5
inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)

inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 6
inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
golden = pixel_shuffle(inp, 5)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=5)
golden = pixel_shuffle(inp_float, 5)
np.testing.assert_equal(out.numpy(), golden)

# ndim = 7
inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(), golden)


@pytest.mark.parametrize("type", ["int32", "float32"])
@pytest.mark.parametrize("is_symbolic", [False, True])
def test_pixel_shuffle_symbolic(is_symbolic):
def test_pixel_shuffle_symbolic(is_symbolic, type):
def fn(inp, upscale_factor):
return F.pixel_shuffle(inp, upscale_factor=upscale_factor)

if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)

inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5))
inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5).astype(type))
golden = pixel_shuffle(inp, 2)
for _ in range(3):
out = fn(inp, 2)


Loading…
Cancel
Save