GitOrigin-RevId: 501cadda76
tags/v1.3.0
@@ -44,6 +44,8 @@ __all__ = [ | |||
"batch_norm", | |||
"conv2d", | |||
"conv_transpose2d", | |||
"deformable_conv2d", | |||
"deformable_psroi_pooling", | |||
"dot", | |||
"dropout", | |||
"indexing_one_hot", | |||
@@ -119,7 +121,8 @@ def conv2d( | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
:param groups: number of groups into which the input and output channels are divided, | |||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. | |||
@@ -141,7 +144,6 @@ def conv2d( | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
Sparse = builtin.Convolution.Sparse | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.Convolution( | |||
stride_h=stride_h, | |||
@@ -185,7 +187,8 @@ def conv_transpose2d( | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
:param groups: number of groups into which the input and output channels are divided, | |||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by groups, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. Default: 1 | |||
@@ -226,6 +229,74 @@ def conv_transpose2d( | |||
return output | |||
def deformable_conv2d( | |||
inp: Tensor, | |||
weight: Tensor, | |||
offset: Tensor, | |||
mask: Tensor, | |||
bias: Optional[Tensor] = None, | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
conv_mode="CROSS_CORRELATION", | |||
compute_mode="DEFAULT", | |||
) -> Tensor: | |||
""" | |||
Deformable Convolution. | |||
:param inp: input feature map. | |||
:param weight: convolution kernel. | |||
:param offset: input offset to kernel, channel of this tensor should match the deformable settings. | |||
:param mask: input mask to kernel, channel of this tensor should match the deformable settings. | |||
:param bias: bias added to the result of convolution (if given). | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, | |||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by groups, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. Default: 1 | |||
:type conv_mode: string or :class:`Convolution.Mode` | |||
:param conv_mode: supports "CROSS_CORRELATION". Default: | |||
"CROSS_CORRELATION" | |||
:type compute_mode: string or | |||
:class:`Convolution.ComputeMode` | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
"Float32" would be used for accumulator and intermediate result, but only | |||
effective when input and output are of Float16 dtype. | |||
:return: output tensor. | |||
""" | |||
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
dilate_h, dilate_w = expand_hw(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.DeformableConv( | |||
stride_h=stride_h, | |||
stride_w=stride_w, | |||
pad_h=pad_h, | |||
pad_w=pad_w, | |||
dilate_h=dilate_h, | |||
dilate_w=dilate_w, | |||
strategy=get_conv_execution_strategy(), | |||
mode=conv_mode, | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
) | |||
inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask) | |||
(output,) = apply(op, inp, weight, offset, mask) | |||
if bias is not None: | |||
output += bias | |||
return output | |||
def local_conv2d( | |||
inp: Tensor, | |||
weight: Tensor, | |||
@@ -380,6 +451,45 @@ def adaptive_avg_pool2d( | |||
return output | |||
def deformable_psroi_pooling( | |||
inp: Tensor, | |||
rois: Tensor, | |||
trans: Tensor, | |||
no_trans: bool, | |||
part_size: int, | |||
pooled_h: int, | |||
pooled_w: int, | |||
sample_per_part: int, | |||
spatial_scale: float, | |||
trans_std: float = 0.1, | |||
): | |||
""" | |||
Deformable PSROI(Position Sensitive Region of Interest) Pooling. | |||
:param inp: input feature map. | |||
:param rois: the rois for feature pooling. | |||
:param trans: input offset to psroi_pooling. | |||
:param no_trans: check the phase of DeformablePSROIPooling. False to the | |||
1st phase, True to the 2nd phase. | |||
:param part_size: part size. | |||
:param sample_per_part: sample points of each part. | |||
:param pooled_shape: kernel shape of convolution. | |||
:param spatial_scale: the spatial_scale w.r.t input image. | |||
:param trans_std: multiplier used in 2nd phase. | |||
""" | |||
op = builtin.DeformablePSROIPooling( | |||
no_trans=no_trans, | |||
part_size=part_size, | |||
pooled_h=pooled_h, | |||
pooled_w=pooled_w, | |||
sample_per_part=sample_per_part, | |||
spatial_scale=spatial_scale, | |||
trans_std=trans_std, | |||
) | |||
output, _ = apply(op, inp, rois, trans) | |||
return output | |||
def prelu(inp: Tensor, weight: Tensor) -> Tensor: | |||
r""" | |||
Applies the element-wise PReLU function. | |||
@@ -12,8 +12,16 @@ from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | |||
from .batch_matmul_activation import BatchMatMulActivation | |||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from .concat import Concat | |||
from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
from .conv import ( | |||
Conv1d, | |||
Conv2d, | |||
ConvRelu2d, | |||
ConvTranspose2d, | |||
DeformableConv2d, | |||
LocalConv2d, | |||
) | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .deformable_psroi_pooling import DeformablePSROIPooling | |||
from .dropout import Dropout | |||
from .elemwise import Elemwise | |||
from .embedding import Embedding | |||
@@ -10,7 +10,14 @@ from typing import Tuple, Union | |||
import numpy as np | |||
from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu | |||
from ..functional import ( | |||
conv1d, | |||
conv2d, | |||
conv_transpose2d, | |||
deformable_conv2d, | |||
local_conv2d, | |||
relu, | |||
) | |||
from ..tensor import Parameter | |||
from ..utils.tuple_function import _pair, _pair_nonzero | |||
from . import init | |||
@@ -121,7 +128,8 @@ class Conv1d(_ConvNd): | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 1D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
:param groups: number of groups into which the input and output channels are divided, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be `(groups, | |||
@@ -250,15 +258,16 @@ class Conv2d(_ConvNd): | |||
In general, output feature maps' shapes can be inferred as follows: | |||
input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})` | |||
output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where | |||
.. math:: | |||
\text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} - | |||
\text{dilation[0]} * (\text{kernel_size[0]} - 1)}{\text{stride[0]}} + 1 \rfloor | |||
\text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor | |||
.. math:: | |||
\text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} - | |||
\text{dilation[1]} * (\text{kernel_size[1]} - 1)}{\text{stride[1]}} + 1 \rfloor | |||
\text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor | |||
When `groups == in_channels` and `out_channels == K * in_channels`, | |||
where K is a positive integer, this operation is also known as depthwise | |||
@@ -277,7 +286,8 @@ class Conv2d(_ConvNd): | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
:param groups: number of groups into which the input and output channels are divided, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be `(groups, | |||
@@ -406,7 +416,8 @@ class ConvTranspose2d(_ConvNd): | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
:param groups: number of groups into which the input and output channels are divided, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be ``(groups, | |||
@@ -579,3 +590,107 @@ class ConvRelu2d(Conv2d): | |||
def forward(self, inp): | |||
return relu(self.calc_conv(inp, self.weight, self.bias)) | |||
class DeformableConv2d(_ConvNd): | |||
""" | |||
Deformable Convolution. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||
an :class:`int`, the actual kernel size would be | |||
`(kernel_size, kernel_size)`. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be `(groups, | |||
out_channel // groups, in_channels // groups, *kernel_size)`. | |||
:param bias: whether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||
`CROSS_CORRELATION` | |||
:param compute_mode: When set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
"Float32" would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
""" | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
): | |||
kernel_size = _pair_nonzero(kernel_size) | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = conv_mode | |||
self.compute_mode = compute_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
) | |||
def _get_fanin(self): | |||
kh, kw = self.kernel_size | |||
ic = self.in_channels | |||
return kh * kw * ic | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
ichl = self.in_channels | |||
ochl = self.out_channels | |||
kh, kw = self.kernel_size | |||
if group == 1: | |||
# Assume format is NCHW | |||
return (ochl, ichl, kh, kw) | |||
assert ( | |||
ichl % group == 0 and ochl % group == 0 | |||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||
ichl, ochl, group | |||
) | |||
# Assume format is NCHW | |||
return (group, ochl // group, ichl // group, kh, kw) | |||
def _infer_bias_shape(self): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def calc_conv(self, inp, weight, offset, mask, bias): | |||
return deformable_conv2d( | |||
inp, | |||
weight, | |||
offset, | |||
mask, | |||
bias, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
def forward(self, inp, offset, mask): | |||
return self.calc_conv(inp, self.weight, offset, mask, self.bias) |
@@ -0,0 +1,46 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..functional import deformable_psroi_pooling | |||
from .module import Module | |||
class DeformablePSROIPooling(Module): | |||
def __init__( | |||
self, | |||
no_trans, | |||
part_size, | |||
pooled_h, | |||
pooled_w, | |||
sample_per_part, | |||
spatial_scale, | |||
trans_std: float = 0.1, | |||
): | |||
super().__init__() | |||
self.no_trans = no_trans | |||
self.part_size = part_size | |||
self.pooled_h = pooled_h | |||
self.pooled_w = pooled_w | |||
self.sample_per_part = sample_per_part | |||
self.spatial_scale = spatial_scale | |||
self.trans_std = trans_std | |||
def forward(self, inp, rois, trans): | |||
return deformable_psroi_pooling( | |||
inp, | |||
rois, | |||
trans, | |||
self.no_trans, | |||
self.part_size, | |||
self.pooled_h, | |||
self.pooled_w, | |||
self.sample_per_part, | |||
self.spatial_scale, | |||
self.trans_std, | |||
) |
@@ -703,6 +703,33 @@ def test_argmxx_on_inf(): | |||
assert all(run_argmin() >= 0) | |||
def test_deformable_psroi_pooling(): | |||
inp = np.random.random((1, 256, 64, 64)).astype("float32") | |||
rois = np.random.random((1, 5)).astype("float32") | |||
trans = np.random.random((24, 2, 7, 7)).astype("float32") | |||
pooled_h = 7 | |||
pooled_w = 7 | |||
sample_per_part = 4 | |||
no_trans = False | |||
part_size = 7 | |||
spatial_scale = 1.0 / 64 | |||
trans_std = 0.1 | |||
y = F.deformable_psroi_pooling( | |||
tensor(inp), | |||
tensor(rois), | |||
tensor(trans), | |||
no_trans, | |||
part_size, | |||
pooled_h, | |||
pooled_w, | |||
sample_per_part, | |||
spatial_scale, | |||
trans_std, | |||
) | |||
def test_cvt_color(): | |||
def rgb2gray(rgb): | |||
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) | |||
@@ -0,0 +1,38 @@ | |||
/** | |||
* \file imperative/src/impl/ops/deformable_conv2d.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace { namespace deformableconv { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::DeformableConv>(); | |||
return DeformableConv::make(node->param(), node->execution_policy()); | |||
} | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& dcn = static_cast<const DeformableConv&>(def); | |||
mgb_assert(inputs.size() == 4); | |||
return opr::DeformableConv::make(inputs[0], inputs[1], inputs[2], inputs[3], dcn.param(), dcn.policy()); | |||
} | |||
OP_TRAIT_REG(DeformableConv, DeformableConv, opr::DeformableConv) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // deformableconv | |||
} // namespace mgb::imperative |
@@ -0,0 +1,32 @@ | |||
/** | |||
* \file imperative/src/impl/ops/deformable_psroi_pooling.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace { namespace deformable_psroi_pooling { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
mgb_assert(inputs.size() == 3); | |||
auto&& op = static_cast<const DeformablePSROIPooling&>(def); | |||
return opr::DeformablePSROIPooling::make_all(inputs[0], inputs[1], inputs[2], op.param()); | |||
} | |||
OP_TRAIT_REG(DeformablePSROIPooling, DeformablePSROIPooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // deformable_psroi_pooling | |||
} // namespace mgb::imperative |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file imperative/src/impl/ops/autogen.cpp | |||
* \file imperative/src/impl/ops/specialzations.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -48,6 +48,8 @@ def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolic | |||
def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | |||
def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; | |||
@@ -56,6 +58,8 @@ def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | |||
def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>; | |||
def DeformablePSROIPooling : MgbHashableOp<"DeformablePSROIPooling", [DeformablePSROIPoolingParam]>; | |||
def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> { | |||
let extraArguments = (ins | |||
MgbDTypeAttr:$dtype | |||