@@ -67,6 +67,7 @@ from .nn import ( | |||||
interpolate, | interpolate, | ||||
leaky_relu, | leaky_relu, | ||||
linear, | linear, | ||||
local_conv2d, | |||||
matrix_mul, | matrix_mul, | ||||
max_pool2d, | max_pool2d, | ||||
one_hot, | one_hot, | ||||
@@ -172,6 +172,34 @@ def conv_transpose2d( | |||||
@wrap_io_tensor | @wrap_io_tensor | ||||
def local_conv2d( | |||||
inp: Tensor, | |||||
weight: Tensor, | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
conv_mode="CROSS_CORRELATION", | |||||
) -> Tensor: | |||||
"""Applies spatial 2D convolution over an image with untied kernels. | |||||
Refer to :class:`~.LocalConv2d` for more information. | |||||
""" | |||||
ret = mgb.opr.group_local( | |||||
inp, | |||||
weight, | |||||
pad_h=padding[0], | |||||
pad_w=padding[1], | |||||
stride_h=stride[0], | |||||
stride_w=stride[1], | |||||
dilate_h=dilation[0], | |||||
dilate_w=dilation[1], | |||||
format="NCHW", | |||||
mode=conv_mode, | |||||
) | |||||
return ret | |||||
@wrap_io_tensor | |||||
def max_pool2d( | def max_pool2d( | ||||
inp: Tensor, | inp: Tensor, | ||||
kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
@@ -9,7 +9,7 @@ | |||||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
from .batchnorm import BatchNorm1d, BatchNorm2d | from .batchnorm import BatchNorm1d, BatchNorm2d | ||||
from .concat import Concat | from .concat import Concat | ||||
from .conv import Conv2d, ConvTranspose2d | |||||
from .conv import Conv2d, ConvTranspose2d, LocalConv2d | |||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | ||||
from .dropout import Dropout | from .dropout import Dropout | ||||
from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
@@ -14,7 +14,7 @@ import numpy as np | |||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
from ..core import Parameter | from ..core import Parameter | ||||
from ..functional import conv2d, conv_transpose2d | |||||
from ..functional import conv2d, conv_transpose2d, local_conv2d | |||||
from ..utils.types import _pair, _pair_nonzero | from ..utils.types import _pair, _pair_nonzero | ||||
from . import init | from . import init | ||||
from .module import Module | from .module import Module | ||||
@@ -224,7 +224,7 @@ class ConvTranspose2d(_ConvNd): | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
and there would be an extra dimension at the beginning of the weight's | and there would be an extra dimension at the beginning of the weight's | ||||
shape. Specifically, the shape of weight would be ``(groups, | shape. Specifically, the shape of weight would be ``(groups, | ||||
out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||||
out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||||
:param bias: wether to add a bias onto the result of convolution. Default: | :param bias: wether to add a bias onto the result of convolution. Default: | ||||
True | True | ||||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | ||||
@@ -306,3 +306,77 @@ class ConvTranspose2d(_ConvNd): | |||||
self.conv_mode, | self.conv_mode, | ||||
self.compute_mode, | self.compute_mode, | ||||
) | ) | ||||
class LocalConv2d(Conv2d): | |||||
r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||||
It is also known as the locally connected layer. | |||||
:param in_channels: number of input channels. | |||||
:param out_channels: number of output channels. | |||||
:param input_height: the height of the input images. | |||||
:param input_width: the width of the input images. | |||||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
an :class:`int`, the actual kernel size would be | |||||
``(kernel_size, kernel_size)``. Default: 1 | |||||
:param stride: stride of the 2D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param groups: number of groups to divide input and output channels into, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||||
The shape of weight is ``(groups, output_height, output_width, | |||||
in_channels // groups, *kernel_size, out_channels // groups)``. | |||||
""" | |||||
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
input_height: int, | |||||
input_width: int, | |||||
kernel_size: Union[int, Tuple[int, int]], | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
conv_mode: str = "CROSS_CORRELATION", | |||||
): | |||||
self.input_height = input_height | |||||
self.input_width = input_width | |||||
super().__init__( | |||||
in_channels, | |||||
out_channels, | |||||
kernel_size, | |||||
stride, | |||||
padding, | |||||
dilation, | |||||
groups, | |||||
bias=False, | |||||
) | |||||
def _infer_weight_shape(self): | |||||
group = self.groups | |||||
output_height = ( | |||||
self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||||
) // self.stride[0] + 1 | |||||
output_width = ( | |||||
self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||||
) // self.stride[1] + 1 | |||||
# Assume format is NCHW | |||||
return ( | |||||
group, | |||||
output_height, | |||||
output_width, | |||||
self.in_channels // group, | |||||
self.kernel_size[0], | |||||
self.kernel_size[1], | |||||
self.out_channels // group, | |||||
) | |||||
def forward(self, inp): | |||||
return local_conv2d( | |||||
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||||
) |
@@ -11,7 +11,7 @@ import itertools | |||||
import numpy as np | import numpy as np | ||||
from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
from megengine.module import ConvTranspose2d | |||||
from megengine.module import ConvTranspose2d, LocalConv2d | |||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -50,3 +50,61 @@ def test_conv_transpose2d(): | |||||
y = conv_transpose2d(tensor(inp)) | y = conv_transpose2d(tensor(inp)) | ||||
assertTensorClose(out, y.numpy(), max_err=2e-6) | assertTensorClose(out, y.numpy(), max_err=2e-6) | ||||
def test_local_conv2d(): | |||||
batch_size = 10 | |||||
in_channels = 4 | |||||
out_channels = 8 | |||||
input_height = 8 | |||||
input_width = 8 | |||||
kernel_size = 3 | |||||
stride = 1 | |||||
padding = 1 | |||||
dilation = 1 | |||||
groups = 1 | |||||
local_conv2d = LocalConv2d( | |||||
in_channels=in_channels, | |||||
out_channels=out_channels, | |||||
input_height=input_height, | |||||
input_width=input_width, | |||||
kernel_size=kernel_size, | |||||
stride=stride, | |||||
padding=padding, | |||||
dilation=dilation, | |||||
groups=groups, | |||||
) | |||||
inputs = np.random.normal( | |||||
size=(batch_size, in_channels, input_height, input_width) | |||||
).astype(np.float32) | |||||
output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||||
output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||||
weights = np.random.normal( | |||||
size=( | |||||
groups, | |||||
output_height, | |||||
output_width, | |||||
in_channels // groups, | |||||
kernel_size, | |||||
kernel_size, | |||||
out_channels // groups, | |||||
) | |||||
).astype(np.float32) | |||||
local_conv2d.weight = Parameter(weights) | |||||
outputs = local_conv2d(tensor(inputs)) | |||||
# naive calculation use numpy | |||||
# only test output_height == input_height, output_width == input_width, group == 1 | |||||
inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||||
expected = np.zeros( | |||||
(batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||||
) | |||||
for n, oc, oh, ow in itertools.product( | |||||
*map(range, [batch_size, out_channels, output_height, output_width]) | |||||
): | |||||
ih, iw = oh * stride, ow * stride | |||||
expected[n, oc, ih, iw] = np.sum( | |||||
inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size] | |||||
* weights[0, oh, ow, :, :, :, oc] | |||||
) | |||||
assertTensorClose(outputs.numpy(), expected, max_err=1e-5) |
@@ -112,9 +112,10 @@ decl_opr('GroupLocal', | |||||
'convolution kernel in ' | 'convolution kernel in ' | ||||
'(group, out row, out col, in channel / group, ' | '(group, out row, out col, in channel / group, ' | ||||
'kern row, kern col, out channel / group) format')], | 'kern row, kern col, out channel / group) format')], | ||||
params='ConvolutionV0', | |||||
params=[('param', 'Convolution')], | |||||
desc='batched convolution on groupped channeled 2D images, but ' | desc='batched convolution on groupped channeled 2D images, but ' | ||||
'kernels are not shared across different output positions') | |||||
'kernels are not shared across different output positions', | |||||
version=1) | |||||
decl_opr('LRN', | decl_opr('LRN', | ||||
inputs=['src'], | inputs=['src'], | ||||