GitOrigin-RevId: 789f1511ec
release-1.1
@@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode | |||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops._internal import param_defs as P | from ..core.ops._internal import param_defs as P | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import utils | |||||
from ..core.tensor import megbrain_graph, utils | |||||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
from ..core.tensor.utils import astensor1d | from ..core.tensor.utils import astensor1d | ||||
from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
@@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap | |||||
from .types import _pair, _pair_nonzero | from .types import _pair, _pair_nonzero | ||||
__all__ = [ | __all__ = [ | ||||
"adaptive_avg_pool2d", | |||||
"adaptive_max_pool2d", | |||||
"avg_pool2d", | "avg_pool2d", | ||||
"batched_nms", | "batched_nms", | ||||
"batch_norm2d", | "batch_norm2d", | ||||
@@ -324,6 +326,48 @@ def avg_pool2d( | |||||
return output | return output | ||||
def adaptive_max_pool2d( | |||||
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], | |||||
) -> Tensor: | |||||
"""Applies a 2D max adaptive pooling over an input. | |||||
Refer to :class:`~.MaxAdaptivePool2d` for more information. | |||||
:param inp: The input tensor. | |||||
:param oshp: (OH, OW) size of the output shape. | |||||
:return: output tensor. | |||||
""" | |||||
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" | |||||
if isinstance(oshp, int): | |||||
oshp = (oshp, oshp) | |||||
op = builtin.AdaptivePooling(mode="MAX", format="NCHW",) | |||||
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | |||||
(output,) = apply(op, inp, oshp) | |||||
return output | |||||
def adaptive_avg_pool2d( | |||||
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], | |||||
) -> Tensor: | |||||
"""Applies a 2D average adaptive pooling over an input. | |||||
Refer to :class:`~.AvgAdaptivePool2d` for more information. | |||||
:param inp: The input tensor. | |||||
:param oshp: (OH, OW) size of the output shape. | |||||
:return: output tensor. | |||||
""" | |||||
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" | |||||
if isinstance(oshp, int): | |||||
oshp = (oshp, oshp) | |||||
op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",) | |||||
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | |||||
(output,) = apply(op, inp, oshp) | |||||
return output | |||||
def prelu(inp: Tensor, weight: Tensor) -> Tensor: | def prelu(inp: Tensor, weight: Tensor) -> Tensor: | ||||
r""" | r""" | ||||
Applies the element-wise PReLU function. | Applies the element-wise PReLU function. | ||||
@@ -8,6 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | |||||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
from .concat import Concat | from .concat import Concat | ||||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | ||||
@@ -0,0 +1,114 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from abc import abstractmethod | |||||
from typing import Tuple, Union | |||||
from ..functional import adaptive_avg_pool2d, adaptive_max_pool2d | |||||
from ..tensor import Parameter, Tensor | |||||
from .module import Module | |||||
class _AdaptivePoolNd(Module): | |||||
def __init__( | |||||
self, oshp: Union[Tuple[int, int], int, Tensor], | |||||
): | |||||
super(_AdaptivePoolNd, self).__init__() | |||||
self.oshp = oshp | |||||
@abstractmethod | |||||
def forward(self, inp): | |||||
pass | |||||
class AdaptiveMaxPool2d(_AdaptivePoolNd): | |||||
r"""Applies a 2D max adaptive pooling over an input. | |||||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||||
an output shape :math:`(OH, OW)`, this layer generates the output of | |||||
the size :math:`(N, C, OH, OW)` through a process described as: | |||||
.. math:: | |||||
\begin{aligned} | |||||
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||||
\text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||||
\text{stride[1]} \times w + n) | |||||
\end{aligned} | |||||
Kernel_size and stride can be inferred from input shape and out shape: | |||||
padding: (0, 0) | |||||
stride: (floor(IH / OH), floor(IW / OW)) | |||||
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.module as M | |||||
m = M.AdaptiveMaxPool2d((2, 2)) | |||||
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||||
oup = m(inp) | |||||
print(oup.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[[[[5. 7.] | |||||
[13. 15.]]]] | |||||
""" | |||||
def forward(self, inp): | |||||
return adaptive_max_pool2d(inp, self.oshp) | |||||
class AdaptiveAvgPool2d(_AdaptivePoolNd): | |||||
r"""Applies a 2D average pooling over an input. | |||||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||||
an output shape :math:`(OH, OW)`, this layer generates the output of | |||||
the size :math:`(N, C, OH, OW)` through a process described as: | |||||
.. math:: | |||||
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||||
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||||
Kernel_size and stride can be inferred from input shape and out shape: | |||||
padding: (0, 0) | |||||
stride: (floor(IH / OH), floor(IW / OW)) | |||||
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.module as M | |||||
m = M.AdaptiveAvgPool2d((2, 2)) | |||||
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||||
oup = m(inp) | |||||
print(oup.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[[[[2.5 4.5] | |||||
[10.5 12.5]]]] | |||||
""" | |||||
def forward(self, inp): | |||||
return adaptive_avg_pool2d(inp, self.oshp) |
@@ -206,6 +206,66 @@ def test_roi_pooling(): | |||||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | ||||
def test_adaptive_avg_pool2d(): | |||||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||||
oshp = (2, 2) | |||||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||||
outp = F.adaptive_avg_pool2d(inp, oshp,) | |||||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
np.testing.assert_equal( | |||||
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||||
) | |||||
grad(outp, tensor(F.ones_like(outp))) | |||||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||||
np.testing.assert_equal( | |||||
inp.grad.numpy(), | |||||
np.array( | |||||
[ | |||||
[ | |||||
[ | |||||
[0.25, 0.25, 0.25, 0.25], | |||||
[0.25, 0.25, 0.25, 0.25], | |||||
[0.25, 0.25, 0.25, 0.25], | |||||
[0.25, 0.25, 0.25, 0.25], | |||||
] | |||||
] | |||||
], | |||||
dtype=np.float32, | |||||
), | |||||
) | |||||
def test_adaptive_max_pool2d(): | |||||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||||
oshp = (2, 2) | |||||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||||
outp = F.adaptive_max_pool2d(inp, oshp,) | |||||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
np.testing.assert_equal( | |||||
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||||
) | |||||
grad(outp, tensor(F.ones_like(outp))) | |||||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||||
np.testing.assert_equal( | |||||
inp.grad.numpy(), | |||||
np.array( | |||||
[ | |||||
[ | |||||
[ | |||||
[0.0, 0.0, 0.0, 0.0], | |||||
[0.0, 1.0, 0.0, 1.0], | |||||
[0.0, 0.0, 0.0, 0.0], | |||||
[0.0, 1.0, 0.0, 1.0], | |||||
] | |||||
] | |||||
], | |||||
dtype=np.float32, | |||||
), | |||||
) | |||||
def test_one_hot(): | def test_one_hot(): | ||||
def onehot_low_dimension(): | def onehot_low_dimension(): | ||||
inp = tensor(np.arange(1, 4, dtype=np.int32)) | inp = tensor(np.arange(1, 4, dtype=np.int32)) | ||||