This reverts committags/v0.4.08e4f25bfd8
. GitOrigin-RevId:df696ab8a2
@@ -65,7 +65,6 @@ from .nn import ( | |||
interpolate, | |||
leaky_relu, | |||
linear, | |||
local_conv2d, | |||
matrix_mul, | |||
max_pool2d, | |||
one_hot, | |||
@@ -171,34 +171,6 @@ def conv_transpose2d( | |||
@wrap_io_tensor | |||
def local_conv2d( | |||
inp: Tensor, | |||
weight: Tensor, | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
conv_mode="CROSS_CORRELATION", | |||
) -> Tensor: | |||
"""Applies spatial 2D convolution over an image with untied kernels. | |||
Refer to :class:`~.LocalConv2d` for more information. | |||
""" | |||
ret = mgb.opr.group_local( | |||
inp, | |||
weight, | |||
pad_h=padding[0], | |||
pad_w=padding[1], | |||
stride_h=stride[0], | |||
stride_w=stride[1], | |||
dilate_h=dilation[0], | |||
dilate_w=dilation[1], | |||
format="NCHW", | |||
mode=conv_mode, | |||
) | |||
return ret | |||
@wrap_io_tensor | |||
def max_pool2d( | |||
inp: Tensor, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
@@ -9,7 +9,7 @@ | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .batchnorm import BatchNorm1d, BatchNorm2d | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvTranspose2d, LocalConv2d | |||
from .conv import Conv2d, ConvTranspose2d | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .dropout import Dropout | |||
from .elemwise import Elemwise | |||
@@ -14,7 +14,7 @@ import numpy as np | |||
import megengine._internal as mgb | |||
from ..core import Parameter | |||
from ..functional import conv2d, conv_transpose2d, local_conv2d | |||
from ..functional import conv2d, conv_transpose2d | |||
from ..utils.types import _pair, _pair_nonzero | |||
from . import init | |||
from .module import Module | |||
@@ -224,7 +224,7 @@ class ConvTranspose2d(_ConvNd): | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be ``(groups, | |||
out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
:param bias: wether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
@@ -306,77 +306,3 @@ class ConvTranspose2d(_ConvNd): | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
class LocalConv2d(Conv2d): | |||
r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||
It is also known as the locally connected layer. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param input_height: the height of the input images. | |||
:param input_width: the width of the input images. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
The shape of weight is ``(groups, output_height, output_width, | |||
in_channels // groups, *kernel_size, out_channels // groups)``. | |||
""" | |||
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
input_height: int, | |||
input_width: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
): | |||
self.input_height = input_height | |||
self.input_width = input_width | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias=False, | |||
) | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
output_height = ( | |||
self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||
) // self.stride[0] + 1 | |||
output_width = ( | |||
self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||
) // self.stride[1] + 1 | |||
# Assume format is NCHW | |||
return ( | |||
group, | |||
output_height, | |||
output_width, | |||
self.in_channels // group, | |||
self.kernel_size[0], | |||
self.kernel_size[1], | |||
self.out_channels // group, | |||
) | |||
def forward(self, inp): | |||
return local_conv2d( | |||
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
) |
@@ -11,7 +11,7 @@ import itertools | |||
import numpy as np | |||
from megengine import Parameter, tensor | |||
from megengine.module import ConvTranspose2d, LocalConv2d | |||
from megengine.module import ConvTranspose2d | |||
from megengine.test import assertTensorClose | |||
@@ -50,61 +50,3 @@ def test_conv_transpose2d(): | |||
y = conv_transpose2d(tensor(inp)) | |||
assertTensorClose(out, y.numpy(), max_err=2e-6) | |||
def test_local_conv2d(): | |||
batch_size = 10 | |||
in_channels = 4 | |||
out_channels = 8 | |||
input_height = 8 | |||
input_width = 8 | |||
kernel_size = 3 | |||
stride = 1 | |||
padding = 1 | |||
dilation = 1 | |||
groups = 1 | |||
local_conv2d = LocalConv2d( | |||
in_channels=in_channels, | |||
out_channels=out_channels, | |||
input_height=input_height, | |||
input_width=input_width, | |||
kernel_size=kernel_size, | |||
stride=stride, | |||
padding=padding, | |||
dilation=dilation, | |||
groups=groups, | |||
) | |||
inputs = np.random.normal( | |||
size=(batch_size, in_channels, input_height, input_width) | |||
).astype(np.float32) | |||
output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||
output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||
weights = np.random.normal( | |||
size=( | |||
groups, | |||
output_height, | |||
output_width, | |||
in_channels // groups, | |||
kernel_size, | |||
kernel_size, | |||
out_channels // groups, | |||
) | |||
).astype(np.float32) | |||
local_conv2d.weight = Parameter(weights) | |||
outputs = local_conv2d(tensor(inputs)) | |||
# naive calculation use numpy | |||
# only test output_height == input_height, output_width == input_width, group == 1 | |||
inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||
expected = np.zeros( | |||
(batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||
) | |||
for n, oc, oh, ow in itertools.product( | |||
*map(range, [batch_size, out_channels, output_height, output_width]) | |||
): | |||
ih, iw = oh * stride, ow * stride | |||
expected[n, oc, ih, iw] = np.sum( | |||
inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size] | |||
* weights[0, oh, ow, :, :, :, oc] | |||
) | |||
assertTensorClose(outputs.numpy(), expected, max_err=1e-5) |