Browse Source

feat(functional): let interpolate support more modes

GitOrigin-RevId: 9693a1ac63
release-1.5
Megvii Engine Team 3 years ago
parent
commit
536506c3f4
1 changed files with 19 additions and 7 deletions
  1. +19
    -7
      imperative/python/megengine/functional/vision.py

+ 19
- 7
imperative/python/megengine/functional/vision.py View File

@@ -17,7 +17,7 @@ from ..core.tensor.utils import astensor1d
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import floor from .elemwise import floor
from .math import argsort from .math import argsort
from .tensor import broadcast_to, concat, expand_dims, reshape
from .tensor import broadcast_to, concat, expand_dims, reshape, transpose




def cvt_color(inp: Tensor, mode: str = ""): def cvt_color(inp: Tensor, mode: str = ""):
@@ -474,7 +474,7 @@ def interpolate(
:param size: size of the output tensor. Default: None :param size: size of the output tensor. Default: None
:param scale_factor: scaling factor of the output tensor. Default: None :param scale_factor: scaling factor of the output tensor. Default: None
:param mode: interpolation methods, acceptable values are: :param mode: interpolation methods, acceptable values are:
"bilinear", "linear". Default: "bilinear"
"bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
:param align_corners: This only has an effect when `mode` :param align_corners: This only has an effect when `mode`
is "bilinear" or "linear". Geometrically, we consider the pixels of the input is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input and output as squares rather than points. If set to ``True``, the input
@@ -511,8 +511,8 @@ def interpolate(


""" """
mode = mode.lower() mode = mode.lower()
if mode not in ["bilinear", "linear"]:
raise ValueError("interpolate only support linear or bilinear mode")
if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
raise ValueError("unsupported interpolate mode: {}".format(mode))
if mode not in ["bilinear", "linear"]: if mode not in ["bilinear", "linear"]:
if align_corners is not None: if align_corners is not None:
raise ValueError( raise ValueError(
@@ -625,9 +625,21 @@ def interpolate(
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) weight = broadcast_to(weight, (inp.shape[0], 3, 3))


weight = weight.astype("float32") weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
if mode in ["linear", "bilinear"]:
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
else:
# only NHWC format support "cubic" and "nearest" mode
inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective(
inp,
weight,
dsize,
format="NHWC",
interp_mode="cubic" if mode == "bicubic" else mode,
)
ret = transpose(ret, (0, 3, 1, 2))
return ret return ret






Loading…
Cancel
Save