@@ -22,8 +22,8 @@ from ..device import get_default_device | |||||
from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
from ..random import uniform | from ..random import uniform | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.tuple_function import _pair, _pair_nonzero | |||||
from .debug_param import get_conv_execution_strategy, get_execution_strategy | |||||
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||||
from .debug_param import get_execution_strategy | |||||
from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
from .elemwise import exp, floor, log, log1p, maximum, minimum | from .elemwise import exp, floor, log, log1p, maximum, minimum | ||||
from .math import argsort, matmul, max, prod, sum | from .math import argsort, matmul, max, prod, sum | ||||
@@ -43,7 +43,9 @@ __all__ = [ | |||||
"adaptive_max_pool2d", | "adaptive_max_pool2d", | ||||
"avg_pool2d", | "avg_pool2d", | ||||
"batch_norm", | "batch_norm", | ||||
"conv1d", | |||||
"conv2d", | "conv2d", | ||||
"conv3d", | |||||
"conv_transpose2d", | "conv_transpose2d", | ||||
"deformable_conv2d", | "deformable_conv2d", | ||||
"deformable_psroi_pooling", | "deformable_psroi_pooling", | ||||
@@ -166,6 +168,66 @@ def conv2d( | |||||
return output | return output | ||||
def conv3d( | |||||
inp: Tensor, | |||||
weight: Tensor, | |||||
bias: Optional[Tensor] = None, | |||||
stride: Union[int, Tuple[int, int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
groups: int = 1, | |||||
conv_mode: str = "CROSS_CORRELATION", | |||||
) -> Tensor: | |||||
""" | |||||
3D convolution operation. | |||||
Refer to :class:`~.Conv3d` for more information. | |||||
:param inp: feature map of the convolution operation. | |||||
:param weight: convolution kernel. | |||||
:param bias: bias added to the result of convolution (if given). | |||||
:param stride: stride of the 3D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
and the shape of weight should be `(groups, out_channel // groups, | |||||
in_channels // groups, t, height, width)`. | |||||
:param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
"CROSS_CORRELATION" | |||||
:return: output tensor. | |||||
""" | |||||
assert conv_mode == "CROSS_CORRELATION" | |||||
D, H, W = 0, 1, 2 | |||||
pad = _triple(padding) | |||||
stride = _triple_nonzero(stride) | |||||
dilate = _triple_nonzero(dilation) | |||||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
op = builtin.Convolution3D( | |||||
pad_d=pad[D], | |||||
pad_h=pad[H], | |||||
pad_w=pad[W], | |||||
stride_d=stride[D], | |||||
stride_h=stride[H], | |||||
stride_w=stride[W], | |||||
dilate_d=dilate[D], | |||||
dilate_h=dilate[H], | |||||
dilate_w=dilate[W], | |||||
strategy=get_execution_strategy(), | |||||
mode=conv_mode, | |||||
sparse=sparse_type, | |||||
) | |||||
inp, weight = utils.convert_inputs(inp, weight) | |||||
(output,) = apply(op, inp, weight) | |||||
if bias is not None: | |||||
output += bias | |||||
return output | |||||
def conv_transpose2d( | def conv_transpose2d( | ||||
inp: Tensor, | inp: Tensor, | ||||
weight: Tensor, | weight: Tensor, | ||||
@@ -1094,7 +1156,7 @@ def matmul( | |||||
transposeB=transpose_b, | transposeB=transpose_b, | ||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
format=format, | format=format, | ||||
strategy=get_conv_execution_strategy(), | |||||
strategy=get_execution_strategy(), | |||||
) | ) | ||||
else: | else: | ||||
op = builtin.MatrixMul( | op = builtin.MatrixMul( | ||||
@@ -1102,7 +1164,7 @@ def matmul( | |||||
transposeB=transpose_b, | transposeB=transpose_b, | ||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
format=format, | format=format, | ||||
strategy=get_conv_execution_strategy(), | |||||
strategy=get_execution_strategy(), | |||||
) | ) | ||||
(result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
@@ -15,6 +15,7 @@ from .concat import Concat | |||||
from .conv import ( | from .conv import ( | ||||
Conv1d, | Conv1d, | ||||
Conv2d, | Conv2d, | ||||
Conv3d, | |||||
ConvRelu2d, | ConvRelu2d, | ||||
ConvTranspose2d, | ConvTranspose2d, | ||||
DeformableConv2d, | DeformableConv2d, | ||||
@@ -13,13 +13,14 @@ import numpy as np | |||||
from ..functional import ( | from ..functional import ( | ||||
conv1d, | conv1d, | ||||
conv2d, | conv2d, | ||||
conv3d, | |||||
conv_transpose2d, | conv_transpose2d, | ||||
deformable_conv2d, | deformable_conv2d, | ||||
local_conv2d, | local_conv2d, | ||||
relu, | relu, | ||||
) | ) | ||||
from ..tensor import Parameter | from ..tensor import Parameter | ||||
from ..utils.tuple_function import _pair, _pair_nonzero | |||||
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||||
from . import init | from . import init | ||||
from .module import Module | from .module import Module | ||||
@@ -400,6 +401,142 @@ class Conv2d(_ConvNd): | |||||
return self.calc_conv(inp, self.weight, self.bias) | return self.calc_conv(inp, self.weight, self.bias) | ||||
class Conv3d(_ConvNd): | |||||
r""" | |||||
Applies a 3D convolution over an input tensor. | |||||
For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`, | |||||
this layer generates an output of the size | |||||
:math:`(N, C_{\text{out}}, T_{\text{out}}}, H_{\text{out}}}, W_{\text{out}}})` through the | |||||
process described as below: | |||||
.. math:: | |||||
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||||
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||||
where :math:`\star` is the valid 3D cross-correlation operator, | |||||
:math:`N` is batch size, :math:`C` denotes number of channels | |||||
When `groups == in_channels` and `out_channels == K * in_channels`, | |||||
where K is a positive integer, this operation is also known as depthwise | |||||
convolution. | |||||
In other words, for an input of size :math:`(N, C_{in}, T_{int}, H_{in}, W_{in})`, | |||||
a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||||
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||||
:param in_channels: number of input channels. | |||||
:param out_channels: number of output channels. | |||||
:param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||||
an :class:`int`, the actual kernel size would be | |||||
`(kernel_size, kernel_size, kernel_size)`. Default: 1 | |||||
:param stride: stride of the 3D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
and there would be an extra dimension at the beginning of the weight's | |||||
shape. Specifically, the shape of weight would be `(groups, | |||||
out_channel // groups, in_channels // groups, *kernel_size)`. | |||||
:param bias: whether to add a bias onto the result of convolution. Default: | |||||
True | |||||
:param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
`CROSS_CORRELATION` | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.module as M | |||||
m = M.Conv3d(in_channels=3, out_channels=1, kernel_size=3) | |||||
inp = mge.tensor(np.arange(0, 384).astype("float32").reshape(2, 3, 4, 4, 4)) | |||||
oup = m(inp) | |||||
print(oup.numpy().shape) | |||||
Outputs: | |||||
.. testoutput:: | |||||
(2, 1, 2, 2, 2) | |||||
""" | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
kernel_size: Union[int, Tuple[int, int, int]], | |||||
stride: Union[int, Tuple[int, int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
groups: int = 1, | |||||
bias: bool = True, | |||||
conv_mode: str = "CROSS_CORRELATION", | |||||
): | |||||
kernel_size = _triple_nonzero(kernel_size) | |||||
stride = _triple_nonzero(stride) | |||||
padding = _triple(padding) | |||||
dilation = _triple_nonzero(dilation) | |||||
self.conv_mode = conv_mode | |||||
super().__init__( | |||||
in_channels, | |||||
out_channels, | |||||
kernel_size, | |||||
stride, | |||||
padding, | |||||
dilation, | |||||
groups, | |||||
bias, | |||||
) | |||||
def _get_fanin(self): | |||||
kt, kh, kw = self.kernel_size | |||||
ic = self.in_channels | |||||
return kt * kh * kw * ic | |||||
def _infer_weight_shape(self): | |||||
group = self.groups | |||||
ichl = self.in_channels | |||||
ochl = self.out_channels | |||||
kt, kh, kw = self.kernel_size | |||||
if group == 1: | |||||
# Assume format is NCTHW | |||||
return (ochl, ichl, kt, kh, kw) | |||||
assert ( | |||||
ichl % group == 0 and ochl % group == 0 | |||||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
ichl, ochl, group | |||||
) | |||||
# Assume format is NCTHW | |||||
return (group, ochl // group, ichl // group, kt, kh, kw) | |||||
def _infer_bias_shape(self): | |||||
# Assume format is NCTHW | |||||
return (1, self.out_channels, 1, 1, 1) | |||||
def calc_conv(self, inp, weight, bias): | |||||
return conv3d( | |||||
inp, | |||||
weight, | |||||
bias, | |||||
self.stride, | |||||
self.padding, | |||||
self.dilation, | |||||
self.groups, | |||||
self.conv_mode, | |||||
) | |||||
def forward(self, inp): | |||||
return self.calc_conv(inp, self.weight, self.bias) | |||||
class ConvTranspose2d(_ConvNd): | class ConvTranspose2d(_ConvNd): | ||||
r""" | r""" | ||||
Applies a 2D transposed convolution over an input tensor. | Applies a 2D transposed convolution over an input tensor. | ||||
@@ -35,4 +35,5 @@ _single = functools.partial(get_ndtuple, n=1, allow_zero=True) | |||||
_pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | _pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | ||||
_pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | ||||
_triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | _triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | ||||
_triple_nonzero = functools.partial(get_ndtuple, n=3, allow_zero=False) | |||||
_quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) | _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) |
@@ -637,7 +637,7 @@ def test_batch_conv_bias(): | |||||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | ||||
def test_zero_stride_numpy_array(): | |||||
def test_conv2d_zero_stride_numpy_array(): | |||||
inp = np.random.randn(3, 224, 224).astype(np.float32) | inp = np.random.randn(3, 224, 224).astype(np.float32) | ||||
inp = inp[np.newaxis, :] | inp = inp[np.newaxis, :] | ||||
@@ -646,6 +646,16 @@ def test_zero_stride_numpy_array(): | |||||
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | ||||
def test_conv3d_zero_stride_numpy_array(): | |||||
inp = np.random.randn(3, 224, 224, 224).astype(np.float32) | |||||
inp = inp[np.newaxis, :] | |||||
inp = tensor(inp, dtype=np.float32) | |||||
weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32) | |||||
out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1) | |||||
out.numpy() | |||||
def test_conv1d(): | def test_conv1d(): | ||||
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | ||||
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | ||||
@@ -658,6 +668,16 @@ def test_conv1d(): | |||||
) | ) | ||||
def test_conv3d(): | |||||
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) | |||||
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) | |||||
out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | |||||
print(out.numpy().shape) | |||||
np.testing.assert_equal( | |||||
out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16 | |||||
) | |||||
def test_condtake(): | def test_condtake(): | ||||
x = np.array([[1, 2, 3], [4, 5, 6]]) | x = np.array([[1, 2, 3], [4, 5, 6]]) | ||||
y = np.array([[True, False, True], [False, True, True]]) | y = np.array([[True, False, True], [False, True, True]]) | ||||
@@ -0,0 +1,79 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/dnn/convolution.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { namespace convolution { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||||
return Convolution::make(node->param(), node->execution_policy()); | |||||
} | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const Convolution&>(def); | |||||
OperatorNodeConfig config{conv.make_name()}; | |||||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
} | |||||
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // convolution | |||||
namespace { namespace convolution_backward_data { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||||
OperatorNodeConfig config{conv.make_name()}; | |||||
if (inputs.size() == 2) { | |||||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
} else { | |||||
mgb_assert(inputs.size() == 3); | |||||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||||
} | |||||
} | |||||
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // convolution_backward_data | |||||
namespace { namespace convolution3d { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::Convolution3D>(); | |||||
return Convolution3D::make(node->param(), node->execution_policy()); | |||||
} | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const Convolution3D&>(def); | |||||
return opr::Convolution3D::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||||
} | |||||
OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // convolution3d | |||||
} | |||||
} |
@@ -36,45 +36,6 @@ | |||||
namespace mgb::imperative { | namespace mgb::imperative { | ||||
namespace { namespace convolution { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||||
return Convolution::make(node->param(), node->execution_policy()); | |||||
} | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const Convolution&>(def); | |||||
OperatorNodeConfig config{conv.make_name()}; | |||||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
} | |||||
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // convolution | |||||
namespace { namespace convolution_backward_data { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||||
OperatorNodeConfig config{conv.make_name()}; | |||||
if (inputs.size() == 2) { | |||||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
} else { | |||||
mgb_assert(inputs.size() == 3); | |||||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||||
} | |||||
} | |||||
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // convolution_backward_data | |||||
namespace { namespace dimshuffle { | namespace { namespace dimshuffle { | ||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | ||||
auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | ||||
@@ -51,6 +51,8 @@ def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolic | |||||
def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | ||||
def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | |||||
def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | ||||
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | ||||