@@ -15,6 +15,7 @@ from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import ( | |||
BatchNorm, | |||
Dimshuffle, | |||
Elemwise, | |||
GetVarShape, | |||
Identity, | |||
@@ -86,6 +87,7 @@ __all__ = [ | |||
"sync_batch_norm", | |||
"warp_affine", | |||
"warp_perspective", | |||
"pixel_shuffle", | |||
] | |||
@@ -1733,6 +1735,69 @@ def pad( | |||
return output | |||
@lru_cache(maxsize=None) | |||
def _get_layerPixelShuffle(device, dtype, dim_order): | |||
@subgraph("LayerPixelShuffle", dtype, device, 3) | |||
def layerPixelShuffle(inputs, f, c): | |||
inp, shape_0, shape_1 = inputs | |||
inp = f(Reshape(), inp, shape_0) | |||
inp = f(Dimshuffle(dim_order), inp) | |||
oup = f(Reshape(), inp, shape_1) | |||
return (oup,), (True,) | |||
return layerPixelShuffle | |||
def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
""" | |||
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of | |||
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero | |||
or more batch dimensions. | |||
:param inp: input tensor. | |||
:param upscale_factor: upscale factor of pixel_shuffle. | |||
:return: output tensor. | |||
""" | |||
assert upscale_factor > 0, "upscale_factor should larger than 0" | |||
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" | |||
assert ( | |||
inp.shape[-3] % (upscale_factor ** 2) == 0 | |||
), "the -3 dimension should be divided by (upscale_factor ** 2)" | |||
_device = inp.device | |||
_dtype = inp.dtype | |||
shape_ori = inp.shape | |||
high_dim = shape_ori[:-3] | |||
square = upscale_factor ** 2 | |||
n = 1 | |||
for item in high_dim: | |||
n *= item | |||
shape_0 = ( | |||
n, | |||
int(shape_ori[-3] / square), | |||
upscale_factor, | |||
upscale_factor, | |||
shape_ori[-2], | |||
shape_ori[-1], | |||
) | |||
shape_1 = ( | |||
*high_dim, | |||
shape_ori[-3] / square, | |||
shape_ori[-2] * upscale_factor, | |||
shape_ori[-1] * upscale_factor, | |||
) | |||
dim_order = (0, 1, 4, 2, 5, 3) | |||
layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order) | |||
shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device) | |||
shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device) | |||
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1) | |||
return outvar | |||
from .quantized import conv_bias_activation # isort:skip | |||
from .loss import * # isort:skip | |||
from .metric import * # isort:skip | |||
@@ -0,0 +1,24 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..functional.nn import pixel_shuffle | |||
from .module import Module | |||
class PixelShuffle(Module): | |||
r""" | |||
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of | |||
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero | |||
or more batch dimensions. | |||
""" | |||
def __init__(self, upscale_factor: int, **kwargs): | |||
super().__init__(**kwargs) | |||
self.upscale_factor = upscale_factor | |||
def forward(self, x): | |||
return pixel_shuffle(x, self.upscale_factor) |
@@ -1177,3 +1177,74 @@ def test_pad(): | |||
dst = np.pad(src, ((2, 2), (2, 2)), "reflect") | |||
res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT") | |||
np.testing.assert_allclose(res, dst, atol=1e-5) | |||
def pixel_shuffle(data, r): | |||
high_dim = data.shape[:-3] | |||
data = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1]) | |||
inn, ic, ih, iw = data.shape | |||
res = np.zeros((inn, int(ic / (r * r)), ih * r, iw * r)) | |||
for n in range(inn): | |||
for c in range(ic): | |||
for h in range(ih): | |||
for w in range(iw): | |||
res[ | |||
n, | |||
int(c / r / r), | |||
h * r + int((c % (r * r)) / r), | |||
w * r + c % r, | |||
] = data[n, c, h, w] | |||
if len(high_dim) > 0: | |||
res = res.reshape((*high_dim, int(ic / r / r), ih * r, iw * r)) | |||
else: | |||
res = res[0] | |||
return res | |||
def test_pixel_shuffle(): | |||
# ndim = 3 | |||
inp = np.arange(16 * 3 * 3).reshape(16, 3, 3) | |||
out = F.pixel_shuffle(tensor(inp), upscale_factor=4) | |||
golden = pixel_shuffle(inp, 4) | |||
np.testing.assert_equal(out.numpy(), golden) | |||
# ndim = 4 | |||
inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3) | |||
out = F.pixel_shuffle(tensor(inp), upscale_factor=3) | |||
golden = pixel_shuffle(inp, 3) | |||
np.testing.assert_equal(out.numpy(), golden) | |||
# ndim = 5 | |||
inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4) | |||
out = F.pixel_shuffle(tensor(inp), upscale_factor=2) | |||
golden = pixel_shuffle(inp, 2) | |||
np.testing.assert_equal(out.numpy(), golden) | |||
# ndim = 6 | |||
inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4) | |||
out = F.pixel_shuffle(tensor(inp), upscale_factor=5) | |||
golden = pixel_shuffle(inp, 5) | |||
np.testing.assert_equal(out.numpy(), golden) | |||
# ndim = 7 | |||
inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4) | |||
out = F.pixel_shuffle(tensor(inp), upscale_factor=2) | |||
golden = pixel_shuffle(inp, 2) | |||
np.testing.assert_equal(out.numpy(), golden) | |||
@pytest.mark.parametrize("is_symbolic", [False, True]) | |||
def test_pixel_shuffle_symbolic(is_symbolic): | |||
def fn(inp, upscale_factor): | |||
return F.pixel_shuffle(inp, upscale_factor=upscale_factor) | |||
if is_symbolic is not None: | |||
fn = jit.trace(symbolic=is_symbolic)(fn) | |||
inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5)) | |||
golden = pixel_shuffle(inp, 2) | |||
for _ in range(3): | |||
out = fn(inp, 2) | |||
np.testing.assert_equal(out.numpy(), golden) | |||
if is_symbolic is None: | |||
break |