Browse Source

feat(mge/module): python wrapper for conv_transpose3d

GitOrigin-RevId: 61097b8713
release-1.4
Megvii Engine Team 4 years ago
parent
commit
11de49b943
6 changed files with 202 additions and 1 deletions
  1. +49
    -0
      imperative/python/megengine/functional/nn.py
  2. +1
    -0
      imperative/python/megengine/module/__init__.py
  3. +73
    -0
      imperative/python/megengine/module/conv.py
  4. +62
    -1
      imperative/python/test/unit/module/test_conv.py
  5. +15
    -0
      imperative/src/impl/ops/convolution.cpp
  6. +2
    -0
      src/core/include/megbrain/ir/ops.td

+ 49
- 0
imperative/python/megengine/functional/nn.py View File

@@ -48,6 +48,7 @@ __all__ = [
"conv2d", "conv2d",
"conv3d", "conv3d",
"conv_transpose2d", "conv_transpose2d",
"conv_transpose3d",
"deformable_conv2d", "deformable_conv2d",
"deformable_psroi_pooling", "deformable_psroi_pooling",
"dropout", "dropout",
@@ -488,6 +489,54 @@ def local_conv2d(
return output return output




def conv_transpose3d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
) -> Tensor:
"""
3D transposed convolution operation. Only support the case that group = 1
and conv_mode = "cross_correlation".

Refer to :class:`~.ConvTranspose3d` for more information.

:param inp: feature map of the convolution operation.
:param weight: convolution kernel.
:param bias: bias added to the result of convolution (if given).
:param stride: stride of the 3D convolution operation. Default: 1
:param padding: size of the paddings added to the input on all sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:return: output tensor.
"""
D, H, W = 0, 1, 2

pad = _triple(padding)
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)

op = builtin.Convolution3DBackwardData(
pad_d=pad[D],
pad_h=pad[H],
pad_w=pad[W],
stride_d=stride[D],
stride_h=stride[H],
stride_w=stride[W],
dilate_d=dilate[D],
dilate_h=dilate[H],
dilate_w=dilate[W],
strategy=get_execution_strategy(),
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
return output


def max_pool2d( def max_pool2d(
inp: Tensor, inp: Tensor,
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int, int]],


+ 1
- 0
imperative/python/megengine/module/__init__.py View File

@@ -18,6 +18,7 @@ from .conv import (
Conv3d, Conv3d,
ConvRelu2d, ConvRelu2d,
ConvTranspose2d, ConvTranspose2d,
ConvTranspose3d,
DeformableConv2d, DeformableConv2d,
LocalConv2d, LocalConv2d,
) )


+ 73
- 0
imperative/python/megengine/module/conv.py View File

@@ -15,6 +15,7 @@ from ..functional import (
conv2d, conv2d,
conv3d, conv3d,
conv_transpose2d, conv_transpose2d,
conv_transpose3d,
deformable_conv2d, deformable_conv2d,
local_conv2d, local_conv2d,
relu, relu,
@@ -842,3 +843,75 @@ class DeformableConv2d(_ConvNd):


def forward(self, inp, offset, mask): def forward(self, inp, offset, mask):
return self.calc_conv(inp, self.weight, offset, mask, self.bias) return self.calc_conv(inp, self.weight, offset, mask, self.bias)


class ConvTranspose3d(_ConvNd):
r"""
Applies a 3D transposed convolution over an input tensor.

Only support the case that group = 1 and conv_mode = "cross_correlation".

:class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation
with respect to its input.

Convolution3D usually reduces the size of input, while transposed convolution3d
works the opposite way, transforming a smaller input to a larger output while
preserving the connectivity pattern.

:param in_channels: number of input channels.
:param out_channels: number of output channels.
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
an :class:`int`, the actual kernel size would be
``(kernel_size, kernel_size, kernel_size)``. Default: 1
:param stride: stride of the 3D convolution operation. Default: 1
:param padding: size of the paddings added to the input on all sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:param bias: wether to add a bias onto the result of convolution. Default:
True
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
):
kernel_size = _triple_nonzero(kernel_size)
stride = _triple_nonzero(stride)
padding = _triple(padding)
dilation = _triple_nonzero(dilation)
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
bias=bias,
)

def _get_fanin(self):
kt, kh, kw = self.kernel_size
ic = self.in_channels
return kt * kh * kw * ic

def _infer_weight_shape(self):
ichl = self.in_channels
ochl = self.out_channels
kt, kh, kw = self.kernel_size
return (ochl, ichl, kt, kh, kw)

def _infer_bias_shape(self):
# Assume format is NCTHW
return (1, self.out_channels, 1, 1, 1)

def forward(self, inp):
return conv_transpose3d(
inp, self.weight, self.bias, self.stride, self.padding, self.dilation,
)

+ 62
- 1
imperative/python/test/unit/module/test_conv.py View File

@@ -11,7 +11,7 @@ import itertools
import numpy as np import numpy as np


from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import ConvTranspose2d, LocalConv2d
from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d




def test_conv_transpose2d(): def test_conv_transpose2d():
@@ -120,3 +120,64 @@ def test_local_conv2d():
test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1)
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2)
test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4)


def test_conv_transpose3d():
def getsize(inp, kernel, stride, dilate):
return (inp - 1) * stride + kernel * dilate - dilate + 1

def test_func(
N,
IC,
ID,
IH,
IW,
OC,
KD,
KH,
KW,
SD,
SH,
SW,
PD,
PH,
PW,
DD,
DH,
DW,
bias=True,
):
conv_transpose3d = ConvTranspose3d(
in_channels=IC,
out_channels=OC,
kernel_size=(KD, KH, KW),
stride=(SD, SH, SW),
padding=(PD, PH, PW),
dilation=(DD, DH, DW),
bias=bias,
)

OD = getsize(ID, KD, SD, DD)
OH = getsize(IH, KH, SH, DH)
OW = getsize(IW, KW, SW, DW)

inp = np.random.normal(size=(N, IC, ID, IH, IW))
weight = np.random.normal(size=(IC, OC, KD, KH, KW))
out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32)

for n, ic, idepth, ih, iw in itertools.product(
*map(range, [N, IC, ID, IH, IW])
):
od, oh, ow = idepth * SD, ih * SH, iw * SW
out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += (
inp[n, ic, idepth, ih, iw] * weight[ic]
)
out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW]

conv_transpose3d.weight = Parameter(weight)
out_meg = conv_transpose3d.forward(tensor(inp))

np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5)

test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1)
test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1)

+ 15
- 0
imperative/src/impl/ops/convolution.cpp View File

@@ -75,5 +75,20 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
.fallback(); .fallback();
}} // convolution3d }} // convolution3d


namespace { namespace convolution3d_backward_data {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()};
mgb_assert(inputs.size() == 2);
return opr::Convolution3DBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
}

OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution3d_backward_data

} }
} }

+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -53,6 +53,8 @@ def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [Convoluti


def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;


def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;

def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;


def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;


Loading…
Cancel
Save