Browse Source

feat(functional): add conv3d

GitOrigin-RevId: c06ac77c6c
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
d970b85deb
8 changed files with 308 additions and 45 deletions
  1. +66
    -4
      imperative/python/megengine/functional/nn.py
  2. +1
    -0
      imperative/python/megengine/module/__init__.py
  3. +138
    -1
      imperative/python/megengine/module/conv.py
  4. +1
    -0
      imperative/python/megengine/utils/tuple_function.py
  5. +21
    -1
      imperative/python/test/unit/functional/test_functional.py
  6. +79
    -0
      imperative/src/impl/ops/convolution.cpp
  7. +0
    -39
      imperative/src/impl/ops/specializations.cpp
  8. +2
    -0
      src/core/include/megbrain/ir/ops.td

+ 66
- 4
imperative/python/megengine/functional/nn.py View File

@@ -22,8 +22,8 @@ from ..device import get_default_device
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..random import uniform from ..random import uniform
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy, get_execution_strategy
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum from .elemwise import exp, floor, log, log1p, maximum, minimum
from .math import argsort, matmul, max, prod, sum from .math import argsort, matmul, max, prod, sum
@@ -43,7 +43,9 @@ __all__ = [
"adaptive_max_pool2d", "adaptive_max_pool2d",
"avg_pool2d", "avg_pool2d",
"batch_norm", "batch_norm",
"conv1d",
"conv2d", "conv2d",
"conv3d",
"conv_transpose2d", "conv_transpose2d",
"deformable_conv2d", "deformable_conv2d",
"deformable_psroi_pooling", "deformable_psroi_pooling",
@@ -166,6 +168,66 @@ def conv2d(
return output return output




def conv3d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
) -> Tensor:
"""
3D convolution operation.

Refer to :class:`~.Conv3d` for more information.

:param inp: feature map of the convolution operation.
:param weight: convolution kernel.
:param bias: bias added to the result of convolution (if given).
:param stride: stride of the 3D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, t, height, width)`.
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:return: output tensor.
"""
assert conv_mode == "CROSS_CORRELATION"

D, H, W = 0, 1, 2

pad = _triple(padding)
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)

sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.Convolution3D(
pad_d=pad[D],
pad_h=pad[H],
pad_w=pad[W],
stride_d=stride[D],
stride_h=stride[H],
stride_w=stride[W],
dilate_d=dilate[D],
dilate_h=dilate[H],
dilate_w=dilate[W],
strategy=get_execution_strategy(),
mode=conv_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
return output


def conv_transpose2d( def conv_transpose2d(
inp: Tensor, inp: Tensor,
weight: Tensor, weight: Tensor,
@@ -1094,7 +1156,7 @@ def matmul(
transposeB=transpose_b, transposeB=transpose_b,
compute_mode=compute_mode, compute_mode=compute_mode,
format=format, format=format,
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
) )
else: else:
op = builtin.MatrixMul( op = builtin.MatrixMul(
@@ -1102,7 +1164,7 @@ def matmul(
transposeB=transpose_b, transposeB=transpose_b,
compute_mode=compute_mode, compute_mode=compute_mode,
format=format, format=format,
strategy=get_conv_execution_strategy(),
strategy=get_execution_strategy(),
) )


(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)


+ 1
- 0
imperative/python/megengine/module/__init__.py View File

@@ -15,6 +15,7 @@ from .concat import Concat
from .conv import ( from .conv import (
Conv1d, Conv1d,
Conv2d, Conv2d,
Conv3d,
ConvRelu2d, ConvRelu2d,
ConvTranspose2d, ConvTranspose2d,
DeformableConv2d, DeformableConv2d,


+ 138
- 1
imperative/python/megengine/module/conv.py View File

@@ -13,13 +13,14 @@ import numpy as np
from ..functional import ( from ..functional import (
conv1d, conv1d,
conv2d, conv2d,
conv3d,
conv_transpose2d, conv_transpose2d,
deformable_conv2d, deformable_conv2d,
local_conv2d, local_conv2d,
relu, relu,
) )
from ..tensor import Parameter from ..tensor import Parameter
from ..utils.tuple_function import _pair, _pair_nonzero
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from . import init from . import init
from .module import Module from .module import Module


@@ -400,6 +401,142 @@ class Conv2d(_ConvNd):
return self.calc_conv(inp, self.weight, self.bias) return self.calc_conv(inp, self.weight, self.bias)




class Conv3d(_ConvNd):

r"""
Applies a 3D convolution over an input tensor.

For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`,
this layer generates an output of the size
:math:`(N, C_{\text{out}}, T_{\text{out}}}, H_{\text{out}}}, W_{\text{out}}})` through the
process described as below:

.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)

where :math:`\star` is the valid 3D cross-correlation operator,
:math:`N` is batch size, :math:`C` denotes number of channels


When `groups == in_channels` and `out_channels == K * in_channels`,
where K is a positive integer, this operation is also known as depthwise
convolution.

In other words, for an input of size :math:`(N, C_{in}, T_{int}, H_{in}, W_{in})`,
a depthwise convolution with a depthwise multiplier `K`, can be constructed
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.

:param in_channels: number of input channels.
:param out_channels: number of output channels.
:param kernel_size: size of weight on spatial dimensions. If kernel_size is
an :class:`int`, the actual kernel size would be
`(kernel_size, kernel_size, kernel_size)`. Default: 1
:param stride: stride of the 3D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be `(groups,
out_channel // groups, in_channels // groups, *kernel_size)`.
:param bias: whether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`

Examples:

.. testcode::

import numpy as np
import megengine as mge
import megengine.module as M

m = M.Conv3d(in_channels=3, out_channels=1, kernel_size=3)
inp = mge.tensor(np.arange(0, 384).astype("float32").reshape(2, 3, 4, 4, 4))
oup = m(inp)
print(oup.numpy().shape)

Outputs:

.. testoutput::

(2, 1, 2, 2, 2)

"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
):
kernel_size = _triple_nonzero(kernel_size)
stride = _triple_nonzero(stride)
padding = _triple(padding)
dilation = _triple_nonzero(dilation)
self.conv_mode = conv_mode
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)

def _get_fanin(self):
kt, kh, kw = self.kernel_size
ic = self.in_channels
return kt * kh * kw * ic

def _infer_weight_shape(self):
group = self.groups
ichl = self.in_channels
ochl = self.out_channels
kt, kh, kw = self.kernel_size
if group == 1:
# Assume format is NCTHW
return (ochl, ichl, kt, kh, kw)

assert (
ichl % group == 0 and ochl % group == 0
), "invalid config: input_channels={} output_channels={} group={}".format(
ichl, ochl, group
)
# Assume format is NCTHW
return (group, ochl // group, ichl // group, kt, kh, kw)

def _infer_bias_shape(self):
# Assume format is NCTHW
return (1, self.out_channels, 1, 1, 1)

def calc_conv(self, inp, weight, bias):
return conv3d(
inp,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.conv_mode,
)

def forward(self, inp):
return self.calc_conv(inp, self.weight, self.bias)


class ConvTranspose2d(_ConvNd): class ConvTranspose2d(_ConvNd):
r""" r"""
Applies a 2D transposed convolution over an input tensor. Applies a 2D transposed convolution over an input tensor.


+ 1
- 0
imperative/python/megengine/utils/tuple_function.py View File

@@ -35,4 +35,5 @@ _single = functools.partial(get_ndtuple, n=1, allow_zero=True)
_pair = functools.partial(get_ndtuple, n=2, allow_zero=True) _pair = functools.partial(get_ndtuple, n=2, allow_zero=True)
_pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False)
_triple = functools.partial(get_ndtuple, n=3, allow_zero=True) _triple = functools.partial(get_ndtuple, n=3, allow_zero=True)
_triple_nonzero = functools.partial(get_ndtuple, n=3, allow_zero=False)
_quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True)

+ 21
- 1
imperative/python/test/unit/functional/test_functional.py View File

@@ -637,7 +637,7 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)




def test_zero_stride_numpy_array():
def test_conv2d_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224).astype(np.float32) inp = np.random.randn(3, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :] inp = inp[np.newaxis, :]


@@ -646,6 +646,16 @@ def test_zero_stride_numpy_array():
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)




def test_conv3d_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :]

inp = tensor(inp, dtype=np.float32)
weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32)
out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1)
out.numpy()


def test_conv1d(): def test_conv1d():
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4))
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2))
@@ -658,6 +668,16 @@ def test_conv1d():
) )




def test_conv3d():
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4))
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2))
out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
print(out.numpy().shape)
np.testing.assert_equal(
out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
)


def test_condtake(): def test_condtake():
x = np.array([[1, 2, 3], [4, 5, 6]]) x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.array([[True, False, True], [False, True, True]]) y = np.array([[True, False, True], [False, True, True]])


+ 79
- 0
imperative/src/impl/ops/convolution.cpp View File

@@ -0,0 +1,79 @@
/**
* \file imperative/src/impl/ops/dnn/convolution.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/convolution.h"

#include "../op_trait.h"

namespace mgb {
namespace imperative {

namespace { namespace convolution {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Convolution>();
return Convolution::make(node->param(), node->execution_policy());
}

auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution&>(def);
OperatorNodeConfig config{conv.make_name()};
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
}

OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution

namespace { namespace convolution_backward_data {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()};
if (inputs.size() == 2) {
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else {
mgb_assert(inputs.size() == 3);
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
}
}

OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution_backward_data

namespace { namespace convolution3d {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Convolution3D>();
return Convolution3D::make(node->param(), node->execution_policy());
}

auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution3D&>(def);
return opr::Convolution3D::make(inputs[0], inputs[1], conv.param(), conv.policy());
}

OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution3d

}
}

+ 0
- 39
imperative/src/impl/ops/specializations.cpp View File

@@ -36,45 +36,6 @@


namespace mgb::imperative { namespace mgb::imperative {


namespace { namespace convolution {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Convolution>();
return Convolution::make(node->param(), node->execution_policy());
}

auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution&>(def);
OperatorNodeConfig config{conv.make_name()};
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
}

OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution

namespace { namespace convolution_backward_data {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()};
if (inputs.size() == 2) {
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else {
mgb_assert(inputs.size() == 3);
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
}
}

OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution_backward_data

namespace { namespace dimshuffle { namespace { namespace dimshuffle {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); auto* node = &node_->cast_final_safe<opr::Dimshuffle>();


+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -51,6 +51,8 @@ def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolic


def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;


def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;

def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;


def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;


Loading…
Cancel
Save