Browse Source

feat(functional): let interpolate support more modes

GitOrigin-RevId: 9693a1ac63
release-1.5
Megvii Engine Team 3 years ago
parent
commit
536506c3f4
1 changed files with 19 additions and 7 deletions
  1. +19
    -7
      imperative/python/megengine/functional/vision.py

+ 19
- 7
imperative/python/megengine/functional/vision.py View File

@@ -17,7 +17,7 @@ from ..core.tensor.utils import astensor1d
from ..tensor import Tensor
from .elemwise import floor
from .math import argsort
from .tensor import broadcast_to, concat, expand_dims, reshape
from .tensor import broadcast_to, concat, expand_dims, reshape, transpose


def cvt_color(inp: Tensor, mode: str = ""):
@@ -474,7 +474,7 @@ def interpolate(
:param size: size of the output tensor. Default: None
:param scale_factor: scaling factor of the output tensor. Default: None
:param mode: interpolation methods, acceptable values are:
"bilinear", "linear". Default: "bilinear"
"bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
:param align_corners: This only has an effect when `mode`
is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input
@@ -511,8 +511,8 @@ def interpolate(

"""
mode = mode.lower()
if mode not in ["bilinear", "linear"]:
raise ValueError("interpolate only support linear or bilinear mode")
if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
raise ValueError("unsupported interpolate mode: {}".format(mode))
if mode not in ["bilinear", "linear"]:
if align_corners is not None:
raise ValueError(
@@ -625,9 +625,21 @@ def interpolate(
weight = broadcast_to(weight, (inp.shape[0], 3, 3))

weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
if mode in ["linear", "bilinear"]:
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
else:
# only NHWC format support "cubic" and "nearest" mode
inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective(
inp,
weight,
dsize,
format="NHWC",
interp_mode="cubic" if mode == "bicubic" else mode,
)
ret = transpose(ret, (0, 3, 1, 2))
return ret




Loading…
Cancel
Save