GitOrigin-RevId: 61097b8713
release-1.4
@@ -48,6 +48,7 @@ __all__ = [ | |||||
"conv2d", | "conv2d", | ||||
"conv3d", | "conv3d", | ||||
"conv_transpose2d", | "conv_transpose2d", | ||||
"conv_transpose3d", | |||||
"deformable_conv2d", | "deformable_conv2d", | ||||
"deformable_psroi_pooling", | "deformable_psroi_pooling", | ||||
"dropout", | "dropout", | ||||
@@ -488,6 +489,54 @@ def local_conv2d( | |||||
return output | return output | ||||
def conv_transpose3d( | |||||
inp: Tensor, | |||||
weight: Tensor, | |||||
bias: Optional[Tensor] = None, | |||||
stride: Union[int, Tuple[int, int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
) -> Tensor: | |||||
""" | |||||
3D transposed convolution operation. Only support the case that group = 1 | |||||
and conv_mode = "cross_correlation". | |||||
Refer to :class:`~.ConvTranspose3d` for more information. | |||||
:param inp: feature map of the convolution operation. | |||||
:param weight: convolution kernel. | |||||
:param bias: bias added to the result of convolution (if given). | |||||
:param stride: stride of the 3D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on all sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
:return: output tensor. | |||||
""" | |||||
D, H, W = 0, 1, 2 | |||||
pad = _triple(padding) | |||||
stride = _triple_nonzero(stride) | |||||
dilate = _triple_nonzero(dilation) | |||||
op = builtin.Convolution3DBackwardData( | |||||
pad_d=pad[D], | |||||
pad_h=pad[H], | |||||
pad_w=pad[W], | |||||
stride_d=stride[D], | |||||
stride_h=stride[H], | |||||
stride_w=stride[W], | |||||
dilate_d=dilate[D], | |||||
dilate_h=dilate[H], | |||||
dilate_w=dilate[W], | |||||
strategy=get_execution_strategy(), | |||||
) | |||||
weight, inp = utils.convert_inputs(weight, inp) | |||||
(output,) = apply(op, weight, inp) | |||||
if bias is not None: | |||||
output += bias | |||||
return output | |||||
def max_pool2d( | def max_pool2d( | ||||
inp: Tensor, | inp: Tensor, | ||||
kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
@@ -18,6 +18,7 @@ from .conv import ( | |||||
Conv3d, | Conv3d, | ||||
ConvRelu2d, | ConvRelu2d, | ||||
ConvTranspose2d, | ConvTranspose2d, | ||||
ConvTranspose3d, | |||||
DeformableConv2d, | DeformableConv2d, | ||||
LocalConv2d, | LocalConv2d, | ||||
) | ) | ||||
@@ -15,6 +15,7 @@ from ..functional import ( | |||||
conv2d, | conv2d, | ||||
conv3d, | conv3d, | ||||
conv_transpose2d, | conv_transpose2d, | ||||
conv_transpose3d, | |||||
deformable_conv2d, | deformable_conv2d, | ||||
local_conv2d, | local_conv2d, | ||||
relu, | relu, | ||||
@@ -842,3 +843,75 @@ class DeformableConv2d(_ConvNd): | |||||
def forward(self, inp, offset, mask): | def forward(self, inp, offset, mask): | ||||
return self.calc_conv(inp, self.weight, offset, mask, self.bias) | return self.calc_conv(inp, self.weight, offset, mask, self.bias) | ||||
class ConvTranspose3d(_ConvNd): | |||||
r""" | |||||
Applies a 3D transposed convolution over an input tensor. | |||||
Only support the case that group = 1 and conv_mode = "cross_correlation". | |||||
:class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation | |||||
with respect to its input. | |||||
Convolution3D usually reduces the size of input, while transposed convolution3d | |||||
works the opposite way, transforming a smaller input to a larger output while | |||||
preserving the connectivity pattern. | |||||
:param in_channels: number of input channels. | |||||
:param out_channels: number of output channels. | |||||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
an :class:`int`, the actual kernel size would be | |||||
``(kernel_size, kernel_size, kernel_size)``. Default: 1 | |||||
:param stride: stride of the 3D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on all sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
:param bias: wether to add a bias onto the result of convolution. Default: | |||||
True | |||||
""" | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
kernel_size: Union[int, Tuple[int, int, int]], | |||||
stride: Union[int, Tuple[int, int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
bias: bool = True, | |||||
): | |||||
kernel_size = _triple_nonzero(kernel_size) | |||||
stride = _triple_nonzero(stride) | |||||
padding = _triple(padding) | |||||
dilation = _triple_nonzero(dilation) | |||||
super().__init__( | |||||
in_channels=in_channels, | |||||
out_channels=out_channels, | |||||
kernel_size=kernel_size, | |||||
stride=stride, | |||||
padding=padding, | |||||
dilation=dilation, | |||||
groups=1, | |||||
bias=bias, | |||||
) | |||||
def _get_fanin(self): | |||||
kt, kh, kw = self.kernel_size | |||||
ic = self.in_channels | |||||
return kt * kh * kw * ic | |||||
def _infer_weight_shape(self): | |||||
ichl = self.in_channels | |||||
ochl = self.out_channels | |||||
kt, kh, kw = self.kernel_size | |||||
return (ochl, ichl, kt, kh, kw) | |||||
def _infer_bias_shape(self): | |||||
# Assume format is NCTHW | |||||
return (1, self.out_channels, 1, 1, 1) | |||||
def forward(self, inp): | |||||
return conv_transpose3d( | |||||
inp, self.weight, self.bias, self.stride, self.padding, self.dilation, | |||||
) |
@@ -11,7 +11,7 @@ import itertools | |||||
import numpy as np | import numpy as np | ||||
from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
from megengine.module import ConvTranspose2d, LocalConv2d | |||||
from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d | |||||
def test_conv_transpose2d(): | def test_conv_transpose2d(): | ||||
@@ -120,3 +120,64 @@ def test_local_conv2d(): | |||||
test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) | test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) | ||||
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) | test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) | ||||
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) | test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) | ||||
def test_conv_transpose3d(): | |||||
def getsize(inp, kernel, stride, dilate): | |||||
return (inp - 1) * stride + kernel * dilate - dilate + 1 | |||||
def test_func( | |||||
N, | |||||
IC, | |||||
ID, | |||||
IH, | |||||
IW, | |||||
OC, | |||||
KD, | |||||
KH, | |||||
KW, | |||||
SD, | |||||
SH, | |||||
SW, | |||||
PD, | |||||
PH, | |||||
PW, | |||||
DD, | |||||
DH, | |||||
DW, | |||||
bias=True, | |||||
): | |||||
conv_transpose3d = ConvTranspose3d( | |||||
in_channels=IC, | |||||
out_channels=OC, | |||||
kernel_size=(KD, KH, KW), | |||||
stride=(SD, SH, SW), | |||||
padding=(PD, PH, PW), | |||||
dilation=(DD, DH, DW), | |||||
bias=bias, | |||||
) | |||||
OD = getsize(ID, KD, SD, DD) | |||||
OH = getsize(IH, KH, SH, DH) | |||||
OW = getsize(IW, KW, SW, DW) | |||||
inp = np.random.normal(size=(N, IC, ID, IH, IW)) | |||||
weight = np.random.normal(size=(IC, OC, KD, KH, KW)) | |||||
out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32) | |||||
for n, ic, idepth, ih, iw in itertools.product( | |||||
*map(range, [N, IC, ID, IH, IW]) | |||||
): | |||||
od, oh, ow = idepth * SD, ih * SH, iw * SW | |||||
out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += ( | |||||
inp[n, ic, idepth, ih, iw] * weight[ic] | |||||
) | |||||
out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW] | |||||
conv_transpose3d.weight = Parameter(weight) | |||||
out_meg = conv_transpose3d.forward(tensor(inp)) | |||||
np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5) | |||||
test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1) | |||||
test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1) |
@@ -75,5 +75,20 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) | |||||
.fallback(); | .fallback(); | ||||
}} // convolution3d | }} // convolution3d | ||||
namespace { namespace convolution3d_backward_data { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const Convolution3DBackwardData&>(def); | |||||
OperatorNodeConfig config{conv.make_name()}; | |||||
mgb_assert(inputs.size() == 2); | |||||
return opr::Convolution3DBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
} | |||||
OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // convolution3d_backward_data | |||||
} | } | ||||
} | } |
@@ -53,6 +53,8 @@ def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [Convoluti | |||||
def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | ||||
def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | |||||
def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | ||||
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | ||||