GitOrigin-RevId: b7cc6dd829
tags/v1.3.0
@@ -7,7 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# pylint: disable=too-many-lines | # pylint: disable=too-many-lines | ||||
from typing import Optional, Sequence, Tuple, Union | |||||
from typing import Iterable, Optional, Sequence, Tuple, Union | |||||
from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
@@ -58,6 +58,7 @@ __all__ = [ | |||||
"one_hot", | "one_hot", | ||||
"prelu", | "prelu", | ||||
"remap", | "remap", | ||||
"resize", | |||||
"softmax", | "softmax", | ||||
"softplus", | "softplus", | ||||
"svd", | "svd", | ||||
@@ -878,6 +879,41 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||||
return result | return result | ||||
def resize( | |||||
inp: Tensor, target_shape: Iterable[int], interp_mode: str = "LINEAR" | |||||
) -> Tensor: | |||||
r""" | |||||
Applies resize transformation to batched 2D images. | |||||
:param inp: `(N, C, H, W)` input tensor. Currently only support "NCHW" format. | |||||
:param target_shape: `(H, W)` target images shape. | |||||
:param interp_mode: interpolation methods. Defaule mode is "LINEAR", Currently only support "LINEAR". | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
from megengine import tensor | |||||
import megengine.functional as F | |||||
x = tensor(np.random.randn(10, 3, 32, 32)) | |||||
out = F.resize(x, (16, 16)) | |||||
print(out.numpy().shape) | |||||
Outputs: | |||||
.. testoutput:: | |||||
(10, 3, 16, 16) | |||||
""" | |||||
op = builtin.Resize(imode=interp_mode, format="NCHW") | |||||
shape = astensor1d(target_shape, inp, dtype="int32", device=inp.device) | |||||
(result,) = apply(op, inp, shape) | |||||
return result | |||||
def warp_perspective( | def warp_perspective( | ||||
inp: Tensor, | inp: Tensor, | ||||
M: Tensor, | M: Tensor, | ||||
@@ -373,6 +373,17 @@ def test_Broadcast(): | |||||
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) | ||||
def test_resize(): | |||||
x_np = np.random.rand(3, 3, 32, 32).astype("float32") | |||||
x = mge.Tensor(x_np) | |||||
grad = Grad().wrt(x, callback=save_to(x)) | |||||
y = F.resize(x, (16, 16)) | |||||
grad(y, F.ones_like(y)) | |||||
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) | |||||
def test_Reduce_sum(): | def test_Reduce_sum(): | ||||
x_np = np.random.rand(3, 3).astype("float32") | x_np = np.random.rand(3, 3).astype("float32") | ||||
x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
@@ -328,6 +328,31 @@ def test_one_hot(): | |||||
onehot_high_dimension() | onehot_high_dimension() | ||||
def test_resize(): | |||||
# check shape | |||||
test_cases = [ | |||||
[(1, 1, 10, 10), (5, 5)], | |||||
[(1, 3, 10, 10), (20, 20)], | |||||
[(10, 1, 10, 10), (1, 1)], | |||||
[(10, 10, 1, 1), (10, 10)], | |||||
] | |||||
for inp_shape, target_shape in test_cases: | |||||
x = tensor(np.random.randn(*inp_shape), dtype=np.float32) | |||||
out = F.resize(x, target_shape, interp_mode="LINEAR") | |||||
assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1] | |||||
assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1] | |||||
# check value | |||||
x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32) | |||||
out = F.resize(x, (15, 5), interp_mode="LINEAR") | |||||
np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32)) | |||||
np_x = np.arange(32) | |||||
x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1) | |||||
out = F.resize(x, (1, 1), interp_mode="LINEAR") | |||||
np.testing.assert_equal(out.item(), np_x.mean()) | |||||
def test_warp_perspective(): | def test_warp_perspective(): | ||||
inp_shape = (1, 1, 4, 4) | inp_shape = (1, 1, 4, 4) | ||||
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | ||||
@@ -0,0 +1,38 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/resize.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/imgproc.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Resize&>(def); | |||||
mgb_assert(inputs.size() == 2); | |||||
return opr::Resize::make(inputs[0], inputs[1], op.param()); | |||||
} | |||||
OP_TRAIT_REG(Resize, Resize) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // anonymous namespace | |||||
} // namespace imperative | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -76,6 +76,8 @@ def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; | |||||
def Remap: MgbHashableOp<"Remap", [RemapParam]>; | def Remap: MgbHashableOp<"Remap", [RemapParam]>; | ||||
def Resize: MgbHashableOp<"Resize", [ResizeParam]>; | |||||
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; | def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; | ||||
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; | def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; | ||||