GitOrigin-RevId: 501cadda76
tags/v1.3.0
@@ -44,6 +44,8 @@ __all__ = [ | |||||
"batch_norm", | "batch_norm", | ||||
"conv2d", | "conv2d", | ||||
"conv_transpose2d", | "conv_transpose2d", | ||||
"deformable_conv2d", | |||||
"deformable_psroi_pooling", | |||||
"dot", | "dot", | ||||
"dropout", | "dropout", | ||||
"indexing_one_hot", | "indexing_one_hot", | ||||
@@ -119,7 +121,8 @@ def conv2d( | |||||
:param padding: size of the paddings added to the input on both sides of its | :param padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | :param dilation: dilation of the 2D convolution operation. Default: 1 | ||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
:param groups: number of groups into which the input and output channels are divided, | |||||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
in_channels // groups, height, width)`. | in_channels // groups, height, width)`. | ||||
@@ -141,7 +144,6 @@ def conv2d( | |||||
pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
dilate_h, dilate_w = expand_hw(dilation) | dilate_h, dilate_w = expand_hw(dilation) | ||||
Sparse = builtin.Convolution.Sparse | |||||
sparse_type = "DENSE" if groups == 1 else "GROUP" | sparse_type = "DENSE" if groups == 1 else "GROUP" | ||||
op = builtin.Convolution( | op = builtin.Convolution( | ||||
stride_h=stride_h, | stride_h=stride_h, | ||||
@@ -185,7 +187,8 @@ def conv_transpose2d( | |||||
:param padding: size of the paddings added to the input on both sides of its | :param padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | :param dilation: dilation of the 2D convolution operation. Default: 1 | ||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
:param groups: number of groups into which the input and output channels are divided, | |||||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by groups, | ``in_channels`` and ``out_channels`` must be divisible by groups, | ||||
and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
in_channels // groups, height, width)`. Default: 1 | in_channels // groups, height, width)`. Default: 1 | ||||
@@ -226,6 +229,74 @@ def conv_transpose2d( | |||||
return output | return output | ||||
def deformable_conv2d( | |||||
inp: Tensor, | |||||
weight: Tensor, | |||||
offset: Tensor, | |||||
mask: Tensor, | |||||
bias: Optional[Tensor] = None, | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
conv_mode="CROSS_CORRELATION", | |||||
compute_mode="DEFAULT", | |||||
) -> Tensor: | |||||
""" | |||||
Deformable Convolution. | |||||
:param inp: input feature map. | |||||
:param weight: convolution kernel. | |||||
:param offset: input offset to kernel, channel of this tensor should match the deformable settings. | |||||
:param mask: input mask to kernel, channel of this tensor should match the deformable settings. | |||||
:param bias: bias added to the result of convolution (if given). | |||||
:param stride: stride of the 2D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
:param groups: number of groups into which the input and output channels are divided, | |||||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by groups, | |||||
and the shape of weight should be `(groups, out_channel // groups, | |||||
in_channels // groups, height, width)`. Default: 1 | |||||
:type conv_mode: string or :class:`Convolution.Mode` | |||||
:param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
"CROSS_CORRELATION" | |||||
:type compute_mode: string or | |||||
:class:`Convolution.ComputeMode` | |||||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
placed on the precision of intermediate results. When set to "FLOAT32", | |||||
"Float32" would be used for accumulator and intermediate result, but only | |||||
effective when input and output are of Float16 dtype. | |||||
:return: output tensor. | |||||
""" | |||||
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||||
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||||
stride_h, stride_w = expand_hw(stride) | |||||
pad_h, pad_w = expand_hw(padding) | |||||
dilate_h, dilate_w = expand_hw(dilation) | |||||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
op = builtin.DeformableConv( | |||||
stride_h=stride_h, | |||||
stride_w=stride_w, | |||||
pad_h=pad_h, | |||||
pad_w=pad_w, | |||||
dilate_h=dilate_h, | |||||
dilate_w=dilate_w, | |||||
strategy=get_conv_execution_strategy(), | |||||
mode=conv_mode, | |||||
compute_mode=compute_mode, | |||||
sparse=sparse_type, | |||||
) | |||||
inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask) | |||||
(output,) = apply(op, inp, weight, offset, mask) | |||||
if bias is not None: | |||||
output += bias | |||||
return output | |||||
def local_conv2d( | def local_conv2d( | ||||
inp: Tensor, | inp: Tensor, | ||||
weight: Tensor, | weight: Tensor, | ||||
@@ -380,6 +451,45 @@ def adaptive_avg_pool2d( | |||||
return output | return output | ||||
def deformable_psroi_pooling( | |||||
inp: Tensor, | |||||
rois: Tensor, | |||||
trans: Tensor, | |||||
no_trans: bool, | |||||
part_size: int, | |||||
pooled_h: int, | |||||
pooled_w: int, | |||||
sample_per_part: int, | |||||
spatial_scale: float, | |||||
trans_std: float = 0.1, | |||||
): | |||||
""" | |||||
Deformable PSROI(Position Sensitive Region of Interest) Pooling. | |||||
:param inp: input feature map. | |||||
:param rois: the rois for feature pooling. | |||||
:param trans: input offset to psroi_pooling. | |||||
:param no_trans: check the phase of DeformablePSROIPooling. False to the | |||||
1st phase, True to the 2nd phase. | |||||
:param part_size: part size. | |||||
:param sample_per_part: sample points of each part. | |||||
:param pooled_shape: kernel shape of convolution. | |||||
:param spatial_scale: the spatial_scale w.r.t input image. | |||||
:param trans_std: multiplier used in 2nd phase. | |||||
""" | |||||
op = builtin.DeformablePSROIPooling( | |||||
no_trans=no_trans, | |||||
part_size=part_size, | |||||
pooled_h=pooled_h, | |||||
pooled_w=pooled_w, | |||||
sample_per_part=sample_per_part, | |||||
spatial_scale=spatial_scale, | |||||
trans_std=trans_std, | |||||
) | |||||
output, _ = apply(op, inp, rois, trans) | |||||
return output | |||||
def prelu(inp: Tensor, weight: Tensor) -> Tensor: | def prelu(inp: Tensor, weight: Tensor) -> Tensor: | ||||
r""" | r""" | ||||
Applies the element-wise PReLU function. | Applies the element-wise PReLU function. | ||||
@@ -12,8 +12,16 @@ from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | |||||
from .batch_matmul_activation import BatchMatMulActivation | from .batch_matmul_activation import BatchMatMulActivation | ||||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
from .concat import Concat | from .concat import Concat | ||||
from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||||
from .conv import ( | |||||
Conv1d, | |||||
Conv2d, | |||||
ConvRelu2d, | |||||
ConvTranspose2d, | |||||
DeformableConv2d, | |||||
LocalConv2d, | |||||
) | |||||
from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
from .deformable_psroi_pooling import DeformablePSROIPooling | |||||
from .dropout import Dropout | from .dropout import Dropout | ||||
from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
from .embedding import Embedding | from .embedding import Embedding | ||||
@@ -10,7 +10,14 @@ from typing import Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu | |||||
from ..functional import ( | |||||
conv1d, | |||||
conv2d, | |||||
conv_transpose2d, | |||||
deformable_conv2d, | |||||
local_conv2d, | |||||
relu, | |||||
) | |||||
from ..tensor import Parameter | from ..tensor import Parameter | ||||
from ..utils.tuple_function import _pair, _pair_nonzero | from ..utils.tuple_function import _pair, _pair_nonzero | ||||
from . import init | from . import init | ||||
@@ -121,7 +128,8 @@ class Conv1d(_ConvNd): | |||||
:param padding: size of the paddings added to the input on both sides of its | :param padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
:param dilation: dilation of the 1D convolution operation. Default: 1 | :param dilation: dilation of the 1D convolution operation. Default: 1 | ||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
:param groups: number of groups into which the input and output channels are divided, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
and there would be an extra dimension at the beginning of the weight's | and there would be an extra dimension at the beginning of the weight's | ||||
shape. Specifically, the shape of weight would be `(groups, | shape. Specifically, the shape of weight would be `(groups, | ||||
@@ -250,15 +258,16 @@ class Conv2d(_ConvNd): | |||||
In general, output feature maps' shapes can be inferred as follows: | In general, output feature maps' shapes can be inferred as follows: | ||||
input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})` | input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})` | ||||
output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where | output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where | ||||
.. math:: | .. math:: | ||||
\text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} - | \text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} - | ||||
\text{dilation[0]} * (\text{kernel_size[0]} - 1)}{\text{stride[0]}} + 1 \rfloor | |||||
\text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor | |||||
.. math:: | .. math:: | ||||
\text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} - | \text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} - | ||||
\text{dilation[1]} * (\text{kernel_size[1]} - 1)}{\text{stride[1]}} + 1 \rfloor | |||||
\text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor | |||||
When `groups == in_channels` and `out_channels == K * in_channels`, | When `groups == in_channels` and `out_channels == K * in_channels`, | ||||
where K is a positive integer, this operation is also known as depthwise | where K is a positive integer, this operation is also known as depthwise | ||||
@@ -277,7 +286,8 @@ class Conv2d(_ConvNd): | |||||
:param padding: size of the paddings added to the input on both sides of its | :param padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | :param dilation: dilation of the 2D convolution operation. Default: 1 | ||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
:param groups: number of groups into which the input and output channels are divided, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
and there would be an extra dimension at the beginning of the weight's | and there would be an extra dimension at the beginning of the weight's | ||||
shape. Specifically, the shape of weight would be `(groups, | shape. Specifically, the shape of weight would be `(groups, | ||||
@@ -406,7 +416,8 @@ class ConvTranspose2d(_ConvNd): | |||||
:param padding: size of the paddings added to the input on both sides of its | :param padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | :param dilation: dilation of the 2D convolution operation. Default: 1 | ||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
:param groups: number of groups into which the input and output channels are divided, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
and there would be an extra dimension at the beginning of the weight's | and there would be an extra dimension at the beginning of the weight's | ||||
shape. Specifically, the shape of weight would be ``(groups, | shape. Specifically, the shape of weight would be ``(groups, | ||||
@@ -579,3 +590,107 @@ class ConvRelu2d(Conv2d): | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return relu(self.calc_conv(inp, self.weight, self.bias)) | return relu(self.calc_conv(inp, self.weight, self.bias)) | ||||
class DeformableConv2d(_ConvNd): | |||||
""" | |||||
Deformable Convolution. | |||||
:param in_channels: number of input channels. | |||||
:param out_channels: number of output channels. | |||||
:param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||||
an :class:`int`, the actual kernel size would be | |||||
`(kernel_size, kernel_size)`. Default: 1 | |||||
:param stride: stride of the 2D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
:param groups: number of groups into which the input and output channels are divided, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
and there would be an extra dimension at the beginning of the weight's | |||||
shape. Specifically, the shape of weight would be `(groups, | |||||
out_channel // groups, in_channels // groups, *kernel_size)`. | |||||
:param bias: whether to add a bias onto the result of convolution. Default: | |||||
True | |||||
:param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
`CROSS_CORRELATION` | |||||
:param compute_mode: When set to "DEFAULT", no special requirements will be | |||||
placed on the precision of intermediate results. When set to "FLOAT32", | |||||
"Float32" would be used for accumulator and intermediate result, but only | |||||
effective when input and output are of float16 dtype. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
kernel_size: Union[int, Tuple[int, int]], | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
bias: bool = True, | |||||
conv_mode: str = "CROSS_CORRELATION", | |||||
compute_mode: str = "DEFAULT", | |||||
): | |||||
kernel_size = _pair_nonzero(kernel_size) | |||||
stride = _pair_nonzero(stride) | |||||
padding = _pair(padding) | |||||
dilation = _pair_nonzero(dilation) | |||||
self.conv_mode = conv_mode | |||||
self.compute_mode = compute_mode | |||||
super().__init__( | |||||
in_channels, | |||||
out_channels, | |||||
kernel_size, | |||||
stride, | |||||
padding, | |||||
dilation, | |||||
groups, | |||||
bias, | |||||
) | |||||
def _get_fanin(self): | |||||
kh, kw = self.kernel_size | |||||
ic = self.in_channels | |||||
return kh * kw * ic | |||||
def _infer_weight_shape(self): | |||||
group = self.groups | |||||
ichl = self.in_channels | |||||
ochl = self.out_channels | |||||
kh, kw = self.kernel_size | |||||
if group == 1: | |||||
# Assume format is NCHW | |||||
return (ochl, ichl, kh, kw) | |||||
assert ( | |||||
ichl % group == 0 and ochl % group == 0 | |||||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
ichl, ochl, group | |||||
) | |||||
# Assume format is NCHW | |||||
return (group, ochl // group, ichl // group, kh, kw) | |||||
def _infer_bias_shape(self): | |||||
# Assume format is NCHW | |||||
return (1, self.out_channels, 1, 1) | |||||
def calc_conv(self, inp, weight, offset, mask, bias): | |||||
return deformable_conv2d( | |||||
inp, | |||||
weight, | |||||
offset, | |||||
mask, | |||||
bias, | |||||
self.stride, | |||||
self.padding, | |||||
self.dilation, | |||||
self.groups, | |||||
self.conv_mode, | |||||
self.compute_mode, | |||||
) | |||||
def forward(self, inp, offset, mask): | |||||
return self.calc_conv(inp, self.weight, offset, mask, self.bias) |
@@ -0,0 +1,46 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..functional import deformable_psroi_pooling | |||||
from .module import Module | |||||
class DeformablePSROIPooling(Module): | |||||
def __init__( | |||||
self, | |||||
no_trans, | |||||
part_size, | |||||
pooled_h, | |||||
pooled_w, | |||||
sample_per_part, | |||||
spatial_scale, | |||||
trans_std: float = 0.1, | |||||
): | |||||
super().__init__() | |||||
self.no_trans = no_trans | |||||
self.part_size = part_size | |||||
self.pooled_h = pooled_h | |||||
self.pooled_w = pooled_w | |||||
self.sample_per_part = sample_per_part | |||||
self.spatial_scale = spatial_scale | |||||
self.trans_std = trans_std | |||||
def forward(self, inp, rois, trans): | |||||
return deformable_psroi_pooling( | |||||
inp, | |||||
rois, | |||||
trans, | |||||
self.no_trans, | |||||
self.part_size, | |||||
self.pooled_h, | |||||
self.pooled_w, | |||||
self.sample_per_part, | |||||
self.spatial_scale, | |||||
self.trans_std, | |||||
) |
@@ -703,6 +703,33 @@ def test_argmxx_on_inf(): | |||||
assert all(run_argmin() >= 0) | assert all(run_argmin() >= 0) | ||||
def test_deformable_psroi_pooling(): | |||||
inp = np.random.random((1, 256, 64, 64)).astype("float32") | |||||
rois = np.random.random((1, 5)).astype("float32") | |||||
trans = np.random.random((24, 2, 7, 7)).astype("float32") | |||||
pooled_h = 7 | |||||
pooled_w = 7 | |||||
sample_per_part = 4 | |||||
no_trans = False | |||||
part_size = 7 | |||||
spatial_scale = 1.0 / 64 | |||||
trans_std = 0.1 | |||||
y = F.deformable_psroi_pooling( | |||||
tensor(inp), | |||||
tensor(rois), | |||||
tensor(trans), | |||||
no_trans, | |||||
part_size, | |||||
pooled_h, | |||||
pooled_w, | |||||
sample_per_part, | |||||
spatial_scale, | |||||
trans_std, | |||||
) | |||||
def test_cvt_color(): | def test_cvt_color(): | ||||
def rgb2gray(rgb): | def rgb2gray(rgb): | ||||
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) | return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) | ||||
@@ -0,0 +1,38 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/deformable_conv2d.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative { | |||||
namespace { namespace deformableconv { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::DeformableConv>(); | |||||
return DeformableConv::make(node->param(), node->execution_policy()); | |||||
} | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& dcn = static_cast<const DeformableConv&>(def); | |||||
mgb_assert(inputs.size() == 4); | |||||
return opr::DeformableConv::make(inputs[0], inputs[1], inputs[2], inputs[3], dcn.param(), dcn.policy()); | |||||
} | |||||
OP_TRAIT_REG(DeformableConv, DeformableConv, opr::DeformableConv) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // deformableconv | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/deformable_psroi_pooling.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative { | |||||
namespace { namespace deformable_psroi_pooling { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
mgb_assert(inputs.size() == 3); | |||||
auto&& op = static_cast<const DeformablePSROIPooling&>(def); | |||||
return opr::DeformablePSROIPooling::make_all(inputs[0], inputs[1], inputs[2], op.param()); | |||||
} | |||||
OP_TRAIT_REG(DeformablePSROIPooling, DeformablePSROIPooling) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // deformable_psroi_pooling | |||||
} // namespace mgb::imperative |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file imperative/src/impl/ops/autogen.cpp | |||||
* \file imperative/src/impl/ops/specialzations.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -48,6 +48,8 @@ def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolic | |||||
def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | ||||
def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||||
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | ||||
def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; | def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; | ||||
@@ -56,6 +58,8 @@ def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | |||||
def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>; | def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>; | ||||
def DeformablePSROIPooling : MgbHashableOp<"DeformablePSROIPooling", [DeformablePSROIPoolingParam]>; | |||||
def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> { | def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> { | ||||
let extraArguments = (ins | let extraArguments = (ins | ||||
MgbDTypeAttr:$dtype | MgbDTypeAttr:$dtype | ||||