Browse Source

feat(imperative,mgb): add RegionRestrictedConvolution megbrain and imperative opr

GitOrigin-RevId: c0106ade08
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
9e020d23b4
16 changed files with 2140 additions and 7 deletions
  1. +58
    -2
      imperative/python/megengine/functional/nn.py
  2. +1
    -0
      imperative/python/megengine/module/__init__.py
  3. +172
    -0
      imperative/python/megengine/module/conv.py
  4. +173
    -0
      imperative/python/test/unit/functional/test_functional.py
  5. +171
    -0
      imperative/src/impl/ops/convolution.cpp
  6. +5
    -5
      imperative/tablegen/generated/hash.txt
  7. +304
    -0
      imperative/tablegen/generated/opdef.cpp.inl
  8. +576
    -0
      imperative/tablegen/generated/opdef.cpy.inl
  9. +52
    -0
      imperative/tablegen/generated/opdef.h.inl
  10. +46
    -0
      imperative/tablegen/generated/opdef.py.inl
  11. +5
    -0
      src/core/include/megbrain/ir/ops.td
  12. +272
    -0
      src/opr/impl/dnn/convolution.cpp
  13. +1
    -0
      src/opr/impl/dnn/dnn.sereg.h
  14. +28
    -0
      src/opr/impl/dnn/dnn.sereg.v2.h
  15. +80
    -0
      src/opr/include/megbrain/opr/dnn/convolution.h
  16. +196
    -0
      src/opr/test/dnn/region_restricted_convolution.cpp

+ 58
- 2
imperative/python/megengine/functional/nn.py View File

@@ -91,6 +91,7 @@ __all__ = [
"warp_affine", "warp_affine",
"warp_perspective", "warp_perspective",
"pixel_shuffle", "pixel_shuffle",
"region_restricted_conv",
] ]




@@ -1213,10 +1214,10 @@ def layer_norm(
): ):
r"""Applies layer normalization to the input. Support tensor of any shape as input. r"""Applies layer normalization to the input. Support tensor of any shape as input.
Reference: https://arxiv.org/pdf/1803.08494.pdf. Reference: https://arxiv.org/pdf/1803.08494.pdf.
Args: Args:
inp: input tensor. inp: input tensor.
normalized_shape: the shape that you want to be normalizated
normalized_shape: the shape that you want to be normalizated
affine: whether to use weight and bias affine: whether to use weight and bias
weight: must not be None when the affine is true weight: must not be None when the affine is true
bias: must not be None when the affine is true bias: must not be None when the affine is true
@@ -1974,6 +1975,61 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable) return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable)




def region_restricted_conv(
inp: Tensor,
weight: Tensor,
rin: Tensor,
rout: Tensor,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
conv_mode: str = "cross_correlation",
compute_mode="default",
) -> Tensor:
r"""Region Restricted convolution operation.

Refer to :class:`~.RegionRestrictedConv` for more information.

Args:
inp: feature map of the convolution operation.
weight: convolution kernel.
stride: stride of the 2D region restricted convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
dilation: dilation of the 2D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, depth, height, width)``. Default: 1
conv_mode: supports "cross_correlation". Default: "cross_correlation"

Returns:
output tensor.
"""
assert conv_mode.lower() == "cross_correlation"

pad_h, pad_w = _expand_hw(padding)
stride_h, stride_w = _expand_hw(stride)
dilate_h, dilate_w = _expand_hw(dilation)

sparse_type = "dense" if groups == 1 else "group"
op = builtin.RegionRestrictedConvolution(
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
(output,) = apply(op, inp, weight, rin, rout)
return output


from .quantized import conv_bias_activation # isort:skip from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip from .loss import * # isort:skip
from .vision import * # isort:skip from .vision import * # isort:skip

+ 1
- 0
imperative/python/megengine/module/__init__.py View File

@@ -14,6 +14,7 @@ from .conv import (
ConvTranspose3d, ConvTranspose3d,
DeformableConv2d, DeformableConv2d,
LocalConv2d, LocalConv2d,
RegionRestrictedConv,
) )
from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d
from .deformable_psroi_pooling import DeformablePSROIPooling from .deformable_psroi_pooling import DeformablePSROIPooling


+ 172
- 0
imperative/python/megengine/module/conv.py View File

@@ -12,6 +12,7 @@ from ..functional import (
deformable_conv2d, deformable_conv2d,
local_conv2d, local_conv2d,
pad, pad,
region_restricted_conv,
relu, relu,
) )
from ..tensor import Parameter from ..tensor import Parameter
@@ -982,3 +983,174 @@ class ConvTranspose3d(_ConvNd):
self.output_padding, self.output_padding,
self.dilation, self.dilation,
) )


class RegionRestrictedConv(_ConvNd):

r"""Applies a 2D RegionRestricted Convolution over an input tensor.

For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
this layer generates an output of the size
:math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
process described as below:

.. math::
\text{out}(N_i, C_{\text{out}_j}) =
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)

where :math:`\star` is the valid 2D cross-correlation operator,
:math:`N` is batch size, :math:`C` denotes number of channels,
:math:`H` is height of input planes in pixels, and :math:`W` is
width in pixels.

In general, output feature maps' shapes can be inferred as follows:

input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})`

output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where

.. math::
\text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
\text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor

.. math::
\text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
\text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor

When `groups == in_channels` and `out_channels == K * in_channels`,
where K is a positive integer, this operation is also known as depthwise
convolution.

In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
a depthwise convolution with a depthwise multiplier `K`, can be constructed
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.

Args:
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: size of weight on spatial dimensions. If kernel_size is
an :class:`int`, the actual kernel size would be
``(kernel_size, kernel_size)``.
stride: stride of the 2D convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its
spatial dimensions. Default: 0
dilation: dilation of the 2D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``. Default: 1
conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros".
Refer to :class:`~.module.padding.Pad` for more information.

Note:
* ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` ,
if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, height, width)``

Examples:
>>> import numpy as np
>>> import megengine as mge
>>> import megengine.module as M
>>> rrconv = M.RegionRestrictedConv(in_channels=2, out_channels=2, kernel_size=2, groups=2)
>>> inp = mge.tensor(np.random.randn(1, 2, 2, 2).astype(np.float32))
>>> rin = mge.tensor(np.random.randn(1, 2, 2).astype(np.int32))
>>> rout = mge.tensor(np.random.randn(1, 1, 1).astype(np.int32))
>>> oup = rrconv(inp, rin, rout)
>>> oup.numpy().shape
(1, 2, 1, 1)
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
groups: int,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
padding_mode: str = "zeros",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride)
padding = _pair(padding)
dilation = _pair_nonzero(dilation)
self.conv_mode = conv_mode
self.compute_mode = compute_mode
self.padding_mode = padding_mode
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
0,
dilation,
groups,
False,
**kwargs,
)

def _get_fanin(self):
kh, kw = self.kernel_size
ic = self.in_channels
return kh * kw * ic

def _infer_weight_shape(self):
group = self.groups
ichl = self.in_channels
ochl = self.out_channels
kh, kw = self.kernel_size
if group == 1:
# Assume format is NCHW
return (ochl, ichl, kh, kw)

assert (
ichl % group == 0 and ochl % group == 0
), "invalid config: in_channels={} out_channels={} group={}".format(
ichl, ochl, group
)
# Assume format is NCHW
return (group, ochl // group, ichl // group, kh, kw)

def _infer_bias_shape(self):
# Assume format is NCHW
return (1, self.out_channels, 1, 1)

def get_pad_width(self):
return (
(0, 0),
(0, 0),
(self.padding[0], self.padding[0]),
(self.padding[1], self.padding[1]),
)

def calc_conv(self, inp, weight, rin, rout):
assert self.padding_mode in [
"zeros",
"reflect",
"replicate",
]
return region_restricted_conv(
inp,
weight,
rin,
rout,
self.stride,
self.padding,
self.dilation,
self.groups,
self.conv_mode,
self.compute_mode,
)

def forward(self, inp, rin, rout):
return self.calc_conv(inp, self.weight, rin, rout)

+ 173
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -930,6 +930,179 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)




def test_region_restricted_conv_forward_backward_naive():
import megengine as mge
import megengine.module as M
from megengine.autodiff import GradManager

handle = "cpu0"
src_1 = np.arange(8).reshape(1, 2, 2, 2).astype(np.float32)
filter_1 = np.arange(8).reshape(2, 1, 1, 2, 2).astype(np.float32)
rin_1 = np.array([1, 1, 1, 1]).reshape(1, 2, 2).astype(np.int32)
rout_1 = np.array([1]).reshape(1, 1, 1).astype(np.int32)
cpu_src = tensor(src_1, device=handle)
cpu_filter = tensor(filter_1, device=handle)
gm = GradManager().attach([cpu_src, cpu_filter])
with gm:
cpu_out = F.region_restricted_conv(
cpu_src,
cpu_filter,
tensor(rin_1, device=handle),
tensor(rout_1, device=handle),
groups=2,
)
gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle))
np.testing.assert_allclose(
cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2)
)
np.testing.assert_allclose(
cpu_filter.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(2, 1, 1, 2, 2)
)


@pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
)
def test_region_restricted_conv_forward_backward_cuda():
import megengine as mge
import megengine.module as M
from megengine.autodiff import GradManager
import megengine.distributed as dist

# params
handle = "gpu0"
N = 1
GROUP = 3
FH = FW = 2
IH = IW = 2
OH = OW = 1
ICPG = OCPG = 1
grad_shape = (N, GROUP * ICPG, IH, IW)
src_shape = grad_shape
filter_shape = (GROUP, OCPG, ICPG, FH, FW)
diff_shape = (N, GROUP * OCPG, OH, OW)
rin_shape = (N, IH, IW)
rout_shape = (N, OH, OW)

def reduce(shape):
mul = 1
for x in shape:
mul *= x
return mul

def get_groundtruth():
src = tensor(
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
device="cpu0",
)
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
gm = GradManager().attach([src, filter])
with gm:
expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
)
gm.backward(
expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
)
return src, filter

expected_src, expected_filter = get_groundtruth()

src = tensor(
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
device=handle,
)
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle)
gm = GradManager().attach([src, filter])
with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle))
np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad)


@pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
)
def test_region_restricted_conv_forward_backward_uint8():
import megengine as mge
import megengine.module as M
from megengine.autodiff import GradManager

# params
handle = "gpu0"
N = 1
GROUP = 2
FH = FW = 1
IH = IW = 4
OH = OW = 4
ICPG = OCPG = 1
grad_shape = (N, GROUP * ICPG, IH, IW)
src_shape = grad_shape
filter_shape = (GROUP, OCPG, ICPG, FH, FW)
diff_shape = (N, GROUP * OCPG, OH, OW)
rin_shape = (N, IH, IW)
rout_shape = (N, OH, OW)

def reduce(shape):
mul = 1
for x in shape:
mul *= x
return mul

def get_groundtruth():
src = tensor(
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
device="cpu0",
)
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
gm = GradManager().attach([src, filter])
with gm:
expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
)
gm.backward(
expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
)
return src, filter

expected_src, expected_filter = get_groundtruth()

# forward and dgrad/wgrad
src = tensor(
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
device=handle,
)
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle)

gm = GradManager().attach([src, filter])
with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gm.backward(
gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle)
)
# assert uint8 gpu result close to cpu result
np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad)


def test_region_restricted_conv():
test_region_restricted_conv_forward_backward_naive()
if is_cuda_available():
test_region_restricted_conv_forward_backward_cuda()
test_region_restricted_conv_forward_backward_uint8()


def test_conv2d_autocast(): def test_conv2d_autocast():
"""check amp's result is equal to manually converted result""" """check amp's result is equal to manually converted result"""
amp.enabled = True amp.enabled = True


+ 171
- 0
imperative/src/impl/ops/convolution.cpp View File

@@ -3,9 +3,11 @@
#include "../blob_manager_impl.h" #include "../blob_manager_impl.h"
#include "../dnn_op_helper.h" #include "../dnn_op_helper.h"
#include "../op_trait.h" #include "../op_trait.h"
#include "megbrain/common.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_gen.h"
#include "megdnn/oprs/nn.h"


namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
@@ -356,5 +358,174 @@ OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
} // namespace convolution3d_backward_data } // namespace convolution3d_backward_data
} // namespace } // namespace


namespace {
namespace region_restricted_conv {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::RegionRestrictedConvolution>();
return RegionRestrictedConvolution::make(node->param());
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& conv = static_cast<const RegionRestrictedConvolution&>(def);
OperatorNodeConfig config{conv.make_name()};
return opr::RegionRestrictedConvolution::make(
inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), config);
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& region_restricted_conv =
def.cast_final_safe<mgb::imperative::RegionRestrictedConvolution>();
DnnOprHelper<megdnn::RegionRestrictedConvolutionForward> dnn_opr(
region_restricted_conv.param());

auto&& src = inputs[0].layout;
auto&& filter = inputs[1].layout;
auto&& rin = inputs[2].layout;
auto&& rout = inputs[3].layout;
TensorLayout output_layout{src.dtype};
if (src.ndim && filter.ndim) {
dnn_opr.opr().deduce_layout(src, filter, rin, rout, output_layout);
}

return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
}

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
// create megdnn opr
auto&& region_restricted_conv = def.cast_final_safe<RegionRestrictedConvolution>();
CompNode cn = inputs[0]->comp_node();

auto&& param = region_restricted_conv.param();
DnnOprCaller<megdnn::RegionRestrictedConvolutionForward> dnn_opr(cn, param);

auto srclo = inputs[0]->layout();
auto filterlo = inputs[1]->layout();
auto rinlo = inputs[2]->layout();
auto routlo = inputs[3]->layout();

auto out_layout = [&] {
if (validated) {
return output_descs[0].layout;
} else {
TensorLayout out_layout{inputs[0]->dtype()};
dnn_opr.op()->deduce_layout(srclo, filterlo, rinlo, routlo, out_layout);
return out_layout;
}
}();

auto out = Tensor::make(out_layout, cn);
dnn_opr.exec_with_ws(inputs[0], inputs[1], inputs[2], inputs[3], out);
return {out};
}

OP_TRAIT_REG(
RegionRestrictedConvolution, RegionRestrictedConvolution,
opr::RegionRestrictedConvolution)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace region_restricted_conv
} // namespace

namespace {
namespace region_restricted_conv_backward_data {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node =
&node_->cast_final_safe<opr::RegionRestrictedConvolutionBackwardData>();
return RegionRestrictedConvolutionBackwardData::make(node->param());
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& conv = static_cast<const RegionRestrictedConvolutionBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()};
// output_dtype may infered from input within rrconv bwd data(deduce_dtype api)
CompNode cn = inputs[0]->comp_node();
DType output_dtype;
DnnOprCaller<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr(cn);
dnn_opr.op()->deduce_dtype(
inputs[0]->dtype(), inputs[1]->dtype(), inputs[2]->dtype(),
inputs[3]->dtype(), output_dtype);
if (output_dtype.valid())
config.output_dtype(output_dtype);
if (inputs.size() == 4) {
return opr::RegionRestrictedConvolutionBackwardData::make(
inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), config);
} else if (inputs.size() == 5) {
return opr::RegionRestrictedConvolutionBackwardData::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], conv.param(),
config);
}
mgb_assert(0);
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& convbwd = def.cast_final_safe<
mgb::imperative::RegionRestrictedConvolutionBackwardData>();
DnnOprHelper<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr(
convbwd.param());

TensorLayout filter = inputs[0].layout;
TensorLayout diff = inputs[1].layout;
TensorLayout rin = inputs[2].layout;
TensorLayout rout = inputs[3].layout;

DType output_dtype;
dnn_opr.opr().deduce_dtype(
inputs[0].layout.dtype, inputs[1].layout.dtype, inputs[2].layout.dtype,
inputs[3].layout.dtype, output_dtype);
TensorLayout output_layout{output_dtype};
if (diff.ndim && filter.ndim) {
dnn_opr.opr().deduce_layout(filter, diff, rin, rout, output_layout);
}
return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
}

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& convbwd = def.cast_final_safe<RegionRestrictedConvolutionBackwardData>();
CompNode cn = inputs[0]->comp_node();
DnnOprCaller<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr(
cn, convbwd.param());

auto filterlo = inputs[0]->layout();
auto difflo = inputs[1]->layout();
auto rinlo = inputs[2]->layout();
auto routlo = inputs[3]->layout();

auto out_layout = [&] {
if (validated) {
return output_descs[0].layout;
} else {
TensorLayout out_layout{inputs[0]->dtype()};
dnn_opr.op()->deduce_layout(filterlo, difflo, rinlo, routlo, out_layout);
return out_layout;
}
}();

auto out = Tensor::make(out_layout, cn);
dnn_opr.exec_with_ws(inputs[0], inputs[1], inputs[2], inputs[3], out);
return {out};
}

OP_TRAIT_REG(
RegionRestrictedConvolutionBackwardData,
RegionRestrictedConvolutionBackwardData,
opr::RegionRestrictedConvolutionBackwardData)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace region_restricted_conv_backward_data
} // namespace

} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb

+ 5
- 5
imperative/tablegen/generated/hash.txt View File

@@ -1,7 +1,7 @@
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
40708c56b1f05fdb7d06cc097a300330 ../../src/core/include/megbrain/ir/ops.td
9f3af118c7fe8d0c9db433825d5ad77b generated/opdef.h.inl
4041e44a8ba3cca3b3affa1ed9ed44a2 generated/opdef.cpp.inl
319e1d170c989fe793a4e9c45decefc4 generated/opdef.py.inl
26a18a7593566128ecce76e8f74dcc5d generated/opdef.cpy.inl
da03ffe2a15411f902cd88920d3d47ec ../../src/core/include/megbrain/ir/ops.td
5756619f37e4dc130e1b049d7706d4eb generated/opdef.h.inl
98d1291eed73970ee087f898b6241358 generated/opdef.cpp.inl
b1a9c7569392942294c2168d40939eb5 generated/opdef.py.inl
3d88d5358d15a39219957f5257e32f5b generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h

+ 304
- 0
imperative/tablegen/generated/opdef.cpp.inl View File

@@ -5694,6 +5694,310 @@ OP_TRAIT_REG(Reduce, Reduce)
.props(Reduce_props_impl) .props(Reduce_props_impl)
.make_name(Reduce_make_name_impl); .make_name(Reduce_make_name_impl);


MGB_DYN_TYPE_OBJ_FINAL_IMPL(RegionRestrictedConvolution);

namespace {
size_t RegionRestrictedConvolution_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.mode));
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_h));
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_w));
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_h));
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_w));
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_h));
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_w));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.sparse));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.compute_mode));
return val;
}
bool RegionRestrictedConvolution_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<RegionRestrictedConvolution>(),
&&b_ = rhs_.cast_final_safe<RegionRestrictedConvolution>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.mode != b_.mode) return false;
if (a_.pad_h != b_.pad_h) return false;
if (a_.pad_w != b_.pad_w) return false;
if (a_.stride_h != b_.stride_h) return false;
if (a_.stride_w != b_.stride_w) return false;
if (a_.dilate_h != b_.dilate_h) return false;
if (a_.dilate_w != b_.dilate_w) return false;
if (a_.sparse != b_.sparse) return false;
if (a_.format != b_.format) return false;
if (a_.compute_mode != b_.compute_mode) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> RegionRestrictedConvolution_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
switch (op_.mode){
case RegionRestrictedConvolution::Mode::CROSS_CORRELATION:
props_.emplace_back("mode", "CROSS_CORRELATION");
break;
case RegionRestrictedConvolution::Mode::CONVOLUTION:
props_.emplace_back("mode", "CONVOLUTION");
break;
default:
props_.emplace_back("mode", "INVALID");
break;
}
props_.emplace_back("pad_h", std::to_string(op_.pad_h));
props_.emplace_back("pad_w", std::to_string(op_.pad_w));
props_.emplace_back("stride_h", std::to_string(op_.stride_h));
props_.emplace_back("stride_w", std::to_string(op_.stride_w));
props_.emplace_back("dilate_h", std::to_string(op_.dilate_h));
props_.emplace_back("dilate_w", std::to_string(op_.dilate_w));
switch (op_.sparse){
case RegionRestrictedConvolution::Sparse::DENSE:
props_.emplace_back("sparse", "DENSE");
break;
case RegionRestrictedConvolution::Sparse::GROUP:
props_.emplace_back("sparse", "GROUP");
break;
default:
props_.emplace_back("sparse", "INVALID");
break;
}
switch (op_.format){
case RegionRestrictedConvolution::Format::NCHW:
props_.emplace_back("format", "NCHW");
break;
case RegionRestrictedConvolution::Format::NHWC:
props_.emplace_back("format", "NHWC");
break;
case RegionRestrictedConvolution::Format::NHWCD4:
props_.emplace_back("format", "NHWCD4");
break;
case RegionRestrictedConvolution::Format::NCHW4:
props_.emplace_back("format", "NCHW4");
break;
case RegionRestrictedConvolution::Format::NCHW8:
props_.emplace_back("format", "NCHW8");
break;
case RegionRestrictedConvolution::Format::NCHW32:
props_.emplace_back("format", "NCHW32");
break;
case RegionRestrictedConvolution::Format::NCHW88:
props_.emplace_back("format", "NCHW88");
break;
case RegionRestrictedConvolution::Format::NCHW44:
props_.emplace_back("format", "NCHW44");
break;
case RegionRestrictedConvolution::Format::NCHW44_DOT:
props_.emplace_back("format", "NCHW44_DOT");
break;
case RegionRestrictedConvolution::Format::NCHW4_NCHW32:
props_.emplace_back("format", "NCHW4_NCHW32");
break;
case RegionRestrictedConvolution::Format::NCHW32_NCHW4:
props_.emplace_back("format", "NCHW32_NCHW4");
break;
case RegionRestrictedConvolution::Format::NCHW4_NCHW:
props_.emplace_back("format", "NCHW4_NCHW");
break;
case RegionRestrictedConvolution::Format::NHWC_NCHW:
props_.emplace_back("format", "NHWC_NCHW");
break;
case RegionRestrictedConvolution::Format::NHWC_NCHW4_IC_SMALL:
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL");
break;
case RegionRestrictedConvolution::Format::NCHW_NCHW4_IC_SMALL:
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL");
break;
case RegionRestrictedConvolution::Format::CHWN4:
props_.emplace_back("format", "CHWN4");
break;
case RegionRestrictedConvolution::Format::NCHW64:
props_.emplace_back("format", "NCHW64");
break;
case RegionRestrictedConvolution::Format::NCHW4_NHWC:
props_.emplace_back("format", "NCHW4_NHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
switch (op_.compute_mode){
case RegionRestrictedConvolution::ComputeMode::DEFAULT:
props_.emplace_back("compute_mode", "DEFAULT");
break;
case RegionRestrictedConvolution::ComputeMode::FLOAT32:
props_.emplace_back("compute_mode", "FLOAT32");
break;
default:
props_.emplace_back("compute_mode", "INVALID");
break;
}
return props_;
}
std::string RegionRestrictedConvolution_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>();
static_cast<void>(op_);
return "RegionRestrictedConvolution";
}
} // anonymous namespace
OP_TRAIT_REG(RegionRestrictedConvolution, RegionRestrictedConvolution)
.hash(RegionRestrictedConvolution_hash_impl)
.is_same_st(RegionRestrictedConvolution_is_same_st_impl)
.props(RegionRestrictedConvolution_props_impl)
.make_name(RegionRestrictedConvolution_make_name_impl);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(RegionRestrictedConvolutionBackwardData);

namespace {
size_t RegionRestrictedConvolutionBackwardData_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.mode));
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_h));
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_w));
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_h));
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_w));
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_h));
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_w));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.sparse));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.compute_mode));
return val;
}
bool RegionRestrictedConvolutionBackwardData_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(),
&&b_ = rhs_.cast_final_safe<RegionRestrictedConvolutionBackwardData>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.mode != b_.mode) return false;
if (a_.pad_h != b_.pad_h) return false;
if (a_.pad_w != b_.pad_w) return false;
if (a_.stride_h != b_.stride_h) return false;
if (a_.stride_w != b_.stride_w) return false;
if (a_.dilate_h != b_.dilate_h) return false;
if (a_.dilate_w != b_.dilate_w) return false;
if (a_.sparse != b_.sparse) return false;
if (a_.format != b_.format) return false;
if (a_.compute_mode != b_.compute_mode) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> RegionRestrictedConvolutionBackwardData_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
switch (op_.mode){
case RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION:
props_.emplace_back("mode", "CROSS_CORRELATION");
break;
case RegionRestrictedConvolutionBackwardData::Mode::CONVOLUTION:
props_.emplace_back("mode", "CONVOLUTION");
break;
default:
props_.emplace_back("mode", "INVALID");
break;
}
props_.emplace_back("pad_h", std::to_string(op_.pad_h));
props_.emplace_back("pad_w", std::to_string(op_.pad_w));
props_.emplace_back("stride_h", std::to_string(op_.stride_h));
props_.emplace_back("stride_w", std::to_string(op_.stride_w));
props_.emplace_back("dilate_h", std::to_string(op_.dilate_h));
props_.emplace_back("dilate_w", std::to_string(op_.dilate_w));
switch (op_.sparse){
case RegionRestrictedConvolutionBackwardData::Sparse::DENSE:
props_.emplace_back("sparse", "DENSE");
break;
case RegionRestrictedConvolutionBackwardData::Sparse::GROUP:
props_.emplace_back("sparse", "GROUP");
break;
default:
props_.emplace_back("sparse", "INVALID");
break;
}
switch (op_.format){
case RegionRestrictedConvolutionBackwardData::Format::NCHW:
props_.emplace_back("format", "NCHW");
break;
case RegionRestrictedConvolutionBackwardData::Format::NHWC:
props_.emplace_back("format", "NHWC");
break;
case RegionRestrictedConvolutionBackwardData::Format::NHWCD4:
props_.emplace_back("format", "NHWCD4");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW4:
props_.emplace_back("format", "NCHW4");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW8:
props_.emplace_back("format", "NCHW8");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW32:
props_.emplace_back("format", "NCHW32");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW88:
props_.emplace_back("format", "NCHW88");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW44:
props_.emplace_back("format", "NCHW44");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW44_DOT:
props_.emplace_back("format", "NCHW44_DOT");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NCHW32:
props_.emplace_back("format", "NCHW4_NCHW32");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW32_NCHW4:
props_.emplace_back("format", "NCHW32_NCHW4");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NCHW:
props_.emplace_back("format", "NCHW4_NCHW");
break;
case RegionRestrictedConvolutionBackwardData::Format::NHWC_NCHW:
props_.emplace_back("format", "NHWC_NCHW");
break;
case RegionRestrictedConvolutionBackwardData::Format::NHWC_NCHW4_IC_SMALL:
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW_NCHW4_IC_SMALL:
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL");
break;
case RegionRestrictedConvolutionBackwardData::Format::CHWN4:
props_.emplace_back("format", "CHWN4");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW64:
props_.emplace_back("format", "NCHW64");
break;
case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NHWC:
props_.emplace_back("format", "NCHW4_NHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
switch (op_.compute_mode){
case RegionRestrictedConvolutionBackwardData::ComputeMode::DEFAULT:
props_.emplace_back("compute_mode", "DEFAULT");
break;
case RegionRestrictedConvolutionBackwardData::ComputeMode::FLOAT32:
props_.emplace_back("compute_mode", "FLOAT32");
break;
default:
props_.emplace_back("compute_mode", "INVALID");
break;
}
return props_;
}
std::string RegionRestrictedConvolutionBackwardData_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>();
static_cast<void>(op_);
return "RegionRestrictedConvolutionBackwardData";
}
} // anonymous namespace
OP_TRAIT_REG(RegionRestrictedConvolutionBackwardData, RegionRestrictedConvolutionBackwardData)
.hash(RegionRestrictedConvolutionBackwardData_hash_impl)
.is_same_st(RegionRestrictedConvolutionBackwardData_is_same_st_impl)
.props(RegionRestrictedConvolutionBackwardData_props_impl)
.make_name(RegionRestrictedConvolutionBackwardData_make_name_impl);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(Remap); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Remap);


namespace { namespace {


+ 576
- 0
imperative/tablegen/generated/opdef.cpy.inl View File

@@ -15368,6 +15368,580 @@ void _init_py_Reduce(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Reduce::typeinfo(), &py_type).second); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Reduce::typeinfo(), &py_type).second);
} }


void _init_py_RegionRestrictedConvolution_Mode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolution::Mode>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Mode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

void _init_py_RegionRestrictedConvolution_Sparse(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolution::Sparse>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Sparse", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

void _init_py_RegionRestrictedConvolution_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolution::Format>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

void _init_py_RegionRestrictedConvolution_ComputeMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolution::ComputeMode>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "ComputeMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

PyOpDefBegin(RegionRestrictedConvolution) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"mode", serialization<decltype(opdef.mode)>::dump(opdef.mode)},
{"pad_h", serialization<decltype(opdef.pad_h)>::dump(opdef.pad_h)},
{"pad_w", serialization<decltype(opdef.pad_w)>::dump(opdef.pad_w)},
{"stride_h", serialization<decltype(opdef.stride_h)>::dump(opdef.stride_h)},
{"stride_w", serialization<decltype(opdef.stride_w)>::dump(opdef.stride_w)},
{"dilate_h", serialization<decltype(opdef.dilate_h)>::dump(opdef.dilate_h)},
{"dilate_w", serialization<decltype(opdef.dilate_w)>::dump(opdef.dilate_w)},
{"sparse", serialization<decltype(opdef.sparse)>::dump(opdef.sparse)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)},
{"compute_mode", serialization<decltype(opdef.compute_mode)>::dump(opdef.compute_mode)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("mode");
if (iter != state.end()) {
opdef.mode = serialization<decltype(opdef.mode)>::load(iter->second);
}
}

{
auto&& iter = state.find("pad_h");
if (iter != state.end()) {
opdef.pad_h = serialization<decltype(opdef.pad_h)>::load(iter->second);
}
}

{
auto&& iter = state.find("pad_w");
if (iter != state.end()) {
opdef.pad_w = serialization<decltype(opdef.pad_w)>::load(iter->second);
}
}

{
auto&& iter = state.find("stride_h");
if (iter != state.end()) {
opdef.stride_h = serialization<decltype(opdef.stride_h)>::load(iter->second);
}
}

{
auto&& iter = state.find("stride_w");
if (iter != state.end()) {
opdef.stride_w = serialization<decltype(opdef.stride_w)>::load(iter->second);
}
}

{
auto&& iter = state.find("dilate_h");
if (iter != state.end()) {
opdef.dilate_h = serialization<decltype(opdef.dilate_h)>::load(iter->second);
}
}

{
auto&& iter = state.find("dilate_w");
if (iter != state.end()) {
opdef.dilate_w = serialization<decltype(opdef.dilate_w)>::load(iter->second);
}
}

{
auto&& iter = state.find("sparse");
if (iter != state.end()) {
opdef.sparse = serialization<decltype(opdef.sparse)>::load(iter->second);
}
}

{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}

{
auto&& iter = state.find("compute_mode");
if (iter != state.end()) {
opdef.compute_mode = serialization<decltype(opdef.compute_mode)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd(RegionRestrictedConvolution)

int PyOp(RegionRestrictedConvolution)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"mode", "pad_h", "pad_w", "stride_h", "stride_w", "dilate_h", "dilate_w", "sparse", "format", "compute_mode", "scope", NULL};
PyObject *mode = NULL, *pad_h = NULL, *pad_w = NULL, *stride_h = NULL, *stride_w = NULL, *dilate_h = NULL, *dilate_w = NULL, *sparse = NULL, *format = NULL, *compute_mode = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOO", const_cast<char**>(kwlist), &mode, &pad_h, &pad_w, &stride_h, &stride_w, &dilate_h, &dilate_w, &sparse, &format, &compute_mode, &scope))
return -1;

if (mode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().mode =
py::cast<decltype(RegionRestrictedConvolution::mode)>(py::handle(mode));
} CATCH_ALL(-1)
}

if (pad_h) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().pad_h =
py::cast<decltype(RegionRestrictedConvolution::pad_h)>(py::handle(pad_h));
} CATCH_ALL(-1)
}

if (pad_w) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().pad_w =
py::cast<decltype(RegionRestrictedConvolution::pad_w)>(py::handle(pad_w));
} CATCH_ALL(-1)
}

if (stride_h) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().stride_h =
py::cast<decltype(RegionRestrictedConvolution::stride_h)>(py::handle(stride_h));
} CATCH_ALL(-1)
}

if (stride_w) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().stride_w =
py::cast<decltype(RegionRestrictedConvolution::stride_w)>(py::handle(stride_w));
} CATCH_ALL(-1)
}

if (dilate_h) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().dilate_h =
py::cast<decltype(RegionRestrictedConvolution::dilate_h)>(py::handle(dilate_h));
} CATCH_ALL(-1)
}

if (dilate_w) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().dilate_w =
py::cast<decltype(RegionRestrictedConvolution::dilate_w)>(py::handle(dilate_w));
} CATCH_ALL(-1)
}

if (sparse) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().sparse =
py::cast<decltype(RegionRestrictedConvolution::sparse)>(py::handle(sparse));
} CATCH_ALL(-1)
}

if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().format =
py::cast<decltype(RegionRestrictedConvolution::format)>(py::handle(format));
} CATCH_ALL(-1)
}

if (compute_mode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().compute_mode =
py::cast<decltype(RegionRestrictedConvolution::compute_mode)>(py::handle(compute_mode));
} CATCH_ALL(-1)
}

if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}

return 0;
}

PyGetSetDef PyOp(RegionRestrictedConvolution)::py_getsetters[] = {
{const_cast<char*>("mode"), py_get_generic(RegionRestrictedConvolution, mode), py_set_generic(RegionRestrictedConvolution, mode), const_cast<char*>("mode"), NULL},
{const_cast<char*>("pad_h"), py_get_generic(RegionRestrictedConvolution, pad_h), py_set_generic(RegionRestrictedConvolution, pad_h), const_cast<char*>("pad_h"), NULL},
{const_cast<char*>("pad_w"), py_get_generic(RegionRestrictedConvolution, pad_w), py_set_generic(RegionRestrictedConvolution, pad_w), const_cast<char*>("pad_w"), NULL},
{const_cast<char*>("stride_h"), py_get_generic(RegionRestrictedConvolution, stride_h), py_set_generic(RegionRestrictedConvolution, stride_h), const_cast<char*>("stride_h"), NULL},
{const_cast<char*>("stride_w"), py_get_generic(RegionRestrictedConvolution, stride_w), py_set_generic(RegionRestrictedConvolution, stride_w), const_cast<char*>("stride_w"), NULL},
{const_cast<char*>("dilate_h"), py_get_generic(RegionRestrictedConvolution, dilate_h), py_set_generic(RegionRestrictedConvolution, dilate_h), const_cast<char*>("dilate_h"), NULL},
{const_cast<char*>("dilate_w"), py_get_generic(RegionRestrictedConvolution, dilate_w), py_set_generic(RegionRestrictedConvolution, dilate_w), const_cast<char*>("dilate_w"), NULL},
{const_cast<char*>("sparse"), py_get_generic(RegionRestrictedConvolution, sparse), py_set_generic(RegionRestrictedConvolution, sparse), const_cast<char*>("sparse"), NULL},
{const_cast<char*>("format"), py_get_generic(RegionRestrictedConvolution, format), py_set_generic(RegionRestrictedConvolution, format), const_cast<char*>("format"), NULL},
{const_cast<char*>("compute_mode"), py_get_generic(RegionRestrictedConvolution, compute_mode), py_set_generic(RegionRestrictedConvolution, compute_mode), const_cast<char*>("compute_mode"), NULL},
{NULL} /* Sentinel */
};

PyMethodDef PyOp(RegionRestrictedConvolution)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(RegionRestrictedConvolution)::getstate, METH_NOARGS, "RegionRestrictedConvolution getstate"},
{const_cast<char*>("__setstate__"), PyOp(RegionRestrictedConvolution)::setstate, METH_VARARGS, "RegionRestrictedConvolution setstate"},
{NULL} /* Sentinel */
};
void _init_py_RegionRestrictedConvolution(py::module m) {
using py_op = PyOp(RegionRestrictedConvolution);
auto& py_type = PyOpType(RegionRestrictedConvolution);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.RegionRestrictedConvolution";
py_type.tp_basicsize = sizeof(PyOp(RegionRestrictedConvolution));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "RegionRestrictedConvolution";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_RegionRestrictedConvolution_Mode(py_type);
_init_py_RegionRestrictedConvolution_Sparse(py_type);
_init_py_RegionRestrictedConvolution_Format(py_type);
_init_py_RegionRestrictedConvolution_ComputeMode(py_type);

PyType_Modified(&py_type);
m.add_object("RegionRestrictedConvolution", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(RegionRestrictedConvolution::typeinfo(), &py_type).second);
}

void _init_py_RegionRestrictedConvolutionBackwardData_Mode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Mode>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Mode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

void _init_py_RegionRestrictedConvolutionBackwardData_Sparse(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Sparse>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Sparse", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

void _init_py_RegionRestrictedConvolutionBackwardData_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Format>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

void _init_py_RegionRestrictedConvolutionBackwardData_ComputeMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::ComputeMode>::type;

Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "ComputeMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}

PyOpDefBegin(RegionRestrictedConvolutionBackwardData) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"mode", serialization<decltype(opdef.mode)>::dump(opdef.mode)},
{"pad_h", serialization<decltype(opdef.pad_h)>::dump(opdef.pad_h)},
{"pad_w", serialization<decltype(opdef.pad_w)>::dump(opdef.pad_w)},
{"stride_h", serialization<decltype(opdef.stride_h)>::dump(opdef.stride_h)},
{"stride_w", serialization<decltype(opdef.stride_w)>::dump(opdef.stride_w)},
{"dilate_h", serialization<decltype(opdef.dilate_h)>::dump(opdef.dilate_h)},
{"dilate_w", serialization<decltype(opdef.dilate_w)>::dump(opdef.dilate_w)},
{"sparse", serialization<decltype(opdef.sparse)>::dump(opdef.sparse)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)},
{"compute_mode", serialization<decltype(opdef.compute_mode)>::dump(opdef.compute_mode)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("mode");
if (iter != state.end()) {
opdef.mode = serialization<decltype(opdef.mode)>::load(iter->second);
}
}

{
auto&& iter = state.find("pad_h");
if (iter != state.end()) {
opdef.pad_h = serialization<decltype(opdef.pad_h)>::load(iter->second);
}
}

{
auto&& iter = state.find("pad_w");
if (iter != state.end()) {
opdef.pad_w = serialization<decltype(opdef.pad_w)>::load(iter->second);
}
}

{
auto&& iter = state.find("stride_h");
if (iter != state.end()) {
opdef.stride_h = serialization<decltype(opdef.stride_h)>::load(iter->second);
}
}

{
auto&& iter = state.find("stride_w");
if (iter != state.end()) {
opdef.stride_w = serialization<decltype(opdef.stride_w)>::load(iter->second);
}
}

{
auto&& iter = state.find("dilate_h");
if (iter != state.end()) {
opdef.dilate_h = serialization<decltype(opdef.dilate_h)>::load(iter->second);
}
}

{
auto&& iter = state.find("dilate_w");
if (iter != state.end()) {
opdef.dilate_w = serialization<decltype(opdef.dilate_w)>::load(iter->second);
}
}

{
auto&& iter = state.find("sparse");
if (iter != state.end()) {
opdef.sparse = serialization<decltype(opdef.sparse)>::load(iter->second);
}
}

{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}

{
auto&& iter = state.find("compute_mode");
if (iter != state.end()) {
opdef.compute_mode = serialization<decltype(opdef.compute_mode)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd(RegionRestrictedConvolutionBackwardData)

int PyOp(RegionRestrictedConvolutionBackwardData)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"mode", "pad_h", "pad_w", "stride_h", "stride_w", "dilate_h", "dilate_w", "sparse", "format", "compute_mode", "scope", NULL};
PyObject *mode = NULL, *pad_h = NULL, *pad_w = NULL, *stride_h = NULL, *stride_w = NULL, *dilate_h = NULL, *dilate_w = NULL, *sparse = NULL, *format = NULL, *compute_mode = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOO", const_cast<char**>(kwlist), &mode, &pad_h, &pad_w, &stride_h, &stride_w, &dilate_h, &dilate_w, &sparse, &format, &compute_mode, &scope))
return -1;

if (mode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().mode =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::mode)>(py::handle(mode));
} CATCH_ALL(-1)
}

if (pad_h) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().pad_h =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::pad_h)>(py::handle(pad_h));
} CATCH_ALL(-1)
}

if (pad_w) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().pad_w =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::pad_w)>(py::handle(pad_w));
} CATCH_ALL(-1)
}

if (stride_h) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().stride_h =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::stride_h)>(py::handle(stride_h));
} CATCH_ALL(-1)
}

if (stride_w) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().stride_w =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::stride_w)>(py::handle(stride_w));
} CATCH_ALL(-1)
}

if (dilate_h) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().dilate_h =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::dilate_h)>(py::handle(dilate_h));
} CATCH_ALL(-1)
}

if (dilate_w) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().dilate_w =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::dilate_w)>(py::handle(dilate_w));
} CATCH_ALL(-1)
}

if (sparse) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().sparse =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::sparse)>(py::handle(sparse));
} CATCH_ALL(-1)
}

if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().format =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::format)>(py::handle(format));
} CATCH_ALL(-1)
}

if (compute_mode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().compute_mode =
py::cast<decltype(RegionRestrictedConvolutionBackwardData::compute_mode)>(py::handle(compute_mode));
} CATCH_ALL(-1)
}

if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}

return 0;
}

PyGetSetDef PyOp(RegionRestrictedConvolutionBackwardData)::py_getsetters[] = {
{const_cast<char*>("mode"), py_get_generic(RegionRestrictedConvolutionBackwardData, mode), py_set_generic(RegionRestrictedConvolutionBackwardData, mode), const_cast<char*>("mode"), NULL},
{const_cast<char*>("pad_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, pad_h), py_set_generic(RegionRestrictedConvolutionBackwardData, pad_h), const_cast<char*>("pad_h"), NULL},
{const_cast<char*>("pad_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, pad_w), py_set_generic(RegionRestrictedConvolutionBackwardData, pad_w), const_cast<char*>("pad_w"), NULL},
{const_cast<char*>("stride_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, stride_h), py_set_generic(RegionRestrictedConvolutionBackwardData, stride_h), const_cast<char*>("stride_h"), NULL},
{const_cast<char*>("stride_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, stride_w), py_set_generic(RegionRestrictedConvolutionBackwardData, stride_w), const_cast<char*>("stride_w"), NULL},
{const_cast<char*>("dilate_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, dilate_h), py_set_generic(RegionRestrictedConvolutionBackwardData, dilate_h), const_cast<char*>("dilate_h"), NULL},
{const_cast<char*>("dilate_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, dilate_w), py_set_generic(RegionRestrictedConvolutionBackwardData, dilate_w), const_cast<char*>("dilate_w"), NULL},
{const_cast<char*>("sparse"), py_get_generic(RegionRestrictedConvolutionBackwardData, sparse), py_set_generic(RegionRestrictedConvolutionBackwardData, sparse), const_cast<char*>("sparse"), NULL},
{const_cast<char*>("format"), py_get_generic(RegionRestrictedConvolutionBackwardData, format), py_set_generic(RegionRestrictedConvolutionBackwardData, format), const_cast<char*>("format"), NULL},
{const_cast<char*>("compute_mode"), py_get_generic(RegionRestrictedConvolutionBackwardData, compute_mode), py_set_generic(RegionRestrictedConvolutionBackwardData, compute_mode), const_cast<char*>("compute_mode"), NULL},
{NULL} /* Sentinel */
};

PyMethodDef PyOp(RegionRestrictedConvolutionBackwardData)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(RegionRestrictedConvolutionBackwardData)::getstate, METH_NOARGS, "RegionRestrictedConvolutionBackwardData getstate"},
{const_cast<char*>("__setstate__"), PyOp(RegionRestrictedConvolutionBackwardData)::setstate, METH_VARARGS, "RegionRestrictedConvolutionBackwardData setstate"},
{NULL} /* Sentinel */
};
void _init_py_RegionRestrictedConvolutionBackwardData(py::module m) {
using py_op = PyOp(RegionRestrictedConvolutionBackwardData);
auto& py_type = PyOpType(RegionRestrictedConvolutionBackwardData);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.RegionRestrictedConvolutionBackwardData";
py_type.tp_basicsize = sizeof(PyOp(RegionRestrictedConvolutionBackwardData));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "RegionRestrictedConvolutionBackwardData";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_RegionRestrictedConvolutionBackwardData_Mode(py_type);
_init_py_RegionRestrictedConvolutionBackwardData_Sparse(py_type);
_init_py_RegionRestrictedConvolutionBackwardData_Format(py_type);
_init_py_RegionRestrictedConvolutionBackwardData_ComputeMode(py_type);

PyType_Modified(&py_type);
m.add_object("RegionRestrictedConvolutionBackwardData", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(RegionRestrictedConvolutionBackwardData::typeinfo(), &py_type).second);
}

template<> struct EnumTrait<Remap::InterpolationMode> { template<> struct EnumTrait<Remap::InterpolationMode> {
static constexpr const char *name = "Remap.InterpolationMode"; static constexpr const char *name = "Remap.InterpolationMode";
static constexpr std::underlying_type_t<Remap::InterpolationMode> max = 5 - 1; static constexpr std::underlying_type_t<Remap::InterpolationMode> max = 5 - 1;
@@ -18700,6 +19274,8 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_ROIAlign(m); \ _init_py_ROIAlign(m); \
_init_py_ROIPooling(m); \ _init_py_ROIPooling(m); \
_init_py_Reduce(m); \ _init_py_Reduce(m); \
_init_py_RegionRestrictedConvolution(m); \
_init_py_RegionRestrictedConvolutionBackwardData(m); \
_init_py_Remap(m); \ _init_py_Remap(m); \
_init_py_RemoteRecv(m); \ _init_py_RemoteRecv(m); \
_init_py_RemoteSend(m); \ _init_py_RemoteSend(m); \


+ 52
- 0
imperative/tablegen/generated/opdef.h.inl View File

@@ -1517,6 +1517,58 @@ public:
} }
}; };


class RegionRestrictedConvolution : public OpDefImplBase<RegionRestrictedConvolution> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
using Mode = ::megdnn::param::Convolution::Mode;
using Sparse = ::megdnn::param::Convolution::Sparse;
using Format = ::megdnn::param::Convolution::Format;
using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
Format format = ::megdnn::param::Convolution::Format::NCHW;
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
RegionRestrictedConvolution() = default;
RegionRestrictedConvolution(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); }
RegionRestrictedConvolution(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {}
::megdnn::param::Convolution param() const {
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
}
};

class RegionRestrictedConvolutionBackwardData : public OpDefImplBase<RegionRestrictedConvolutionBackwardData> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
using Mode = ::megdnn::param::Convolution::Mode;
using Sparse = ::megdnn::param::Convolution::Sparse;
using Format = ::megdnn::param::Convolution::Format;
using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
Format format = ::megdnn::param::Convolution::Format::NCHW;
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
RegionRestrictedConvolutionBackwardData() = default;
RegionRestrictedConvolutionBackwardData(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); }
RegionRestrictedConvolutionBackwardData(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {}
::megdnn::param::Convolution param() const {
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
}
};

class Remap : public OpDefImplBase<Remap> { class Remap : public OpDefImplBase<Remap> {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;




+ 46
- 0
imperative/tablegen/generated/opdef.py.inl View File

@@ -1620,6 +1620,52 @@ ReduceInst
.def_readwrite("data_type", &Reduce::data_type) .def_readwrite("data_type", &Reduce::data_type)
.def_readwrite("keepdim", &Reduce::keepdim); .def_readwrite("keepdim", &Reduce::keepdim);


py::class_<RegionRestrictedConvolution, std::shared_ptr<RegionRestrictedConvolution>, OpDef> RegionRestrictedConvolutionInst(m, "RegionRestrictedConvolution");

RegionRestrictedConvolutionInst.attr("Mode") = BatchConvBiasInst.attr("Mode");

RegionRestrictedConvolutionInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");

RegionRestrictedConvolutionInst.attr("Format") = AdaptivePoolingInst.attr("Format");

RegionRestrictedConvolutionInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");

RegionRestrictedConvolutionInst
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {})
.def_readwrite("mode", &RegionRestrictedConvolution::mode)
.def_readwrite("pad_h", &RegionRestrictedConvolution::pad_h)
.def_readwrite("pad_w", &RegionRestrictedConvolution::pad_w)
.def_readwrite("stride_h", &RegionRestrictedConvolution::stride_h)
.def_readwrite("stride_w", &RegionRestrictedConvolution::stride_w)
.def_readwrite("dilate_h", &RegionRestrictedConvolution::dilate_h)
.def_readwrite("dilate_w", &RegionRestrictedConvolution::dilate_w)
.def_readwrite("sparse", &RegionRestrictedConvolution::sparse)
.def_readwrite("format", &RegionRestrictedConvolution::format)
.def_readwrite("compute_mode", &RegionRestrictedConvolution::compute_mode);

py::class_<RegionRestrictedConvolutionBackwardData, std::shared_ptr<RegionRestrictedConvolutionBackwardData>, OpDef> RegionRestrictedConvolutionBackwardDataInst(m, "RegionRestrictedConvolutionBackwardData");

RegionRestrictedConvolutionBackwardDataInst.attr("Mode") = BatchConvBiasInst.attr("Mode");

RegionRestrictedConvolutionBackwardDataInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");

RegionRestrictedConvolutionBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format");

RegionRestrictedConvolutionBackwardDataInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");

RegionRestrictedConvolutionBackwardDataInst
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {})
.def_readwrite("mode", &RegionRestrictedConvolutionBackwardData::mode)
.def_readwrite("pad_h", &RegionRestrictedConvolutionBackwardData::pad_h)
.def_readwrite("pad_w", &RegionRestrictedConvolutionBackwardData::pad_w)
.def_readwrite("stride_h", &RegionRestrictedConvolutionBackwardData::stride_h)
.def_readwrite("stride_w", &RegionRestrictedConvolutionBackwardData::stride_w)
.def_readwrite("dilate_h", &RegionRestrictedConvolutionBackwardData::dilate_h)
.def_readwrite("dilate_w", &RegionRestrictedConvolutionBackwardData::dilate_w)
.def_readwrite("sparse", &RegionRestrictedConvolutionBackwardData::sparse)
.def_readwrite("format", &RegionRestrictedConvolutionBackwardData::format)
.def_readwrite("compute_mode", &RegionRestrictedConvolutionBackwardData::compute_mode);

py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap"); py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap");


py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode") py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode")


+ 5
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -520,4 +520,9 @@ def MeshGrid: MgbHashableOp<"MeshGrid"> {
MgbStringAttr:$indexing MgbStringAttr:$indexing
); );
} }

def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [ConvolutionParam]>;

def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>;

#endif // MGB_OPS #endif // MGB_OPS

+ 272
- 0
src/opr/impl/dnn/convolution.cpp View File

@@ -25,6 +25,58 @@ using namespace cg::static_infer;
using intl::WorkspaceLimitGetter; using intl::WorkspaceLimitGetter;


/* ==================== misc impl ==================== */ /* ==================== misc impl ==================== */
template <typename MGBOPR, typename DNNOPR>
void mixin::RegionConvBackwardDataMixin::init_output_static_infer_desc_for_bwd_data(
cg::OperatorNodeBase* self) {
using namespace cg::static_infer;
auto&& mgr = self->owner_graph()->static_infer_manager();

DepVal inp_deps;
inp_deps.reserve(6);
for (int i = 0; i < 4; i++) {
inp_deps.push_back({self->input(i), DepType::SHAPE});
}

auto infer_shp = [self](TensorShape& dest, const InpVal& inp) {
TensorLayout ol{self->output(0)->dtype()};
mgb_assert(
self->input(0)->dtype().category() == DTypeCategory::FLOAT &&
self->input(1)->dtype().category() == DTypeCategory::FLOAT &&
self->input(2)->dtype().category() == DTypeCategory::INT &&
self->input(3)->dtype().category() == DTypeCategory::INT,
"region conv dtype assert error!");
static_cast<MGBOPR*>(self)->megdnn_opr()->deduce_layout(
{inp.val.at(0).shape(), self->input(0)->dtype()}, // filter
{inp.val.at(1).shape(), self->input(1)->dtype()}, // diff
{inp.val.at(2).shape(), self->input(2)->dtype()}, // rin
{inp.val.at(3).shape(), self->input(3)->dtype()}, // rout
ol // grad
);
dest = ol;
return true;
};
mgr.register_shape_infer(self->output(0), {SourceType::DEP, inp_deps, infer_shp});

// workspace size
auto infer_wk = [self](TensorShape& dest, const InpVal& inp) {
TensorLayout ol{self->output(0)->dtype()};
dest.ndim = 1;
dest.shape[0] =
static_cast<MGBOPR*>(self)->megdnn_opr()->get_workspace_in_bytes(
{self->input(0)->shape(), self->input(0)->dtype()}, // filter
{self->input(1)->shape(), self->input(1)->dtype()}, // diff
{self->input(2)->shape(), self->input(2)->dtype()}, // rin
{self->input(3)->shape(), self->input(3)->dtype()}, // rout
ol);
return true;
};
inp_deps.push_back({self->output(0), DepType::SHAPE});
auto workspace_dep_var =
intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph());
if (workspace_dep_var)
inp_deps.push_back({workspace_dep_var, DepType::VALUE});
mgr.register_shape_infer(self->output(1), {SourceType::DEP, inp_deps, infer_wk});
}


template <class MgbOpr, class MegDNNOpr> template <class MgbOpr, class MegDNNOpr>
void mixin::ConvolutionBackwardDataMixin::init_output_static_infer_desc_for_bwd_data( void mixin::ConvolutionBackwardDataMixin::init_output_static_infer_desc_for_bwd_data(
@@ -1535,6 +1587,226 @@ void BatchConvBiasForward::init_output_format() {
output(0)->format(input(0)->format()); output(0)->format(input(0)->format());
} }


/* ========================== RegionRestrictedConvolutionForward
* ========================== */

IMPL_CONV(RegionRestrictedConvolutionForward);

RegionRestrictedConvolutionForward::RegionRestrictedConvolutionForward(
VarNode* src, VarNode* filter, VarNode* region_in, VarNode* region_out,
const Param& param, const OperatorNodeConfig& config)
: Super(src->owner_graph(), config, "region_restricted_conv_fwd",
{src, filter, region_in, region_out}) {
init_megdnn_opr(*this, param);
add_input({src, filter, region_in, region_out});
}

SymbolVar RegionRestrictedConvolutionForward::make(
SymbolVar src, SymbolVar filter, SymbolVar region_in, SymbolVar region_out,
const Param& param, const OperatorNodeConfig& config) {
return src.insert_single_output_opr<RegionRestrictedConvolutionForward>(
src.node(), filter.node(), region_in.node(), region_out.node(), param,
config);
}

void RegionRestrictedConvolutionForward::init_output_dtype() {
mgb_assert(
input(0)->dtype().category() == DTypeCategory::FLOAT,
"input dtype only support FLOAT, \
but got input dtype: %s",
input(0)->dtype().name());
output(0)->dtype(input(0)->dtype());
return;
}

size_t RegionRestrictedConvolutionForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return megdnn_opr()->get_workspace_in_bytes(
{input_shapes[0], input(0)->dtype(), input(0)->format()},
{input_shapes[1], input(1)->dtype(), input(1)->format()},
{input_shapes[2], input(2)->dtype(), input(2)->format()},
{input_shapes[3], input(3)->dtype(), input(3)->format()},
{output_shapes[0], output(0)->dtype(), output(0)->format()});
}

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionForward) {
mgb_assert(
opr.input(0)->dtype().category() == DTypeCategory::FLOAT &&
opr.input(1)->dtype().category() == DTypeCategory::FLOAT &&
opr.input(2)->dtype().category() == DTypeCategory::INT &&
opr.input(3)->dtype().category() == DTypeCategory::INT,
"only float data type supported for grad");
if (wrt_idx == 0) { // src
SymbolVar grad = RegionRestrictedConvolutionBackwardData::make(
opr.input(1), // filter
out_grad[0], // diff
opr.input(2), // rin
opr.input(3), // rout
opr.input(0), // src
opr.param());
return grad.node();
}
// TODO: CUDA WGRAD UNIMPLEMENTED!
if (wrt_idx == 1) { // filter
SymbolVar grad = RegionRestrictedConvolutionBackwardFilter::make(
opr.input(0), // src
out_grad[0], // diff
opr.input(2), // rin
opr.input(3), // rout
opr.input(1), // filter
opr.param());
return grad.node();
}
return nullptr;
}
#endif

/* ========================== RegionRestrictedConvolutionBackwardData
* ========================== */
IMPL_CONV(RegionRestrictedConvolutionBackwardData);

RegionRestrictedConvolutionBackwardData::RegionRestrictedConvolutionBackwardData(
VarNode* filter, VarNode* diff, VarNode* region_in, VarNode* region_out,
VarNode* src, const Param& param, const OperatorNodeConfig& config)
: Super{filter->owner_graph(),
config,
"region_restricted_conv_bwd_data",
{filter, diff, region_in, region_out}} {
init_megdnn_opr(*this, param);
add_input({filter, diff, region_in, region_out});
if (src)
add_input({src});
}

SymbolVar RegionRestrictedConvolutionBackwardData::make(
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out,
SymbolVar src, const Param& param, const OperatorNodeConfig& config) {
return filter.insert_single_output_opr<RegionRestrictedConvolutionBackwardData>(
filter.node(), diff.node(), region_in.node(), region_out.node(), src.node(),
param, config);
}

SymbolVar RegionRestrictedConvolutionBackwardData::make(
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out,
const Param& param, const OperatorNodeConfig& config) {
return make(filter, diff, region_in, region_out, {}, param, config);
}

void RegionRestrictedConvolutionBackwardData::init_output_static_infer_desc() {
init_output_static_infer_desc_for_bwd_data<
RegionRestrictedConvolutionBackwardData,
megdnn::RegionRestrictedConvolutionBackwardData>(this);
}

void RegionRestrictedConvolutionBackwardData::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}

void RegionRestrictedConvolutionBackwardData::scn_do_execute() {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), // filter
input(1)->dev_tensor().as_megdnn(), // diff
input(2)->dev_tensor().as_megdnn(), // rin
input(3)->dev_tensor().as_megdnn(), // rout
output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
}

cg::OperatorNodeBase::NodeProp* RegionRestrictedConvolutionBackwardData::
do_make_node_prop() const {
auto prop = Super::Super::do_make_node_prop();
if (input().size() == 5) {
using D = NodeProp::DepType;
prop->reset_dep_type(
input(),
{D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE, D::SHAPE});
}
return prop;
}

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionBackwardData) {
if (wrt_idx == 0) { // filter
return RegionRestrictedConvolutionBackwardFilter::make(
out_grad[0], opr.input(1), opr.input(2), opr.input(3),
opr.input(0), opr.param())
.node();
}
if (wrt_idx == 1) { // diff
return RegionRestrictedConvolution::make(
out_grad[0], opr.input(0), opr.input(2), opr.input(3),
opr.param())
.node();
}
return nullptr;
}
#endif

/* ========================== RegionRestrictedConvolutionBackwardFilter
* ========================== */
IMPL_CONV(RegionRestrictedConvolutionBackwardFilter);

RegionRestrictedConvolutionBackwardFilter::RegionRestrictedConvolutionBackwardFilter(
VarNode* src, VarNode* diff, VarNode* region_in, VarNode* region_out,
VarNode* filter, const Param& param, const OperatorNodeConfig& config)
: Super({src->owner_graph(),
config,
"region_restricted_conv_bwd_filter",
{src, diff, region_in, region_out, filter}},
4, false) {
init_megdnn_opr(*this, param);
add_input({src, diff, region_in, region_out, filter});
}

SymbolVar RegionRestrictedConvolutionBackwardFilter::make(
SymbolVar src, SymbolVar diff, SymbolVar region_in, SymbolVar region_out,
SymbolVar filter, const Param& param, const OperatorNodeConfig& config) {
return src.insert_single_output_opr<RegionRestrictedConvolutionBackwardFilter>(
src.node(), diff.node(), region_in.node(), region_out.node(), filter.node(),
param, config);
}

size_t RegionRestrictedConvolutionBackwardFilter::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return megdnn_opr()->get_workspace_in_bytes(
{input_shapes[0], input(0)->dtype(), input(0)->format()},
{input_shapes[1], input(1)->dtype(), input(1)->format()},
{input_shapes[2], input(2)->dtype(), input(2)->format()},
{input_shapes[3], input(3)->dtype(), input(3)->format()},
{output_shapes[0], output(0)->dtype(), output(0)->format()});
}

void RegionRestrictedConvolutionBackwardFilter::scn_do_execute() {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), // src
input(1)->dev_tensor().as_megdnn(), // diff
input(2)->dev_tensor().as_megdnn(), // rin
input(3)->dev_tensor().as_megdnn(), // rout
output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionBackwardFilter) {
if (wrt_idx == 0) {
return RegionRestrictedConvolutionBackwardData::make(
out_grad[0] /*filter*/, opr.input(1) /*diff*/,
opr.input(2) /*rin*/, opr.input(3) /*rout*/,
opr.input(0) /*src*/, opr.param())
.node();
}
if (wrt_idx == 1) {
return RegionRestrictedConvolution::make(
opr.input(0) /*src*/, out_grad[0] /*filter*/,
opr.input(2) /*rin*/, opr.input(3) /*rout*/, opr.param())
.node();
}
return nullptr;
}
#endif

#undef IMPL_CONV #undef IMPL_CONV


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 0
src/opr/impl/dnn/dnn.sereg.h View File

@@ -431,6 +431,7 @@ struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0>
MakeConvCallerEmpty<megdnn::Convolution3D>, MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>, MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {}; megdnn::param::Convolution3D> {};

template <> template <>
struct OprLoadDumpImpl<opr::ConvBiasForward, 0> struct OprLoadDumpImpl<opr::ConvBiasForward, 0>
: public ConvLoadDumpImpl< : public ConvLoadDumpImpl<


+ 28
- 0
src/opr/impl/dnn/dnn.sereg.v2.h View File

@@ -194,6 +194,30 @@ struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0>
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
megdnn::Convolution> {}; megdnn::Convolution> {};


template <>
struct OprMaker<opr::RegionRestrictedConvolutionBackwardData, 0> {
using Opr = opr::RegionRestrictedConvolutionBackwardData;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 4) { // deconv mode
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config)
.node()
->owner_opr();
} else if (inputs.size() == 5) { // dgrad mode
return Opr::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param,
config)
.node()
->owner_opr();
} else {
return nullptr;
}
}
};

} // namespace serialization } // namespace serialization


namespace opr { namespace opr {
@@ -220,6 +244,10 @@ SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0);


MGB_SEREG_OPR(RegionRestrictedConvolutionBackwardData, 0);
MGB_SEREG_OPR(RegionRestrictedConvolution, 4);
MGB_SEREG_OPR(RegionRestrictedConvolutionBackwardFilter, 5);

SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0);


+ 80
- 0
src/opr/include/megbrain/opr/dnn/convolution.h View File

@@ -18,6 +18,12 @@ protected:
static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self); static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self);
}; };


class RegionConvBackwardDataMixin : public cg::OperatorNodeMixinBase {
protected:
template <typename MGBOPR, typename DNNOPR>
static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self);
};

class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase {
class PreprocessedFilterExecDep; class PreprocessedFilterExecDep;


@@ -83,6 +89,80 @@ class ConvolutionTestingPeer;


} // namespace testing } // namespace testing


/* ==================== RegionRestrictedConvolutionForward ==================== */
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
RegionRestrictedConvolutionForward,
intl::MegDNNOprWrapperFwd<megdnn::RegionRestrictedConvolutionForward>) // {
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void init_output_dtype() override;

public:
MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionForward(
VarNode* src, VarNode* filter, VarNode* region_in, VarNode* region_out,
const Param& param, const OperatorNodeConfig& config);

MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar filter, SymbolVar region_in, SymbolVar region_out,
const Param& param, const OperatorNodeConfig& config = {});
};
using RegionRestrictedConvolution = RegionRestrictedConvolutionForward;

/* ==================== RegionRestrictedConvolutionBackwardData ==================== */
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
RegionRestrictedConvolutionBackwardData,
cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<
megdnn::RegionRestrictedConvolutionBackwardData>>,
public mixin::RegionConvBackwardDataMixin) // {
void scn_do_execute() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
void init_output_dtype() override;

public:
MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionBackwardData(
VarNode* filter, VarNode* diff, VarNode* region_in, VarNode* region_out,
VarNode* src, const Param& param, const OperatorNodeConfig& config);

// grad mode
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out,
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});

// sereg for deconv mode
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out,
const Param& param, const OperatorNodeConfig& config = {});

// user interface for deconv
MGE_WIN_DECLSPEC_FUC static SymbolVar make_deconv(
SymbolVar data, SymbolVar filter, SymbolVar region_in, SymbolVar region_out,
const Param& param = {}, const OperatorNodeConfig& config = {}) {
return make(filter, data, region_in, region_out, param, config);
}
};

/* ==================== RegionRestrictedConvolutionBackwardFilter ==================== */
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
RegionRestrictedConvolutionBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::RegionRestrictedConvolutionBackwardFilter>) // {
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;

public:
MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionBackwardFilter(
VarNode* src, VarNode* diff, VarNode* region_in, VarNode* region_out,
VarNode* filter, const Param& param, const OperatorNodeConfig& config);

MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar diff, SymbolVar region_in, SymbolVar region_out,
SymbolVar filter, const Param& param,
const OperatorNodeConfig& config = {});
};

MGB_DEFINE_OPR_CLASS_WITH_EXPORT( MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ConvolutionForward, intl::ConvolutionForwardBase, ConvolutionForward, intl::ConvolutionForwardBase,
public mixin::AlgoChooserHelper) // { public mixin::AlgoChooserHelper) // {


+ 196
- 0
src/opr/test/dnn/region_restricted_convolution.cpp View File

@@ -0,0 +1,196 @@
#include "./legacy_checker.h"
#include "megbrain/comp_node_env.h"

#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megdnn/algorithm_cache.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs/base.h"

#include <gmock/gmock.h>

#include <cmath>
#include <memory>
#include <random>

using namespace mgb;

TEST(TestOprDNN, REGIONCONV_FWD_CPU_WRAPPER) {
using Checker = AutoOprChecker<4, 1>;
megdnn::RegionRestrictedConvolution::Param param;
param.sparse = opr::RegionRestrictedConvolution::Param::Sparse::DENSE;

auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
return {opr::RegionRestrictedConvolutionForward::make(
inputs[0], inputs[1], inputs[2], inputs[3], param)};
};

Checker::RunOptions option;
option.numdiff_eps = 0.1;
option.numdiff_max_err = 1e-2;

auto mask_gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Int32, RandomDistribution::CONSTANT> gen(1);
src = *gen(src.shape(), src.comp_node());
};
auto float_gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> gen;
src = *gen(src.shape(), src.comp_node());
};

auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr =
megdnn_naive_handle()
->create_operator<megdnn::RegionRestrictedConvolutionForward>();
opr->param() = param;
TensorLayout dest_layout;
opr->deduce_layout(
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(),
dest_layout);
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes(
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(),
dest_layout));
dest[0].dtype(inp[0]->dtype())
.comp_node(inp[0]->comp_node())
.resize(dest_layout);
opr->exec(
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(),
inp[3]->as_megdnn(), dest[0].as_megdnn(),
{workspace.data(), workspace.size()});
};

Checker(make_graph, fwd, CompNode::load("cpu0"))
.set_input_dtype(0, dtype::Float32())
.set_input_dtype(1, dtype::Float32())
.set_input_dtype(2, dtype::Int32())
.set_input_dtype(3, dtype::Int32())
.set_input_generator(0, float_gen)
.set_input_generator(1, float_gen)
.set_input_generator(2, mask_gen)
.set_input_generator(3, mask_gen)
.set_input_allow_grad(2, false)
.set_input_allow_grad(3, false)
// {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow}
.run({TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 2},
TensorShape{1, 2, 2}, TensorShape{1, 1, 1}},
option)
.run({TensorShape{1, 2, 3, 3}, TensorShape{1, 2, 3, 3},
TensorShape{1, 3, 3}, TensorShape{1, 1, 1}},
option)
.run({TensorShape{1, 1, 4, 4}, TensorShape{1, 1, 2, 2},
TensorShape{1, 4, 4}, TensorShape{1, 3, 3}},
option)
.run({TensorShape{2, 2, 8, 8}, TensorShape{4, 2, 2, 2},
TensorShape{2, 8, 8}, TensorShape{2, 7, 7}},
option)
.run({TensorShape{4, 4, 8, 8}, TensorShape{4, 4, 2, 2},
TensorShape{4, 8, 8}, TensorShape{4, 7, 7}},
option);
}

#if MGB_CUDA
TEST(TestOprDNN, REGIONCONV_FWD_GPU_WRAPPER) {
using Checker = AutoOprChecker<4, 1>;
megdnn::RegionRestrictedConvolution::Param param;
param.sparse = opr::RegionRestrictedConvolution::Param::Sparse::GROUP;

auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
return {opr::RegionRestrictedConvolutionForward::make(
inputs[0], inputs[1], inputs[2], inputs[3], param)};
};

Checker::RunOptions option;
option.numdiff_eps = 0.1;
option.numdiff_max_err = 1e-2;

auto mask_gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Int32, RandomDistribution::CONSTANT> gen(1);
src = *gen(src.shape(), src.comp_node());
};
auto uint8_mask_gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Uint8, RandomDistribution::CONSTANT> gen(1);
src = *gen(src.shape(), src.comp_node());
};
auto float_gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> gen;
src = *gen(src.shape(), src.comp_node());
};

auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr =
megdnn_naive_handle()
->create_operator<megdnn::RegionRestrictedConvolutionForward>();
opr->param() = param;
TensorLayout dest_layout;
opr->deduce_layout(
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(),
dest_layout);
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes(
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(),
dest_layout));
dest[0].dtype(inp[0]->dtype())
.comp_node(inp[0]->comp_node())
.resize(dest_layout);
opr->exec(
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(),
inp[3]->as_megdnn(), dest[0].as_megdnn(),
{workspace.data(), workspace.size()});
};

Checker(make_graph, fwd, CompNode::load("gpu0"))
.set_input_dtype(0, dtype::Float32())
.set_input_dtype(1, dtype::Float32())
.set_input_dtype(2, dtype::Int32())
.set_input_dtype(3, dtype::Int32())
.set_input_generator(0, float_gen)
.set_input_generator(1, float_gen)
.set_input_generator(2, mask_gen)
.set_input_generator(3, mask_gen)
.set_input_allow_grad(2, false)
.set_input_allow_grad(3, false)
// {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow}
.run({TensorShape{1, 2, 2, 2}, TensorShape{2, 1, 1, 2, 2},
TensorShape{1, 2, 2}, TensorShape{1, 1, 1}},
option)
.run({TensorShape{1, 2, 3, 3}, TensorShape{2, 1, 1, 3, 3},
TensorShape{1, 3, 3}, TensorShape{1, 1, 1}},
option)
.run({TensorShape{1, 4, 4, 4}, TensorShape{4, 1, 1, 2, 2},
TensorShape{1, 4, 4}, TensorShape{1, 3, 3}},
option)
.run({TensorShape{2, 4, 8, 8}, TensorShape{4, 1, 1, 2, 2},
TensorShape{2, 8, 8}, TensorShape{2, 7, 7}},
option);

Checker(make_graph, fwd, CompNode::load("gpu0"))
.set_input_dtype(0, dtype::Float32())
.set_input_dtype(1, dtype::Float32())
.set_input_dtype(2, dtype::Uint8())
.set_input_dtype(3, dtype::Uint8())
.set_input_generator(0, float_gen)
.set_input_generator(1, float_gen)
.set_input_generator(2, uint8_mask_gen)
.set_input_generator(3, uint8_mask_gen)
.set_input_allow_grad(2, false)
.set_input_allow_grad(3, false)
// {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow}
.run({TensorShape{1, 2, 4, 4}, TensorShape{2, 1, 1, 1, 1},
TensorShape{1, 4, 4}, TensorShape{1, 4, 4}},
option)
.run({TensorShape{1, 2, 8, 8}, TensorShape{2, 1, 1, 1, 1},
TensorShape{1, 8, 8}, TensorShape{1, 8, 8}},
option)
.run({TensorShape{1, 4, 8, 8}, TensorShape{4, 1, 1, 5, 5},
TensorShape{1, 8, 8}, TensorShape{1, 4, 4}},
option)
.run({TensorShape{2, 4, 8, 8}, TensorShape{4, 1, 1, 1, 1},
TensorShape{2, 8, 8}, TensorShape{2, 8, 8}},
option);
}
#endif

Loading…
Cancel
Save