Browse Source

feat(imperative/opr): deprecate resize op and make it as a special case of interpolate

GitOrigin-RevId: a5668c5779
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
fe9c6e26a8
3 changed files with 29 additions and 43 deletions
  1. +23
    -37
      imperative/python/megengine/functional/nn.py
  2. +2
    -2
      imperative/python/test/unit/core/test_autodiff.py
  3. +4
    -4
      imperative/python/test/unit/functional/test_functional.py

+ 23
- 37
imperative/python/megengine/functional/nn.py View File

@@ -57,7 +57,6 @@ __all__ = [
"one_hot",
"prelu",
"remap",
"resize",
"softmax",
"softplus",
"warp_affine",
@@ -984,41 +983,6 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
return result


def resize(
inp: Tensor, target_shape: Iterable[int], interp_mode: str = "LINEAR"
) -> Tensor:
r"""
Applies resize transformation to batched 2D images.

:param inp: `(N, C, H, W)` input tensor. Currently only support "NCHW" format.
:param target_shape: `(H, W)` target images shape.
:param interp_mode: interpolation methods. Defaule mode is "LINEAR", Currently only support "LINEAR".

Examples:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F

x = tensor(np.random.randn(10, 3, 32, 32))
out = F.resize(x, (16, 16))
print(out.numpy().shape)

Outputs:

.. testoutput::

(10, 3, 16, 16)

"""
op = builtin.Resize(imode=interp_mode, format="NCHW")
shape = astensor1d(target_shape, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result


def warp_affine(
inp: Tensor,
weight: Tensor,
@@ -1187,7 +1151,7 @@ def interpolate(
size: Optional[Union[int, Tuple[int, int]]] = None,
scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
mode: str = "BILINEAR",
align_corners: bool = None,
align_corners: Optional[bool] = None,
) -> Tensor:
r"""
Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
@@ -1197,6 +1161,15 @@ def interpolate(
:param scale_factor: scaling factor of the output tensor. Default: None
:param mode: interpolation methods, acceptable values are:
"BILINEAR", "LINEAR". Default: "BILINEAR"
:param align_corners: This only has an effect when `mode`
is "BILINEAR" or "LINEAR". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input
and output tensors are aligned by the center points of their corner
pixels, preserving the values at the corner pixels. If set to ``False``,
the input and output tensors are aligned by the corner points of their
corner pixels, and the interpolation uses edge value padding for
out-of-boundary values, making this operation *independent* of input size
when `scale_factor` is kept the same. Default: None
:return: output tensor.

Examples:
@@ -1235,6 +1208,19 @@ def interpolate(
if align_corners is None:
align_corners = False

if (
size is not None
and scale_factor is None
and not align_corners
and mode == "BILINEAR"
and inp.ndim in [4, 5]
):
# fastpath for interpolate
op = builtin.Resize(imode="LINEAR", format="NCHW")
shape = astensor1d(size, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result

if mode == "LINEAR":
inp = expand_dims(inp, 3)



+ 2
- 2
imperative/python/test/unit/core/test_autodiff.py View File

@@ -367,12 +367,12 @@ def test_Broadcast():
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())


def test_resize():
def test_interpolate_fastpath():
x_np = np.random.rand(3, 3, 32, 32).astype("float32")
x = mge.Tensor(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = F.resize(x, (16, 16))
y = F.nn.interpolate(x, size=(16, 16), mode="BILINEAR")

grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy())


+ 4
- 4
imperative/python/test/unit/functional/test_functional.py View File

@@ -325,7 +325,7 @@ def test_one_hot():
onehot_high_dimension()


def test_resize():
def test_interpolate_fastpath():
# check shape
test_cases = [
[(1, 1, 10, 10), (5, 5)],
@@ -335,18 +335,18 @@ def test_resize():
]
for inp_shape, target_shape in test_cases:
x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
out = F.resize(x, target_shape, interp_mode="LINEAR")
out = F.nn.interpolate(x, target_shape, mode="BILINEAR")
assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]

# check value
x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
out = F.resize(x, (15, 5), interp_mode="LINEAR")
out = F.nn.interpolate(x, (15, 5), mode="BILINEAR")
np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))

np_x = np.arange(32)
x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
out = F.resize(x, (1, 1), interp_mode="LINEAR")
out = F.nn.interpolate(x, (1, 1), mode="BILINEAR")
np.testing.assert_equal(out.item(), np_x.mean())




Loading…
Cancel
Save