Browse Source

feat(imperative): add pixel_shuffle opr

revert-211-master
chenjiahui Lixiangyin 3 years ago
parent
commit
d17cd60d3b
3 changed files with 160 additions and 0 deletions
  1. +65
    -0
      imperative/python/megengine/functional/nn.py
  2. +24
    -0
      imperative/python/megengine/module/pixel_shuffle.py
  3. +71
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 65
- 0
imperative/python/megengine/functional/nn.py View File

@@ -15,6 +15,7 @@ from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
from ..core.ops.builtin import (
BatchNorm,
Dimshuffle,
Elemwise,
GetVarShape,
Identity,
@@ -86,6 +87,7 @@ __all__ = [
"sync_batch_norm",
"warp_affine",
"warp_perspective",
"pixel_shuffle",
]


@@ -1733,6 +1735,69 @@ def pad(
return output


@lru_cache(maxsize=None)
def _get_layerPixelShuffle(device, dtype, dim_order):
@subgraph("LayerPixelShuffle", dtype, device, 3)
def layerPixelShuffle(inputs, f, c):
inp, shape_0, shape_1 = inputs
inp = f(Reshape(), inp, shape_0)
inp = f(Dimshuffle(dim_order), inp)
oup = f(Reshape(), inp, shape_1)
return (oup,), (True,)

return layerPixelShuffle


def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
"""
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero
or more batch dimensions.

:param inp: input tensor.
:param upscale_factor: upscale factor of pixel_shuffle.
:return: output tensor.
"""
assert upscale_factor > 0, "upscale_factor should larger than 0"
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
assert (
inp.shape[-3] % (upscale_factor ** 2) == 0
), "the -3 dimension should be divided by (upscale_factor ** 2)"

_device = inp.device
_dtype = inp.dtype
shape_ori = inp.shape
high_dim = shape_ori[:-3]
square = upscale_factor ** 2
n = 1
for item in high_dim:
n *= item
shape_0 = (
n,
int(shape_ori[-3] / square),
upscale_factor,
upscale_factor,
shape_ori[-2],
shape_ori[-1],
)
shape_1 = (
*high_dim,
shape_ori[-3] / square,
shape_ori[-2] * upscale_factor,
shape_ori[-1] * upscale_factor,
)

dim_order = (0, 1, 4, 2, 5, 3)

layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)

shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device)
shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device)
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)

return outvar


from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip
from .metric import * # isort:skip


+ 24
- 0
imperative/python/megengine/module/pixel_shuffle.py View File

@@ -0,0 +1,24 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..functional.nn import pixel_shuffle
from .module import Module


class PixelShuffle(Module):
r"""
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero
or more batch dimensions.
"""

def __init__(self, upscale_factor: int, **kwargs):
super().__init__(**kwargs)
self.upscale_factor = upscale_factor

def forward(self, x):
return pixel_shuffle(x, self.upscale_factor)

+ 71
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -1177,3 +1177,74 @@ def test_pad():
dst = np.pad(src, ((2, 2), (2, 2)), "reflect")
res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT")
np.testing.assert_allclose(res, dst, atol=1e-5)


def pixel_shuffle(data, r):
high_dim = data.shape[:-3]
data = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1])
inn, ic, ih, iw = data.shape
res = np.zeros((inn, int(ic / (r * r)), ih * r, iw * r))
for n in range(inn):
for c in range(ic):
for h in range(ih):
for w in range(iw):
res[
n,
int(c / r / r),
h * r + int((c % (r * r)) / r),
w * r + c % r,
] = data[n, c, h, w]
if len(high_dim) > 0:
res = res.reshape((*high_dim, int(ic / r / r), ih * r, iw * r))
else:
res = res[0]
return res


def test_pixel_shuffle():
# ndim = 3
inp = np.arange(16 * 3 * 3).reshape(16, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
golden = pixel_shuffle(inp, 4)
np.testing.assert_equal(out.numpy(), golden)

# ndim = 4
inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
golden = pixel_shuffle(inp, 3)
np.testing.assert_equal(out.numpy(), golden)

# ndim = 5
inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)

# ndim = 6
inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
golden = pixel_shuffle(inp, 5)
np.testing.assert_equal(out.numpy(), golden)

# ndim = 7
inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)


@pytest.mark.parametrize("is_symbolic", [False, True])
def test_pixel_shuffle_symbolic(is_symbolic):
def fn(inp, upscale_factor):
return F.pixel_shuffle(inp, upscale_factor=upscale_factor)

if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)

inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5))
golden = pixel_shuffle(inp, 2)
for _ in range(3):
out = fn(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
if is_symbolic is None:
break

Loading…
Cancel
Save