GitOrigin-RevId: 9ad87a4ea9
tags/v0.3.2
@@ -53,6 +53,7 @@ from .nn import ( | |||||
batch_norm2d, | batch_norm2d, | ||||
batched_matrix_mul, | batched_matrix_mul, | ||||
conv2d, | conv2d, | ||||
conv_transpose2d, | |||||
dropout, | dropout, | ||||
embedding, | embedding, | ||||
eye, | eye, | ||||
@@ -101,6 +101,69 @@ def conv2d( | |||||
@wrap_io_tensor | @wrap_io_tensor | ||||
def conv_transpose2d( | |||||
inp: Tensor, | |||||
weight: Tensor, | |||||
bias: Optional[Tensor] = None, | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
conv_mode="CROSS_CORRELATION", | |||||
compute_mode="DEFAULT", | |||||
) -> Tensor: | |||||
"""2D transposed convolution operation. | |||||
:param inp: The feature map of the convolution operation | |||||
:param weight: The convolution kernel | |||||
:param bias: The bias added to the result of convolution (if given) | |||||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||||
:param padding: Size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||||
:param groups: number of groups to divide input and output channels into, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
and the shape of weight should be ``(groups, out_channel // groups, | |||||
in_channels // groups, height, width)``. Default: 1 | |||||
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
'CROSS_CORRELATION'. | |||||
:type compute_mode: string or | |||||
:class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||||
Float32 would be used for accumulator and intermediate result, but only | |||||
effective when input and output are of Float16 dtype. | |||||
Refer to :class:`~.ConvTranspose2d` for more information. | |||||
""" | |||||
ph, pw = _pair(padding) | |||||
sh, sw = _pair_nonzero(stride) | |||||
dh, dw = _pair_nonzero(dilation) | |||||
Sparse = mgb.opr_param_defs.Convolution.Sparse | |||||
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||||
res = mgb.opr.deconvolution( | |||||
inp, | |||||
weight, | |||||
pad_h=ph, | |||||
pad_w=pw, | |||||
stride_h=sh, | |||||
stride_w=sw, | |||||
dilate_h=dh, | |||||
dilate_w=dw, | |||||
format="NCHW", | |||||
strategy=get_conv_execution_strategy(), | |||||
mode=conv_mode, | |||||
compute_mode=compute_mode, | |||||
sparse=sparse_type, | |||||
) | |||||
if bias is not None: | |||||
res += bias | |||||
return res | |||||
@wrap_io_tensor | |||||
def max_pool2d( | def max_pool2d( | ||||
inp: Tensor, | inp: Tensor, | ||||
kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
@@ -8,7 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
from .batchnorm import BatchNorm1d, BatchNorm2d | from .batchnorm import BatchNorm1d, BatchNorm2d | ||||
from .conv import Conv2d | |||||
from .conv import Conv2d, ConvTranspose2d | |||||
from .dropout import Dropout | from .dropout import Dropout | ||||
from .embedding import Embedding | from .embedding import Embedding | ||||
from .identity import Identity | from .identity import Identity | ||||
@@ -14,7 +14,7 @@ import numpy as np | |||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
from ..core import Parameter | from ..core import Parameter | ||||
from ..functional import conv2d | |||||
from ..functional import conv2d, conv_transpose2d | |||||
from ..utils.types import _pair, _pair_nonzero | from ..utils.types import _pair, _pair_nonzero | ||||
from . import init | from . import init | ||||
from .module import Module | from .module import Module | ||||
@@ -31,7 +31,6 @@ class _ConvNd(Module): | |||||
stride: Union[int, Tuple[int, int]], | stride: Union[int, Tuple[int, int]], | ||||
padding: Union[int, Tuple[int, int]], | padding: Union[int, Tuple[int, int]], | ||||
dilation: Union[int, Tuple[int, int]], | dilation: Union[int, Tuple[int, int]], | ||||
output_padding: Union[int, Tuple[int, int]], | |||||
groups: int, | groups: int, | ||||
bias: bool = True, | bias: bool = True, | ||||
): | ): | ||||
@@ -46,7 +45,6 @@ class _ConvNd(Module): | |||||
self.stride = stride | self.stride = stride | ||||
self.padding = padding | self.padding = padding | ||||
self.dilation = dilation | self.dilation = dilation | ||||
self.output_padding = output_padding | |||||
self.groups = groups | self.groups = groups | ||||
self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | ||||
@@ -154,7 +152,6 @@ class Conv2d(_ConvNd): | |||||
stride, | stride, | ||||
padding, | padding, | ||||
dilation, | dilation, | ||||
(0, 0), | |||||
groups, | groups, | ||||
bias, | bias, | ||||
) | ) | ||||
@@ -197,3 +194,112 @@ class Conv2d(_ConvNd): | |||||
self.conv_mode, | self.conv_mode, | ||||
self.compute_mode, | self.compute_mode, | ||||
) | ) | ||||
class ConvTranspose2d(_ConvNd): | |||||
r"""Applies a 2D transposed convolution over an input tensor. | |||||
This module is also known as a deconvolution or a fractionally-strided convolution. | |||||
:class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||||
with respect to its input. | |||||
Convolution usually reduces the size of input, while transposed convolution works | |||||
the other way, transforming a smaller input to a larger output while preserving the | |||||
connectivity pattern. | |||||
:param in_channels: number of input channels. | |||||
:param out_channels: number of output channels. | |||||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
an :class:`int`, the actual kernel size would be | |||||
``(kernel_size, kernel_size)``. Default: 1 | |||||
:param stride: stride of the 2D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
:param groups: number of groups to divide input and output channels into, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
and there would be an extra dimension at the beginning of the weight's | |||||
shape. Specifically, the shape of weight would be ``(groups, | |||||
out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||||
:param bias: wether to add a bias onto the result of convolution. Default: | |||||
True | |||||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||||
`CROSS_CORRELATION`. | |||||
:param compute_mode: When set to `DEFAULT`, no special requirements will be | |||||
placed on the precision of intermediate results. When set to `FLOAT32`, | |||||
float32 would be used for accumulator and intermediate result, but only | |||||
effective when input and output are of float16 dtype. | |||||
""" | |||||
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||||
_compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
kernel_size: Union[int, Tuple[int, int]], | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
bias: bool = True, | |||||
conv_mode: str = "CROSS_CORRELATION", | |||||
compute_mode: str = "DEFAULT", | |||||
): | |||||
kernel_size = _pair_nonzero(kernel_size) | |||||
stride = _pair_nonzero(stride) | |||||
padding = _pair(padding) | |||||
dilation = _pair_nonzero(dilation) | |||||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||||
super().__init__( | |||||
in_channels, | |||||
out_channels, | |||||
kernel_size, | |||||
stride, | |||||
padding, | |||||
dilation, | |||||
groups, | |||||
bias, | |||||
) | |||||
def _get_fanin(self): | |||||
kh, kw = self.kernel_size | |||||
oc = self.out_channels | |||||
return kh * kw * oc | |||||
def _infer_weight_shape(self): | |||||
group = self.groups | |||||
ichl = self.in_channels | |||||
ochl = self.out_channels | |||||
kh, kw = self.kernel_size | |||||
if group == 1: | |||||
# Assume format is NCHW | |||||
return (ichl, ochl, kh, kw) | |||||
assert ( | |||||
ichl % group == 0 and ochl % group == 0 | |||||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
ichl, ochl, group | |||||
) | |||||
# Assume format is NCHW | |||||
return (group, ichl // group, ochl // group, kh, kw) | |||||
def _infer_bias_shape(self): | |||||
# Assume format is NCHW | |||||
return (1, self.out_channels, 1, 1) | |||||
def forward(self, inp): | |||||
return conv_transpose2d( | |||||
inp, | |||||
self.weight, | |||||
self.bias, | |||||
self.stride, | |||||
self.padding, | |||||
self.dilation, | |||||
self.groups, | |||||
self.conv_mode, | |||||
self.compute_mode, | |||||
) |
@@ -0,0 +1,64 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import itertools | |||||
import numpy as np | |||||
import pytest | |||||
import torch | |||||
import megengine as mge | |||||
from megengine import Parameter, tensor | |||||
from megengine.module import Conv2d, ConvTranspose2d | |||||
from megengine.test import assertTensorClose | |||||
def test_conv_transpose2d(): | |||||
SH, SW = 3, 1 | |||||
PH, PW = 2, 0 | |||||
N, IC, IH, IW = 4, 5, 8, 6 | |||||
KH, KW = 3, 4 | |||||
OC = 3 | |||||
BIAS = True | |||||
def getsize(inp, kern, stride): | |||||
return (inp - 1) * stride + kern | |||||
OH = getsize(IH, KH, SH) | |||||
OW = getsize(IW, KW, SW) | |||||
inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) | |||||
out = np.zeros((N, OC, OH, OW), dtype=np.float32) | |||||
weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) | |||||
bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) | |||||
for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): | |||||
oh, ow = ih * SH, iw * SW | |||||
out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] | |||||
out = out[:, :, PH : OH - PH, PW : OW - PW] | |||||
if BIAS: | |||||
out += bias | |||||
conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) | |||||
conv_transpose2d.weight = Parameter(weight, dtype=np.float32) | |||||
if BIAS: | |||||
conv_transpose2d.bias = Parameter(bias, dtype=np.float32) | |||||
y = conv_transpose2d(tensor(inp)) | |||||
assertTensorClose(out, y.numpy(), max_err=2e-6) | |||||
torch_conv_transpose2d = torch.nn.ConvTranspose2d( | |||||
IC, OC, (KH, KW), stride=(SH, SW), padding=(PH, PW), bias=BIAS | |||||
) | |||||
torch_conv_transpose2d.weight = torch.nn.parameter.Parameter(torch.Tensor(weight)) | |||||
if BIAS: | |||||
torch_conv_transpose2d.bias = torch.nn.parameter.Parameter( | |||||
torch.Tensor(bias).reshape(OC) | |||||
) | |||||
torch_y = torch_conv_transpose2d(torch.Tensor(inp)) | |||||
assertTensorClose(torch_y.detach().numpy(), y.numpy(), max_err=2e-6) |