@@ -22,8 +22,8 @@ from ..device import get_default_device | |||
from ..distributed import WORLD, is_distributed | |||
from ..random import uniform | |||
from ..tensor import Tensor | |||
from ..utils.tuple_function import _pair, _pair_nonzero | |||
from .debug_param import get_conv_execution_strategy, get_execution_strategy | |||
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||
from .debug_param import get_execution_strategy | |||
from .distributed import all_reduce_sum | |||
from .elemwise import exp, floor, log, log1p, maximum, minimum | |||
from .math import argsort, matmul, max, prod, sum | |||
@@ -43,7 +43,9 @@ __all__ = [ | |||
"adaptive_max_pool2d", | |||
"avg_pool2d", | |||
"batch_norm", | |||
"conv1d", | |||
"conv2d", | |||
"conv3d", | |||
"conv_transpose2d", | |||
"deformable_conv2d", | |||
"deformable_psroi_pooling", | |||
@@ -166,6 +168,66 @@ def conv2d( | |||
return output | |||
def conv3d( | |||
inp: Tensor, | |||
weight: Tensor, | |||
bias: Optional[Tensor] = None, | |||
stride: Union[int, Tuple[int, int, int]] = 1, | |||
padding: Union[int, Tuple[int, int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int, int]] = 1, | |||
groups: int = 1, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
) -> Tensor: | |||
""" | |||
3D convolution operation. | |||
Refer to :class:`~.Conv3d` for more information. | |||
:param inp: feature map of the convolution operation. | |||
:param weight: convolution kernel. | |||
:param bias: bias added to the result of convolution (if given). | |||
:param stride: stride of the 3D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 3D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, t, height, width)`. | |||
:param conv_mode: supports "CROSS_CORRELATION". Default: | |||
"CROSS_CORRELATION" | |||
:return: output tensor. | |||
""" | |||
assert conv_mode == "CROSS_CORRELATION" | |||
D, H, W = 0, 1, 2 | |||
pad = _triple(padding) | |||
stride = _triple_nonzero(stride) | |||
dilate = _triple_nonzero(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.Convolution3D( | |||
pad_d=pad[D], | |||
pad_h=pad[H], | |||
pad_w=pad[W], | |||
stride_d=stride[D], | |||
stride_h=stride[H], | |||
stride_w=stride[W], | |||
dilate_d=dilate[D], | |||
dilate_h=dilate[H], | |||
dilate_w=dilate[W], | |||
strategy=get_execution_strategy(), | |||
mode=conv_mode, | |||
sparse=sparse_type, | |||
) | |||
inp, weight = utils.convert_inputs(inp, weight) | |||
(output,) = apply(op, inp, weight) | |||
if bias is not None: | |||
output += bias | |||
return output | |||
def conv_transpose2d( | |||
inp: Tensor, | |||
weight: Tensor, | |||
@@ -1094,7 +1156,7 @@ def matmul( | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=get_conv_execution_strategy(), | |||
strategy=get_execution_strategy(), | |||
) | |||
else: | |||
op = builtin.MatrixMul( | |||
@@ -1102,7 +1164,7 @@ def matmul( | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=get_conv_execution_strategy(), | |||
strategy=get_execution_strategy(), | |||
) | |||
(result,) = apply(op, inp1, inp2) | |||
@@ -15,6 +15,7 @@ from .concat import Concat | |||
from .conv import ( | |||
Conv1d, | |||
Conv2d, | |||
Conv3d, | |||
ConvRelu2d, | |||
ConvTranspose2d, | |||
DeformableConv2d, | |||
@@ -13,13 +13,14 @@ import numpy as np | |||
from ..functional import ( | |||
conv1d, | |||
conv2d, | |||
conv3d, | |||
conv_transpose2d, | |||
deformable_conv2d, | |||
local_conv2d, | |||
relu, | |||
) | |||
from ..tensor import Parameter | |||
from ..utils.tuple_function import _pair, _pair_nonzero | |||
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||
from . import init | |||
from .module import Module | |||
@@ -400,6 +401,142 @@ class Conv2d(_ConvNd): | |||
return self.calc_conv(inp, self.weight, self.bias) | |||
class Conv3d(_ConvNd): | |||
r""" | |||
Applies a 3D convolution over an input tensor. | |||
For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`, | |||
this layer generates an output of the size | |||
:math:`(N, C_{\text{out}}, T_{\text{out}}}, H_{\text{out}}}, W_{\text{out}}})` through the | |||
process described as below: | |||
.. math:: | |||
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||
where :math:`\star` is the valid 3D cross-correlation operator, | |||
:math:`N` is batch size, :math:`C` denotes number of channels | |||
When `groups == in_channels` and `out_channels == K * in_channels`, | |||
where K is a positive integer, this operation is also known as depthwise | |||
convolution. | |||
In other words, for an input of size :math:`(N, C_{in}, T_{int}, H_{in}, W_{in})`, | |||
a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||
an :class:`int`, the actual kernel size would be | |||
`(kernel_size, kernel_size, kernel_size)`. Default: 1 | |||
:param stride: stride of the 3D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 3D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be `(groups, | |||
out_channel // groups, in_channels // groups, *kernel_size)`. | |||
:param bias: whether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||
`CROSS_CORRELATION` | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
m = M.Conv3d(in_channels=3, out_channels=1, kernel_size=3) | |||
inp = mge.tensor(np.arange(0, 384).astype("float32").reshape(2, 3, 4, 4, 4)) | |||
oup = m(inp) | |||
print(oup.numpy().shape) | |||
Outputs: | |||
.. testoutput:: | |||
(2, 1, 2, 2, 2) | |||
""" | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int, int]], | |||
stride: Union[int, Tuple[int, int, int]] = 1, | |||
padding: Union[int, Tuple[int, int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
): | |||
kernel_size = _triple_nonzero(kernel_size) | |||
stride = _triple_nonzero(stride) | |||
padding = _triple(padding) | |||
dilation = _triple_nonzero(dilation) | |||
self.conv_mode = conv_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
) | |||
def _get_fanin(self): | |||
kt, kh, kw = self.kernel_size | |||
ic = self.in_channels | |||
return kt * kh * kw * ic | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
ichl = self.in_channels | |||
ochl = self.out_channels | |||
kt, kh, kw = self.kernel_size | |||
if group == 1: | |||
# Assume format is NCTHW | |||
return (ochl, ichl, kt, kh, kw) | |||
assert ( | |||
ichl % group == 0 and ochl % group == 0 | |||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||
ichl, ochl, group | |||
) | |||
# Assume format is NCTHW | |||
return (group, ochl // group, ichl // group, kt, kh, kw) | |||
def _infer_bias_shape(self): | |||
# Assume format is NCTHW | |||
return (1, self.out_channels, 1, 1, 1) | |||
def calc_conv(self, inp, weight, bias): | |||
return conv3d( | |||
inp, | |||
weight, | |||
bias, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
) | |||
def forward(self, inp): | |||
return self.calc_conv(inp, self.weight, self.bias) | |||
class ConvTranspose2d(_ConvNd): | |||
r""" | |||
Applies a 2D transposed convolution over an input tensor. | |||
@@ -35,4 +35,5 @@ _single = functools.partial(get_ndtuple, n=1, allow_zero=True) | |||
_pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | |||
_pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | |||
_triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | |||
_triple_nonzero = functools.partial(get_ndtuple, n=3, allow_zero=False) | |||
_quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) |
@@ -637,7 +637,7 @@ def test_batch_conv_bias(): | |||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||
def test_zero_stride_numpy_array(): | |||
def test_conv2d_zero_stride_numpy_array(): | |||
inp = np.random.randn(3, 224, 224).astype(np.float32) | |||
inp = inp[np.newaxis, :] | |||
@@ -646,6 +646,16 @@ def test_zero_stride_numpy_array(): | |||
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | |||
def test_conv3d_zero_stride_numpy_array(): | |||
inp = np.random.randn(3, 224, 224, 224).astype(np.float32) | |||
inp = inp[np.newaxis, :] | |||
inp = tensor(inp, dtype=np.float32) | |||
weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32) | |||
out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1) | |||
out.numpy() | |||
def test_conv1d(): | |||
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | |||
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | |||
@@ -658,6 +668,16 @@ def test_conv1d(): | |||
) | |||
def test_conv3d(): | |||
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) | |||
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) | |||
out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | |||
print(out.numpy().shape) | |||
np.testing.assert_equal( | |||
out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16 | |||
) | |||
def test_condtake(): | |||
x = np.array([[1, 2, 3], [4, 5, 6]]) | |||
y = np.array([[True, False, True], [False, True, True]]) | |||
@@ -0,0 +1,79 @@ | |||
/** | |||
* \file imperative/src/impl/ops/dnn/convolution.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "../op_trait.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { namespace convolution { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||
return Convolution::make(node->param(), node->execution_policy()); | |||
} | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const Convolution&>(def); | |||
OperatorNodeConfig config{conv.make_name()}; | |||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} | |||
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // convolution | |||
namespace { namespace convolution_backward_data { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||
OperatorNodeConfig config{conv.make_name()}; | |||
if (inputs.size() == 2) { | |||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else { | |||
mgb_assert(inputs.size() == 3); | |||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
} | |||
} | |||
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // convolution_backward_data | |||
namespace { namespace convolution3d { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::Convolution3D>(); | |||
return Convolution3D::make(node->param(), node->execution_policy()); | |||
} | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const Convolution3D&>(def); | |||
return opr::Convolution3D::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||
} | |||
OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // convolution3d | |||
} | |||
} |
@@ -36,45 +36,6 @@ | |||
namespace mgb::imperative { | |||
namespace { namespace convolution { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||
return Convolution::make(node->param(), node->execution_policy()); | |||
} | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const Convolution&>(def); | |||
OperatorNodeConfig config{conv.make_name()}; | |||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} | |||
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // convolution | |||
namespace { namespace convolution_backward_data { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||
OperatorNodeConfig config{conv.make_name()}; | |||
if (inputs.size() == 2) { | |||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else { | |||
mgb_assert(inputs.size() == 3); | |||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
} | |||
} | |||
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // convolution_backward_data | |||
namespace { namespace dimshuffle { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | |||
@@ -51,6 +51,8 @@ def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolic | |||
def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | |||
def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | |||