GitOrigin-RevId: 789f1511ec
release-1.1
@@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode | |||
from ..core.ops import builtin | |||
from ..core.ops._internal import param_defs as P | |||
from ..core.ops.special import Const | |||
from ..core.tensor import utils | |||
from ..core.tensor import megbrain_graph, utils | |||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
from ..core.tensor.utils import astensor1d | |||
from ..distributed import WORLD, is_distributed | |||
@@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap | |||
from .types import _pair, _pair_nonzero | |||
__all__ = [ | |||
"adaptive_avg_pool2d", | |||
"adaptive_max_pool2d", | |||
"avg_pool2d", | |||
"batched_nms", | |||
"batch_norm2d", | |||
@@ -324,6 +326,48 @@ def avg_pool2d( | |||
return output | |||
def adaptive_max_pool2d( | |||
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], | |||
) -> Tensor: | |||
"""Applies a 2D max adaptive pooling over an input. | |||
Refer to :class:`~.MaxAdaptivePool2d` for more information. | |||
:param inp: The input tensor. | |||
:param oshp: (OH, OW) size of the output shape. | |||
:return: output tensor. | |||
""" | |||
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" | |||
if isinstance(oshp, int): | |||
oshp = (oshp, oshp) | |||
op = builtin.AdaptivePooling(mode="MAX", format="NCHW",) | |||
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | |||
(output,) = apply(op, inp, oshp) | |||
return output | |||
def adaptive_avg_pool2d( | |||
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], | |||
) -> Tensor: | |||
"""Applies a 2D average adaptive pooling over an input. | |||
Refer to :class:`~.AvgAdaptivePool2d` for more information. | |||
:param inp: The input tensor. | |||
:param oshp: (OH, OW) size of the output shape. | |||
:return: output tensor. | |||
""" | |||
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" | |||
if isinstance(oshp, int): | |||
oshp = (oshp, oshp) | |||
op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",) | |||
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | |||
(output,) = apply(op, inp, oshp) | |||
return output | |||
def prelu(inp: Tensor, weight: Tensor) -> Tensor: | |||
r""" | |||
Applies the element-wise PReLU function. | |||
@@ -8,6 +8,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | |||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
@@ -0,0 +1,114 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from typing import Tuple, Union | |||
from ..functional import adaptive_avg_pool2d, adaptive_max_pool2d | |||
from ..tensor import Parameter, Tensor | |||
from .module import Module | |||
class _AdaptivePoolNd(Module): | |||
def __init__( | |||
self, oshp: Union[Tuple[int, int], int, Tensor], | |||
): | |||
super(_AdaptivePoolNd, self).__init__() | |||
self.oshp = oshp | |||
@abstractmethod | |||
def forward(self, inp): | |||
pass | |||
class AdaptiveMaxPool2d(_AdaptivePoolNd): | |||
r"""Applies a 2D max adaptive pooling over an input. | |||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||
an output shape :math:`(OH, OW)`, this layer generates the output of | |||
the size :math:`(N, C, OH, OW)` through a process described as: | |||
.. math:: | |||
\begin{aligned} | |||
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||
\text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||
\text{stride[1]} \times w + n) | |||
\end{aligned} | |||
Kernel_size and stride can be inferred from input shape and out shape: | |||
padding: (0, 0) | |||
stride: (floor(IH / OH), floor(IW / OW)) | |||
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
m = M.AdaptiveMaxPool2d((2, 2)) | |||
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||
oup = m(inp) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[[5. 7.] | |||
[13. 15.]]]] | |||
""" | |||
def forward(self, inp): | |||
return adaptive_max_pool2d(inp, self.oshp) | |||
class AdaptiveAvgPool2d(_AdaptivePoolNd): | |||
r"""Applies a 2D average pooling over an input. | |||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||
an output shape :math:`(OH, OW)`, this layer generates the output of | |||
the size :math:`(N, C, OH, OW)` through a process described as: | |||
.. math:: | |||
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||
Kernel_size and stride can be inferred from input shape and out shape: | |||
padding: (0, 0) | |||
stride: (floor(IH / OH), floor(IW / OW)) | |||
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
m = M.AdaptiveAvgPool2d((2, 2)) | |||
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||
oup = m(inp) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[[2.5 4.5] | |||
[10.5 12.5]]]] | |||
""" | |||
def forward(self, inp): | |||
return adaptive_avg_pool2d(inp, self.oshp) |
@@ -206,6 +206,66 @@ def test_roi_pooling(): | |||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
def test_adaptive_avg_pool2d(): | |||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||
oshp = (2, 2) | |||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||
outp = F.adaptive_avg_pool2d(inp, oshp,) | |||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
np.testing.assert_equal( | |||
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||
) | |||
grad(outp, tensor(F.ones_like(outp))) | |||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||
np.testing.assert_equal( | |||
inp.grad.numpy(), | |||
np.array( | |||
[ | |||
[ | |||
[ | |||
[0.25, 0.25, 0.25, 0.25], | |||
[0.25, 0.25, 0.25, 0.25], | |||
[0.25, 0.25, 0.25, 0.25], | |||
[0.25, 0.25, 0.25, 0.25], | |||
] | |||
] | |||
], | |||
dtype=np.float32, | |||
), | |||
) | |||
def test_adaptive_max_pool2d(): | |||
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||
oshp = (2, 2) | |||
grad = Grad().wrt(inp, callback=_save_to(inp)) | |||
outp = F.adaptive_max_pool2d(inp, oshp,) | |||
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||
np.testing.assert_equal( | |||
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||
) | |||
grad(outp, tensor(F.ones_like(outp))) | |||
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||
np.testing.assert_equal( | |||
inp.grad.numpy(), | |||
np.array( | |||
[ | |||
[ | |||
[ | |||
[0.0, 0.0, 0.0, 0.0], | |||
[0.0, 1.0, 0.0, 1.0], | |||
[0.0, 0.0, 0.0, 0.0], | |||
[0.0, 1.0, 0.0, 1.0], | |||
] | |||
] | |||
], | |||
dtype=np.float32, | |||
), | |||
) | |||
def test_one_hot(): | |||
def onehot_low_dimension(): | |||
inp = tensor(np.arange(1, 4, dtype=np.int32)) | |||