GitOrigin-RevId: c0106ade08
release-1.11.1
@@ -91,6 +91,7 @@ __all__ = [ | |||
"warp_affine", | |||
"warp_perspective", | |||
"pixel_shuffle", | |||
"region_restricted_conv", | |||
] | |||
@@ -1213,10 +1214,10 @@ def layer_norm( | |||
): | |||
r"""Applies layer normalization to the input. Support tensor of any shape as input. | |||
Reference: https://arxiv.org/pdf/1803.08494.pdf. | |||
Args: | |||
inp: input tensor. | |||
normalized_shape: the shape that you want to be normalizated | |||
normalized_shape: the shape that you want to be normalizated | |||
affine: whether to use weight and bias | |||
weight: must not be None when the affine is true | |||
bias: must not be None when the affine is true | |||
@@ -1974,6 +1975,61 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable) | |||
def region_restricted_conv( | |||
inp: Tensor, | |||
weight: Tensor, | |||
rin: Tensor, | |||
rout: Tensor, | |||
stride: Union[int, Tuple[int, int, int]] = 1, | |||
padding: Union[int, Tuple[int, int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int, int]] = 1, | |||
groups: int = 1, | |||
conv_mode: str = "cross_correlation", | |||
compute_mode="default", | |||
) -> Tensor: | |||
r"""Region Restricted convolution operation. | |||
Refer to :class:`~.RegionRestrictedConv` for more information. | |||
Args: | |||
inp: feature map of the convolution operation. | |||
weight: convolution kernel. | |||
stride: stride of the 2D region restricted convolution operation. Default: 1 | |||
padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
dilation: dilation of the 2D convolution operation. Default: 1 | |||
groups: number of groups into which the input and output channels are divided, | |||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, depth, height, width)``. Default: 1 | |||
conv_mode: supports "cross_correlation". Default: "cross_correlation" | |||
Returns: | |||
output tensor. | |||
""" | |||
assert conv_mode.lower() == "cross_correlation" | |||
pad_h, pad_w = _expand_hw(padding) | |||
stride_h, stride_w = _expand_hw(stride) | |||
dilate_h, dilate_w = _expand_hw(dilation) | |||
sparse_type = "dense" if groups == 1 else "group" | |||
op = builtin.RegionRestrictedConvolution( | |||
stride_h=stride_h, | |||
stride_w=stride_w, | |||
pad_h=pad_h, | |||
pad_w=pad_w, | |||
dilate_h=dilate_h, | |||
dilate_w=dilate_w, | |||
mode=conv_mode, | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
) | |||
(output,) = apply(op, inp, weight, rin, rout) | |||
return output | |||
from .quantized import conv_bias_activation # isort:skip | |||
from .loss import * # isort:skip | |||
from .vision import * # isort:skip |
@@ -14,6 +14,7 @@ from .conv import ( | |||
ConvTranspose3d, | |||
DeformableConv2d, | |||
LocalConv2d, | |||
RegionRestrictedConv, | |||
) | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .deformable_psroi_pooling import DeformablePSROIPooling | |||
@@ -12,6 +12,7 @@ from ..functional import ( | |||
deformable_conv2d, | |||
local_conv2d, | |||
pad, | |||
region_restricted_conv, | |||
relu, | |||
) | |||
from ..tensor import Parameter | |||
@@ -982,3 +983,174 @@ class ConvTranspose3d(_ConvNd): | |||
self.output_padding, | |||
self.dilation, | |||
) | |||
class RegionRestrictedConv(_ConvNd): | |||
r"""Applies a 2D RegionRestricted Convolution over an input tensor. | |||
For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`, | |||
this layer generates an output of the size | |||
:math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the | |||
process described as below: | |||
.. math:: | |||
\text{out}(N_i, C_{\text{out}_j}) = | |||
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||
where :math:`\star` is the valid 2D cross-correlation operator, | |||
:math:`N` is batch size, :math:`C` denotes number of channels, | |||
:math:`H` is height of input planes in pixels, and :math:`W` is | |||
width in pixels. | |||
In general, output feature maps' shapes can be inferred as follows: | |||
input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})` | |||
output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where | |||
.. math:: | |||
\text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} - | |||
\text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor | |||
.. math:: | |||
\text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} - | |||
\text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor | |||
When `groups == in_channels` and `out_channels == K * in_channels`, | |||
where K is a positive integer, this operation is also known as depthwise | |||
convolution. | |||
In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||
a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||
Args: | |||
in_channels: number of input channels. | |||
out_channels: number of output channels. | |||
kernel_size: size of weight on spatial dimensions. If kernel_size is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. | |||
stride: stride of the 2D convolution operation. Default: 1 | |||
padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Default: 0 | |||
dilation: dilation of the 2D convolution operation. Default: 1 | |||
groups: number of groups into which the input and output channels are divided, | |||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, height, width)``. Default: 1 | |||
conv_mode: Supports `cross_correlation`. Default: `cross_correlation` | |||
compute_mode: When set to "default", no special requirements will be | |||
placed on the precision of intermediate results. When set to "float32", | |||
"float32" would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". | |||
Refer to :class:`~.module.padding.Pad` for more information. | |||
Note: | |||
* ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` , | |||
if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, height, width)`` | |||
Examples: | |||
>>> import numpy as np | |||
>>> import megengine as mge | |||
>>> import megengine.module as M | |||
>>> rrconv = M.RegionRestrictedConv(in_channels=2, out_channels=2, kernel_size=2, groups=2) | |||
>>> inp = mge.tensor(np.random.randn(1, 2, 2, 2).astype(np.float32)) | |||
>>> rin = mge.tensor(np.random.randn(1, 2, 2).astype(np.int32)) | |||
>>> rout = mge.tensor(np.random.randn(1, 1, 1).astype(np.int32)) | |||
>>> oup = rrconv(inp, rin, rout) | |||
>>> oup.numpy().shape | |||
(1, 2, 1, 1) | |||
""" | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
groups: int, | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
conv_mode: str = "cross_correlation", | |||
compute_mode: str = "default", | |||
padding_mode: str = "zeros", | |||
**kwargs | |||
): | |||
kernel_size = _pair_nonzero(kernel_size) | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = conv_mode | |||
self.compute_mode = compute_mode | |||
self.padding_mode = padding_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
0, | |||
dilation, | |||
groups, | |||
False, | |||
**kwargs, | |||
) | |||
def _get_fanin(self): | |||
kh, kw = self.kernel_size | |||
ic = self.in_channels | |||
return kh * kw * ic | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
ichl = self.in_channels | |||
ochl = self.out_channels | |||
kh, kw = self.kernel_size | |||
if group == 1: | |||
# Assume format is NCHW | |||
return (ochl, ichl, kh, kw) | |||
assert ( | |||
ichl % group == 0 and ochl % group == 0 | |||
), "invalid config: in_channels={} out_channels={} group={}".format( | |||
ichl, ochl, group | |||
) | |||
# Assume format is NCHW | |||
return (group, ochl // group, ichl // group, kh, kw) | |||
def _infer_bias_shape(self): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def get_pad_width(self): | |||
return ( | |||
(0, 0), | |||
(0, 0), | |||
(self.padding[0], self.padding[0]), | |||
(self.padding[1], self.padding[1]), | |||
) | |||
def calc_conv(self, inp, weight, rin, rout): | |||
assert self.padding_mode in [ | |||
"zeros", | |||
"reflect", | |||
"replicate", | |||
] | |||
return region_restricted_conv( | |||
inp, | |||
weight, | |||
rin, | |||
rout, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
def forward(self, inp, rin, rout): | |||
return self.calc_conv(inp, self.weight, rin, rout) |
@@ -930,6 +930,179 @@ def test_batch_conv_bias(): | |||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||
def test_region_restricted_conv_forward_backward_naive(): | |||
import megengine as mge | |||
import megengine.module as M | |||
from megengine.autodiff import GradManager | |||
handle = "cpu0" | |||
src_1 = np.arange(8).reshape(1, 2, 2, 2).astype(np.float32) | |||
filter_1 = np.arange(8).reshape(2, 1, 1, 2, 2).astype(np.float32) | |||
rin_1 = np.array([1, 1, 1, 1]).reshape(1, 2, 2).astype(np.int32) | |||
rout_1 = np.array([1]).reshape(1, 1, 1).astype(np.int32) | |||
cpu_src = tensor(src_1, device=handle) | |||
cpu_filter = tensor(filter_1, device=handle) | |||
gm = GradManager().attach([cpu_src, cpu_filter]) | |||
with gm: | |||
cpu_out = F.region_restricted_conv( | |||
cpu_src, | |||
cpu_filter, | |||
tensor(rin_1, device=handle), | |||
tensor(rout_1, device=handle), | |||
groups=2, | |||
) | |||
gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle)) | |||
np.testing.assert_allclose( | |||
cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2) | |||
) | |||
np.testing.assert_allclose( | |||
cpu_filter.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(2, 1, 1, 2, 2) | |||
) | |||
@pytest.mark.skipif( | |||
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | |||
) | |||
def test_region_restricted_conv_forward_backward_cuda(): | |||
import megengine as mge | |||
import megengine.module as M | |||
from megengine.autodiff import GradManager | |||
import megengine.distributed as dist | |||
# params | |||
handle = "gpu0" | |||
N = 1 | |||
GROUP = 3 | |||
FH = FW = 2 | |||
IH = IW = 2 | |||
OH = OW = 1 | |||
ICPG = OCPG = 1 | |||
grad_shape = (N, GROUP * ICPG, IH, IW) | |||
src_shape = grad_shape | |||
filter_shape = (GROUP, OCPG, ICPG, FH, FW) | |||
diff_shape = (N, GROUP * OCPG, OH, OW) | |||
rin_shape = (N, IH, IW) | |||
rout_shape = (N, OH, OW) | |||
def reduce(shape): | |||
mul = 1 | |||
for x in shape: | |||
mul *= x | |||
return mul | |||
def get_groundtruth(): | |||
src = tensor( | |||
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||
device="cpu0", | |||
) | |||
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | |||
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | |||
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | |||
gm = GradManager().attach([src, filter]) | |||
with gm: | |||
expected_out = F.region_restricted_conv( | |||
src, filter, rin, rout, groups=GROUP | |||
) | |||
gm.backward( | |||
expected_out, | |||
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | |||
) | |||
return src, filter | |||
expected_src, expected_filter = get_groundtruth() | |||
src = tensor( | |||
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||
device=handle, | |||
) | |||
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | |||
rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle) | |||
rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle) | |||
gm = GradManager().attach([src, filter]) | |||
with gm: | |||
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP) | |||
gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle)) | |||
np.testing.assert_allclose(src.grad, expected_src.grad) | |||
np.testing.assert_allclose(filter.grad, expected_filter.grad) | |||
@pytest.mark.skipif( | |||
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | |||
) | |||
def test_region_restricted_conv_forward_backward_uint8(): | |||
import megengine as mge | |||
import megengine.module as M | |||
from megengine.autodiff import GradManager | |||
# params | |||
handle = "gpu0" | |||
N = 1 | |||
GROUP = 2 | |||
FH = FW = 1 | |||
IH = IW = 4 | |||
OH = OW = 4 | |||
ICPG = OCPG = 1 | |||
grad_shape = (N, GROUP * ICPG, IH, IW) | |||
src_shape = grad_shape | |||
filter_shape = (GROUP, OCPG, ICPG, FH, FW) | |||
diff_shape = (N, GROUP * OCPG, OH, OW) | |||
rin_shape = (N, IH, IW) | |||
rout_shape = (N, OH, OW) | |||
def reduce(shape): | |||
mul = 1 | |||
for x in shape: | |||
mul *= x | |||
return mul | |||
def get_groundtruth(): | |||
src = tensor( | |||
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||
device="cpu0", | |||
) | |||
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | |||
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | |||
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | |||
gm = GradManager().attach([src, filter]) | |||
with gm: | |||
expected_out = F.region_restricted_conv( | |||
src, filter, rin, rout, groups=GROUP | |||
) | |||
gm.backward( | |||
expected_out, | |||
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | |||
) | |||
return src, filter | |||
expected_src, expected_filter = get_groundtruth() | |||
# forward and dgrad/wgrad | |||
src = tensor( | |||
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||
device=handle, | |||
) | |||
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | |||
rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle) | |||
rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle) | |||
gm = GradManager().attach([src, filter]) | |||
with gm: | |||
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP) | |||
gm.backward( | |||
gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle) | |||
) | |||
# assert uint8 gpu result close to cpu result | |||
np.testing.assert_allclose(src.grad, expected_src.grad) | |||
np.testing.assert_allclose(filter.grad, expected_filter.grad) | |||
def test_region_restricted_conv(): | |||
test_region_restricted_conv_forward_backward_naive() | |||
if is_cuda_available(): | |||
test_region_restricted_conv_forward_backward_cuda() | |||
test_region_restricted_conv_forward_backward_uint8() | |||
def test_conv2d_autocast(): | |||
"""check amp's result is equal to manually converted result""" | |||
amp.enabled = True | |||
@@ -3,9 +3,11 @@ | |||
#include "../blob_manager_impl.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megbrain/opr/tensor_gen.h" | |||
#include "megdnn/oprs/nn.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -356,5 +358,174 @@ OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) | |||
} // namespace convolution3d_backward_data | |||
} // namespace | |||
namespace { | |||
namespace region_restricted_conv { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::RegionRestrictedConvolution>(); | |||
return RegionRestrictedConvolution::make(node->param()); | |||
} | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const RegionRestrictedConvolution&>(def); | |||
OperatorNodeConfig config{conv.make_name()}; | |||
return opr::RegionRestrictedConvolution::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), config); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& region_restricted_conv = | |||
def.cast_final_safe<mgb::imperative::RegionRestrictedConvolution>(); | |||
DnnOprHelper<megdnn::RegionRestrictedConvolutionForward> dnn_opr( | |||
region_restricted_conv.param()); | |||
auto&& src = inputs[0].layout; | |||
auto&& filter = inputs[1].layout; | |||
auto&& rin = inputs[2].layout; | |||
auto&& rout = inputs[3].layout; | |||
TensorLayout output_layout{src.dtype}; | |||
if (src.ndim && filter.ndim) { | |||
dnn_opr.opr().deduce_layout(src, filter, rin, rout, output_layout); | |||
} | |||
return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
// create megdnn opr | |||
auto&& region_restricted_conv = def.cast_final_safe<RegionRestrictedConvolution>(); | |||
CompNode cn = inputs[0]->comp_node(); | |||
auto&& param = region_restricted_conv.param(); | |||
DnnOprCaller<megdnn::RegionRestrictedConvolutionForward> dnn_opr(cn, param); | |||
auto srclo = inputs[0]->layout(); | |||
auto filterlo = inputs[1]->layout(); | |||
auto rinlo = inputs[2]->layout(); | |||
auto routlo = inputs[3]->layout(); | |||
auto out_layout = [&] { | |||
if (validated) { | |||
return output_descs[0].layout; | |||
} else { | |||
TensorLayout out_layout{inputs[0]->dtype()}; | |||
dnn_opr.op()->deduce_layout(srclo, filterlo, rinlo, routlo, out_layout); | |||
return out_layout; | |||
} | |||
}(); | |||
auto out = Tensor::make(out_layout, cn); | |||
dnn_opr.exec_with_ws(inputs[0], inputs[1], inputs[2], inputs[3], out); | |||
return {out}; | |||
} | |||
OP_TRAIT_REG( | |||
RegionRestrictedConvolution, RegionRestrictedConvolution, | |||
opr::RegionRestrictedConvolution) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.fallback(); | |||
} // namespace region_restricted_conv | |||
} // namespace | |||
namespace { | |||
namespace region_restricted_conv_backward_data { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = | |||
&node_->cast_final_safe<opr::RegionRestrictedConvolutionBackwardData>(); | |||
return RegionRestrictedConvolutionBackwardData::make(node->param()); | |||
} | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const RegionRestrictedConvolutionBackwardData&>(def); | |||
OperatorNodeConfig config{conv.make_name()}; | |||
// output_dtype may infered from input within rrconv bwd data(deduce_dtype api) | |||
CompNode cn = inputs[0]->comp_node(); | |||
DType output_dtype; | |||
DnnOprCaller<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr(cn); | |||
dnn_opr.op()->deduce_dtype( | |||
inputs[0]->dtype(), inputs[1]->dtype(), inputs[2]->dtype(), | |||
inputs[3]->dtype(), output_dtype); | |||
if (output_dtype.valid()) | |||
config.output_dtype(output_dtype); | |||
if (inputs.size() == 4) { | |||
return opr::RegionRestrictedConvolutionBackwardData::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), config); | |||
} else if (inputs.size() == 5) { | |||
return opr::RegionRestrictedConvolutionBackwardData::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], conv.param(), | |||
config); | |||
} | |||
mgb_assert(0); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& convbwd = def.cast_final_safe< | |||
mgb::imperative::RegionRestrictedConvolutionBackwardData>(); | |||
DnnOprHelper<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr( | |||
convbwd.param()); | |||
TensorLayout filter = inputs[0].layout; | |||
TensorLayout diff = inputs[1].layout; | |||
TensorLayout rin = inputs[2].layout; | |||
TensorLayout rout = inputs[3].layout; | |||
DType output_dtype; | |||
dnn_opr.opr().deduce_dtype( | |||
inputs[0].layout.dtype, inputs[1].layout.dtype, inputs[2].layout.dtype, | |||
inputs[3].layout.dtype, output_dtype); | |||
TensorLayout output_layout{output_dtype}; | |||
if (diff.ndim && filter.ndim) { | |||
dnn_opr.opr().deduce_layout(filter, diff, rin, rout, output_layout); | |||
} | |||
return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& convbwd = def.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||
CompNode cn = inputs[0]->comp_node(); | |||
DnnOprCaller<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr( | |||
cn, convbwd.param()); | |||
auto filterlo = inputs[0]->layout(); | |||
auto difflo = inputs[1]->layout(); | |||
auto rinlo = inputs[2]->layout(); | |||
auto routlo = inputs[3]->layout(); | |||
auto out_layout = [&] { | |||
if (validated) { | |||
return output_descs[0].layout; | |||
} else { | |||
TensorLayout out_layout{inputs[0]->dtype()}; | |||
dnn_opr.op()->deduce_layout(filterlo, difflo, rinlo, routlo, out_layout); | |||
return out_layout; | |||
} | |||
}(); | |||
auto out = Tensor::make(out_layout, cn); | |||
dnn_opr.exec_with_ws(inputs[0], inputs[1], inputs[2], inputs[3], out); | |||
return {out}; | |||
} | |||
OP_TRAIT_REG( | |||
RegionRestrictedConvolutionBackwardData, | |||
RegionRestrictedConvolutionBackwardData, | |||
opr::RegionRestrictedConvolutionBackwardData) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.fallback(); | |||
} // namespace region_restricted_conv_backward_data | |||
} // namespace | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -1,7 +1,7 @@ | |||
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | |||
40708c56b1f05fdb7d06cc097a300330 ../../src/core/include/megbrain/ir/ops.td | |||
9f3af118c7fe8d0c9db433825d5ad77b generated/opdef.h.inl | |||
4041e44a8ba3cca3b3affa1ed9ed44a2 generated/opdef.cpp.inl | |||
319e1d170c989fe793a4e9c45decefc4 generated/opdef.py.inl | |||
26a18a7593566128ecce76e8f74dcc5d generated/opdef.cpy.inl | |||
da03ffe2a15411f902cd88920d3d47ec ../../src/core/include/megbrain/ir/ops.td | |||
5756619f37e4dc130e1b049d7706d4eb generated/opdef.h.inl | |||
98d1291eed73970ee087f898b6241358 generated/opdef.cpp.inl | |||
b1a9c7569392942294c2168d40939eb5 generated/opdef.py.inl | |||
3d88d5358d15a39219957f5257e32f5b generated/opdef.cpy.inl | |||
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h |
@@ -5694,6 +5694,310 @@ OP_TRAIT_REG(Reduce, Reduce) | |||
.props(Reduce_props_impl) | |||
.make_name(Reduce_make_name_impl); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RegionRestrictedConvolution); | |||
namespace { | |||
size_t RegionRestrictedConvolution_hash_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>(); | |||
static_cast<void>(op_); | |||
size_t val = mgb::hash(op_.dyn_typeinfo()); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.mode)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_h)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_w)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_h)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_w)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_h)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_w)); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.sparse)); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.compute_mode)); | |||
return val; | |||
} | |||
bool RegionRestrictedConvolution_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||
auto &&a_ = lhs_.cast_final_safe<RegionRestrictedConvolution>(), | |||
&&b_ = rhs_.cast_final_safe<RegionRestrictedConvolution>(); | |||
static_cast<void>(a_); | |||
static_cast<void>(b_); | |||
if (a_.mode != b_.mode) return false; | |||
if (a_.pad_h != b_.pad_h) return false; | |||
if (a_.pad_w != b_.pad_w) return false; | |||
if (a_.stride_h != b_.stride_h) return false; | |||
if (a_.stride_w != b_.stride_w) return false; | |||
if (a_.dilate_h != b_.dilate_h) return false; | |||
if (a_.dilate_w != b_.dilate_w) return false; | |||
if (a_.sparse != b_.sparse) return false; | |||
if (a_.format != b_.format) return false; | |||
if (a_.compute_mode != b_.compute_mode) return false; | |||
return true; | |||
} | |||
std::vector<std::pair<const char*, std::string>> RegionRestrictedConvolution_props_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>(); | |||
static_cast<void>(op_); | |||
std::vector<std::pair<const char*, std::string>> props_; | |||
switch (op_.mode){ | |||
case RegionRestrictedConvolution::Mode::CROSS_CORRELATION: | |||
props_.emplace_back("mode", "CROSS_CORRELATION"); | |||
break; | |||
case RegionRestrictedConvolution::Mode::CONVOLUTION: | |||
props_.emplace_back("mode", "CONVOLUTION"); | |||
break; | |||
default: | |||
props_.emplace_back("mode", "INVALID"); | |||
break; | |||
} | |||
props_.emplace_back("pad_h", std::to_string(op_.pad_h)); | |||
props_.emplace_back("pad_w", std::to_string(op_.pad_w)); | |||
props_.emplace_back("stride_h", std::to_string(op_.stride_h)); | |||
props_.emplace_back("stride_w", std::to_string(op_.stride_w)); | |||
props_.emplace_back("dilate_h", std::to_string(op_.dilate_h)); | |||
props_.emplace_back("dilate_w", std::to_string(op_.dilate_w)); | |||
switch (op_.sparse){ | |||
case RegionRestrictedConvolution::Sparse::DENSE: | |||
props_.emplace_back("sparse", "DENSE"); | |||
break; | |||
case RegionRestrictedConvolution::Sparse::GROUP: | |||
props_.emplace_back("sparse", "GROUP"); | |||
break; | |||
default: | |||
props_.emplace_back("sparse", "INVALID"); | |||
break; | |||
} | |||
switch (op_.format){ | |||
case RegionRestrictedConvolution::Format::NCHW: | |||
props_.emplace_back("format", "NCHW"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NHWC: | |||
props_.emplace_back("format", "NHWC"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NHWCD4: | |||
props_.emplace_back("format", "NHWCD4"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW4: | |||
props_.emplace_back("format", "NCHW4"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW8: | |||
props_.emplace_back("format", "NCHW8"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW32: | |||
props_.emplace_back("format", "NCHW32"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW88: | |||
props_.emplace_back("format", "NCHW88"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW44: | |||
props_.emplace_back("format", "NCHW44"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW44_DOT: | |||
props_.emplace_back("format", "NCHW44_DOT"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW4_NCHW32: | |||
props_.emplace_back("format", "NCHW4_NCHW32"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW32_NCHW4: | |||
props_.emplace_back("format", "NCHW32_NCHW4"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW4_NCHW: | |||
props_.emplace_back("format", "NCHW4_NCHW"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NHWC_NCHW: | |||
props_.emplace_back("format", "NHWC_NCHW"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NHWC_NCHW4_IC_SMALL: | |||
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW_NCHW4_IC_SMALL: | |||
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); | |||
break; | |||
case RegionRestrictedConvolution::Format::CHWN4: | |||
props_.emplace_back("format", "CHWN4"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW64: | |||
props_.emplace_back("format", "NCHW64"); | |||
break; | |||
case RegionRestrictedConvolution::Format::NCHW4_NHWC: | |||
props_.emplace_back("format", "NCHW4_NHWC"); | |||
break; | |||
default: | |||
props_.emplace_back("format", "INVALID"); | |||
break; | |||
} | |||
switch (op_.compute_mode){ | |||
case RegionRestrictedConvolution::ComputeMode::DEFAULT: | |||
props_.emplace_back("compute_mode", "DEFAULT"); | |||
break; | |||
case RegionRestrictedConvolution::ComputeMode::FLOAT32: | |||
props_.emplace_back("compute_mode", "FLOAT32"); | |||
break; | |||
default: | |||
props_.emplace_back("compute_mode", "INVALID"); | |||
break; | |||
} | |||
return props_; | |||
} | |||
std::string RegionRestrictedConvolution_make_name_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>(); | |||
static_cast<void>(op_); | |||
return "RegionRestrictedConvolution"; | |||
} | |||
} // anonymous namespace | |||
OP_TRAIT_REG(RegionRestrictedConvolution, RegionRestrictedConvolution) | |||
.hash(RegionRestrictedConvolution_hash_impl) | |||
.is_same_st(RegionRestrictedConvolution_is_same_st_impl) | |||
.props(RegionRestrictedConvolution_props_impl) | |||
.make_name(RegionRestrictedConvolution_make_name_impl); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RegionRestrictedConvolutionBackwardData); | |||
namespace { | |||
size_t RegionRestrictedConvolutionBackwardData_hash_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||
static_cast<void>(op_); | |||
size_t val = mgb::hash(op_.dyn_typeinfo()); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.mode)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_h)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_w)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_h)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_w)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_h)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_w)); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.sparse)); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.compute_mode)); | |||
return val; | |||
} | |||
bool RegionRestrictedConvolutionBackwardData_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||
auto &&a_ = lhs_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(), | |||
&&b_ = rhs_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||
static_cast<void>(a_); | |||
static_cast<void>(b_); | |||
if (a_.mode != b_.mode) return false; | |||
if (a_.pad_h != b_.pad_h) return false; | |||
if (a_.pad_w != b_.pad_w) return false; | |||
if (a_.stride_h != b_.stride_h) return false; | |||
if (a_.stride_w != b_.stride_w) return false; | |||
if (a_.dilate_h != b_.dilate_h) return false; | |||
if (a_.dilate_w != b_.dilate_w) return false; | |||
if (a_.sparse != b_.sparse) return false; | |||
if (a_.format != b_.format) return false; | |||
if (a_.compute_mode != b_.compute_mode) return false; | |||
return true; | |||
} | |||
std::vector<std::pair<const char*, std::string>> RegionRestrictedConvolutionBackwardData_props_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||
static_cast<void>(op_); | |||
std::vector<std::pair<const char*, std::string>> props_; | |||
switch (op_.mode){ | |||
case RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION: | |||
props_.emplace_back("mode", "CROSS_CORRELATION"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Mode::CONVOLUTION: | |||
props_.emplace_back("mode", "CONVOLUTION"); | |||
break; | |||
default: | |||
props_.emplace_back("mode", "INVALID"); | |||
break; | |||
} | |||
props_.emplace_back("pad_h", std::to_string(op_.pad_h)); | |||
props_.emplace_back("pad_w", std::to_string(op_.pad_w)); | |||
props_.emplace_back("stride_h", std::to_string(op_.stride_h)); | |||
props_.emplace_back("stride_w", std::to_string(op_.stride_w)); | |||
props_.emplace_back("dilate_h", std::to_string(op_.dilate_h)); | |||
props_.emplace_back("dilate_w", std::to_string(op_.dilate_w)); | |||
switch (op_.sparse){ | |||
case RegionRestrictedConvolutionBackwardData::Sparse::DENSE: | |||
props_.emplace_back("sparse", "DENSE"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Sparse::GROUP: | |||
props_.emplace_back("sparse", "GROUP"); | |||
break; | |||
default: | |||
props_.emplace_back("sparse", "INVALID"); | |||
break; | |||
} | |||
switch (op_.format){ | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW: | |||
props_.emplace_back("format", "NCHW"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NHWC: | |||
props_.emplace_back("format", "NHWC"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NHWCD4: | |||
props_.emplace_back("format", "NHWCD4"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW4: | |||
props_.emplace_back("format", "NCHW4"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW8: | |||
props_.emplace_back("format", "NCHW8"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW32: | |||
props_.emplace_back("format", "NCHW32"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW88: | |||
props_.emplace_back("format", "NCHW88"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW44: | |||
props_.emplace_back("format", "NCHW44"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW44_DOT: | |||
props_.emplace_back("format", "NCHW44_DOT"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NCHW32: | |||
props_.emplace_back("format", "NCHW4_NCHW32"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW32_NCHW4: | |||
props_.emplace_back("format", "NCHW32_NCHW4"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NCHW: | |||
props_.emplace_back("format", "NCHW4_NCHW"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NHWC_NCHW: | |||
props_.emplace_back("format", "NHWC_NCHW"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NHWC_NCHW4_IC_SMALL: | |||
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW_NCHW4_IC_SMALL: | |||
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::CHWN4: | |||
props_.emplace_back("format", "CHWN4"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW64: | |||
props_.emplace_back("format", "NCHW64"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NHWC: | |||
props_.emplace_back("format", "NCHW4_NHWC"); | |||
break; | |||
default: | |||
props_.emplace_back("format", "INVALID"); | |||
break; | |||
} | |||
switch (op_.compute_mode){ | |||
case RegionRestrictedConvolutionBackwardData::ComputeMode::DEFAULT: | |||
props_.emplace_back("compute_mode", "DEFAULT"); | |||
break; | |||
case RegionRestrictedConvolutionBackwardData::ComputeMode::FLOAT32: | |||
props_.emplace_back("compute_mode", "FLOAT32"); | |||
break; | |||
default: | |||
props_.emplace_back("compute_mode", "INVALID"); | |||
break; | |||
} | |||
return props_; | |||
} | |||
std::string RegionRestrictedConvolutionBackwardData_make_name_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||
static_cast<void>(op_); | |||
return "RegionRestrictedConvolutionBackwardData"; | |||
} | |||
} // anonymous namespace | |||
OP_TRAIT_REG(RegionRestrictedConvolutionBackwardData, RegionRestrictedConvolutionBackwardData) | |||
.hash(RegionRestrictedConvolutionBackwardData_hash_impl) | |||
.is_same_st(RegionRestrictedConvolutionBackwardData_is_same_st_impl) | |||
.props(RegionRestrictedConvolutionBackwardData_props_impl) | |||
.make_name(RegionRestrictedConvolutionBackwardData_make_name_impl); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Remap); | |||
namespace { | |||
@@ -15368,6 +15368,580 @@ void _init_py_Reduce(py::module m) { | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Reduce::typeinfo(), &py_type).second); | |||
} | |||
void _init_py_RegionRestrictedConvolution_Mode(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolution::Mode>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "Mode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
void _init_py_RegionRestrictedConvolution_Sparse(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolution::Sparse>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "Sparse", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
void _init_py_RegionRestrictedConvolution_Format(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolution::Format>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
void _init_py_RegionRestrictedConvolution_ComputeMode(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolution::ComputeMode>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "ComputeMode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
PyOpDefBegin(RegionRestrictedConvolution) // { | |||
static PyGetSetDef py_getsetters[]; | |||
static PyMethodDef tp_methods[]; | |||
static PyObject* getstate(PyObject* self, PyObject*) { | |||
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
std::unordered_map<std::string, py::object> state { | |||
{"mode", serialization<decltype(opdef.mode)>::dump(opdef.mode)}, | |||
{"pad_h", serialization<decltype(opdef.pad_h)>::dump(opdef.pad_h)}, | |||
{"pad_w", serialization<decltype(opdef.pad_w)>::dump(opdef.pad_w)}, | |||
{"stride_h", serialization<decltype(opdef.stride_h)>::dump(opdef.stride_h)}, | |||
{"stride_w", serialization<decltype(opdef.stride_w)>::dump(opdef.stride_w)}, | |||
{"dilate_h", serialization<decltype(opdef.dilate_h)>::dump(opdef.dilate_h)}, | |||
{"dilate_w", serialization<decltype(opdef.dilate_w)>::dump(opdef.dilate_w)}, | |||
{"sparse", serialization<decltype(opdef.sparse)>::dump(opdef.sparse)}, | |||
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)}, | |||
{"compute_mode", serialization<decltype(opdef.compute_mode)>::dump(opdef.compute_mode)} | |||
}; | |||
return py::cast(state).release().ptr(); | |||
} | |||
static PyObject* setstate(PyObject* self, PyObject* args) { | |||
PyObject* dict = PyTuple_GetItem(args, 0); | |||
if (!dict) return NULL; | |||
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
{ | |||
auto&& iter = state.find("mode"); | |||
if (iter != state.end()) { | |||
opdef.mode = serialization<decltype(opdef.mode)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("pad_h"); | |||
if (iter != state.end()) { | |||
opdef.pad_h = serialization<decltype(opdef.pad_h)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("pad_w"); | |||
if (iter != state.end()) { | |||
opdef.pad_w = serialization<decltype(opdef.pad_w)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("stride_h"); | |||
if (iter != state.end()) { | |||
opdef.stride_h = serialization<decltype(opdef.stride_h)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("stride_w"); | |||
if (iter != state.end()) { | |||
opdef.stride_w = serialization<decltype(opdef.stride_w)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("dilate_h"); | |||
if (iter != state.end()) { | |||
opdef.dilate_h = serialization<decltype(opdef.dilate_h)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("dilate_w"); | |||
if (iter != state.end()) { | |||
opdef.dilate_w = serialization<decltype(opdef.dilate_w)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("sparse"); | |||
if (iter != state.end()) { | |||
opdef.sparse = serialization<decltype(opdef.sparse)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("format"); | |||
if (iter != state.end()) { | |||
opdef.format = serialization<decltype(opdef.format)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("compute_mode"); | |||
if (iter != state.end()) { | |||
opdef.compute_mode = serialization<decltype(opdef.compute_mode)>::load(iter->second); | |||
} | |||
} | |||
Py_RETURN_NONE; | |||
} | |||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
// }; | |||
PyOpDefEnd(RegionRestrictedConvolution) | |||
int PyOp(RegionRestrictedConvolution)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||
static const char* kwlist[] = {"mode", "pad_h", "pad_w", "stride_h", "stride_w", "dilate_h", "dilate_w", "sparse", "format", "compute_mode", "scope", NULL}; | |||
PyObject *mode = NULL, *pad_h = NULL, *pad_w = NULL, *stride_h = NULL, *stride_w = NULL, *dilate_h = NULL, *dilate_w = NULL, *sparse = NULL, *format = NULL, *compute_mode = NULL, *scope = NULL; | |||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOO", const_cast<char**>(kwlist), &mode, &pad_h, &pad_w, &stride_h, &stride_w, &dilate_h, &dilate_w, &sparse, &format, &compute_mode, &scope)) | |||
return -1; | |||
if (mode) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().mode = | |||
py::cast<decltype(RegionRestrictedConvolution::mode)>(py::handle(mode)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (pad_h) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().pad_h = | |||
py::cast<decltype(RegionRestrictedConvolution::pad_h)>(py::handle(pad_h)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (pad_w) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().pad_w = | |||
py::cast<decltype(RegionRestrictedConvolution::pad_w)>(py::handle(pad_w)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (stride_h) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().stride_h = | |||
py::cast<decltype(RegionRestrictedConvolution::stride_h)>(py::handle(stride_h)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (stride_w) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().stride_w = | |||
py::cast<decltype(RegionRestrictedConvolution::stride_w)>(py::handle(stride_w)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (dilate_h) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().dilate_h = | |||
py::cast<decltype(RegionRestrictedConvolution::dilate_h)>(py::handle(dilate_h)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (dilate_w) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().dilate_w = | |||
py::cast<decltype(RegionRestrictedConvolution::dilate_w)>(py::handle(dilate_w)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (sparse) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().sparse = | |||
py::cast<decltype(RegionRestrictedConvolution::sparse)>(py::handle(sparse)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (format) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().format = | |||
py::cast<decltype(RegionRestrictedConvolution::format)>(py::handle(format)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (compute_mode) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().compute_mode = | |||
py::cast<decltype(RegionRestrictedConvolution::compute_mode)>(py::handle(compute_mode)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (scope) { | |||
try { | |||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
->set_scope(py::cast<std::string>(py::handle(scope))); | |||
} CATCH_ALL(-1) | |||
} | |||
return 0; | |||
} | |||
PyGetSetDef PyOp(RegionRestrictedConvolution)::py_getsetters[] = { | |||
{const_cast<char*>("mode"), py_get_generic(RegionRestrictedConvolution, mode), py_set_generic(RegionRestrictedConvolution, mode), const_cast<char*>("mode"), NULL}, | |||
{const_cast<char*>("pad_h"), py_get_generic(RegionRestrictedConvolution, pad_h), py_set_generic(RegionRestrictedConvolution, pad_h), const_cast<char*>("pad_h"), NULL}, | |||
{const_cast<char*>("pad_w"), py_get_generic(RegionRestrictedConvolution, pad_w), py_set_generic(RegionRestrictedConvolution, pad_w), const_cast<char*>("pad_w"), NULL}, | |||
{const_cast<char*>("stride_h"), py_get_generic(RegionRestrictedConvolution, stride_h), py_set_generic(RegionRestrictedConvolution, stride_h), const_cast<char*>("stride_h"), NULL}, | |||
{const_cast<char*>("stride_w"), py_get_generic(RegionRestrictedConvolution, stride_w), py_set_generic(RegionRestrictedConvolution, stride_w), const_cast<char*>("stride_w"), NULL}, | |||
{const_cast<char*>("dilate_h"), py_get_generic(RegionRestrictedConvolution, dilate_h), py_set_generic(RegionRestrictedConvolution, dilate_h), const_cast<char*>("dilate_h"), NULL}, | |||
{const_cast<char*>("dilate_w"), py_get_generic(RegionRestrictedConvolution, dilate_w), py_set_generic(RegionRestrictedConvolution, dilate_w), const_cast<char*>("dilate_w"), NULL}, | |||
{const_cast<char*>("sparse"), py_get_generic(RegionRestrictedConvolution, sparse), py_set_generic(RegionRestrictedConvolution, sparse), const_cast<char*>("sparse"), NULL}, | |||
{const_cast<char*>("format"), py_get_generic(RegionRestrictedConvolution, format), py_set_generic(RegionRestrictedConvolution, format), const_cast<char*>("format"), NULL}, | |||
{const_cast<char*>("compute_mode"), py_get_generic(RegionRestrictedConvolution, compute_mode), py_set_generic(RegionRestrictedConvolution, compute_mode), const_cast<char*>("compute_mode"), NULL}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
PyMethodDef PyOp(RegionRestrictedConvolution)::tp_methods[] = { | |||
{const_cast<char*>("__getstate__"), PyOp(RegionRestrictedConvolution)::getstate, METH_NOARGS, "RegionRestrictedConvolution getstate"}, | |||
{const_cast<char*>("__setstate__"), PyOp(RegionRestrictedConvolution)::setstate, METH_VARARGS, "RegionRestrictedConvolution setstate"}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
void _init_py_RegionRestrictedConvolution(py::module m) { | |||
using py_op = PyOp(RegionRestrictedConvolution); | |||
auto& py_type = PyOpType(RegionRestrictedConvolution); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.RegionRestrictedConvolution"; | |||
py_type.tp_basicsize = sizeof(PyOp(RegionRestrictedConvolution)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "RegionRestrictedConvolution"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
py_type.tp_init = py_op::py_init; | |||
py_type.tp_methods = py_op::tp_methods; | |||
py_type.tp_getset = py_op::py_getsetters; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
_init_py_RegionRestrictedConvolution_Mode(py_type); | |||
_init_py_RegionRestrictedConvolution_Sparse(py_type); | |||
_init_py_RegionRestrictedConvolution_Format(py_type); | |||
_init_py_RegionRestrictedConvolution_ComputeMode(py_type); | |||
PyType_Modified(&py_type); | |||
m.add_object("RegionRestrictedConvolution", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(RegionRestrictedConvolution::typeinfo(), &py_type).second); | |||
} | |||
void _init_py_RegionRestrictedConvolutionBackwardData_Mode(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Mode>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "Mode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
void _init_py_RegionRestrictedConvolutionBackwardData_Sparse(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Sparse>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "Sparse", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
void _init_py_RegionRestrictedConvolutionBackwardData_Format(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Format>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
void _init_py_RegionRestrictedConvolutionBackwardData_ComputeMode(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::ComputeMode>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "ComputeMode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
PyOpDefBegin(RegionRestrictedConvolutionBackwardData) // { | |||
static PyGetSetDef py_getsetters[]; | |||
static PyMethodDef tp_methods[]; | |||
static PyObject* getstate(PyObject* self, PyObject*) { | |||
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
std::unordered_map<std::string, py::object> state { | |||
{"mode", serialization<decltype(opdef.mode)>::dump(opdef.mode)}, | |||
{"pad_h", serialization<decltype(opdef.pad_h)>::dump(opdef.pad_h)}, | |||
{"pad_w", serialization<decltype(opdef.pad_w)>::dump(opdef.pad_w)}, | |||
{"stride_h", serialization<decltype(opdef.stride_h)>::dump(opdef.stride_h)}, | |||
{"stride_w", serialization<decltype(opdef.stride_w)>::dump(opdef.stride_w)}, | |||
{"dilate_h", serialization<decltype(opdef.dilate_h)>::dump(opdef.dilate_h)}, | |||
{"dilate_w", serialization<decltype(opdef.dilate_w)>::dump(opdef.dilate_w)}, | |||
{"sparse", serialization<decltype(opdef.sparse)>::dump(opdef.sparse)}, | |||
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)}, | |||
{"compute_mode", serialization<decltype(opdef.compute_mode)>::dump(opdef.compute_mode)} | |||
}; | |||
return py::cast(state).release().ptr(); | |||
} | |||
static PyObject* setstate(PyObject* self, PyObject* args) { | |||
PyObject* dict = PyTuple_GetItem(args, 0); | |||
if (!dict) return NULL; | |||
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||
auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
{ | |||
auto&& iter = state.find("mode"); | |||
if (iter != state.end()) { | |||
opdef.mode = serialization<decltype(opdef.mode)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("pad_h"); | |||
if (iter != state.end()) { | |||
opdef.pad_h = serialization<decltype(opdef.pad_h)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("pad_w"); | |||
if (iter != state.end()) { | |||
opdef.pad_w = serialization<decltype(opdef.pad_w)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("stride_h"); | |||
if (iter != state.end()) { | |||
opdef.stride_h = serialization<decltype(opdef.stride_h)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("stride_w"); | |||
if (iter != state.end()) { | |||
opdef.stride_w = serialization<decltype(opdef.stride_w)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("dilate_h"); | |||
if (iter != state.end()) { | |||
opdef.dilate_h = serialization<decltype(opdef.dilate_h)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("dilate_w"); | |||
if (iter != state.end()) { | |||
opdef.dilate_w = serialization<decltype(opdef.dilate_w)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("sparse"); | |||
if (iter != state.end()) { | |||
opdef.sparse = serialization<decltype(opdef.sparse)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("format"); | |||
if (iter != state.end()) { | |||
opdef.format = serialization<decltype(opdef.format)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("compute_mode"); | |||
if (iter != state.end()) { | |||
opdef.compute_mode = serialization<decltype(opdef.compute_mode)>::load(iter->second); | |||
} | |||
} | |||
Py_RETURN_NONE; | |||
} | |||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
// }; | |||
PyOpDefEnd(RegionRestrictedConvolutionBackwardData) | |||
int PyOp(RegionRestrictedConvolutionBackwardData)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||
static const char* kwlist[] = {"mode", "pad_h", "pad_w", "stride_h", "stride_w", "dilate_h", "dilate_w", "sparse", "format", "compute_mode", "scope", NULL}; | |||
PyObject *mode = NULL, *pad_h = NULL, *pad_w = NULL, *stride_h = NULL, *stride_w = NULL, *dilate_h = NULL, *dilate_w = NULL, *sparse = NULL, *format = NULL, *compute_mode = NULL, *scope = NULL; | |||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOO", const_cast<char**>(kwlist), &mode, &pad_h, &pad_w, &stride_h, &stride_w, &dilate_h, &dilate_w, &sparse, &format, &compute_mode, &scope)) | |||
return -1; | |||
if (mode) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().mode = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::mode)>(py::handle(mode)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (pad_h) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().pad_h = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::pad_h)>(py::handle(pad_h)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (pad_w) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().pad_w = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::pad_w)>(py::handle(pad_w)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (stride_h) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().stride_h = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::stride_h)>(py::handle(stride_h)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (stride_w) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().stride_w = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::stride_w)>(py::handle(stride_w)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (dilate_h) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().dilate_h = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::dilate_h)>(py::handle(dilate_h)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (dilate_w) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().dilate_w = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::dilate_w)>(py::handle(dilate_w)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (sparse) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().sparse = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::sparse)>(py::handle(sparse)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (format) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().format = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::format)>(py::handle(format)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (compute_mode) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().compute_mode = | |||
py::cast<decltype(RegionRestrictedConvolutionBackwardData::compute_mode)>(py::handle(compute_mode)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (scope) { | |||
try { | |||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
->set_scope(py::cast<std::string>(py::handle(scope))); | |||
} CATCH_ALL(-1) | |||
} | |||
return 0; | |||
} | |||
PyGetSetDef PyOp(RegionRestrictedConvolutionBackwardData)::py_getsetters[] = { | |||
{const_cast<char*>("mode"), py_get_generic(RegionRestrictedConvolutionBackwardData, mode), py_set_generic(RegionRestrictedConvolutionBackwardData, mode), const_cast<char*>("mode"), NULL}, | |||
{const_cast<char*>("pad_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, pad_h), py_set_generic(RegionRestrictedConvolutionBackwardData, pad_h), const_cast<char*>("pad_h"), NULL}, | |||
{const_cast<char*>("pad_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, pad_w), py_set_generic(RegionRestrictedConvolutionBackwardData, pad_w), const_cast<char*>("pad_w"), NULL}, | |||
{const_cast<char*>("stride_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, stride_h), py_set_generic(RegionRestrictedConvolutionBackwardData, stride_h), const_cast<char*>("stride_h"), NULL}, | |||
{const_cast<char*>("stride_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, stride_w), py_set_generic(RegionRestrictedConvolutionBackwardData, stride_w), const_cast<char*>("stride_w"), NULL}, | |||
{const_cast<char*>("dilate_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, dilate_h), py_set_generic(RegionRestrictedConvolutionBackwardData, dilate_h), const_cast<char*>("dilate_h"), NULL}, | |||
{const_cast<char*>("dilate_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, dilate_w), py_set_generic(RegionRestrictedConvolutionBackwardData, dilate_w), const_cast<char*>("dilate_w"), NULL}, | |||
{const_cast<char*>("sparse"), py_get_generic(RegionRestrictedConvolutionBackwardData, sparse), py_set_generic(RegionRestrictedConvolutionBackwardData, sparse), const_cast<char*>("sparse"), NULL}, | |||
{const_cast<char*>("format"), py_get_generic(RegionRestrictedConvolutionBackwardData, format), py_set_generic(RegionRestrictedConvolutionBackwardData, format), const_cast<char*>("format"), NULL}, | |||
{const_cast<char*>("compute_mode"), py_get_generic(RegionRestrictedConvolutionBackwardData, compute_mode), py_set_generic(RegionRestrictedConvolutionBackwardData, compute_mode), const_cast<char*>("compute_mode"), NULL}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
PyMethodDef PyOp(RegionRestrictedConvolutionBackwardData)::tp_methods[] = { | |||
{const_cast<char*>("__getstate__"), PyOp(RegionRestrictedConvolutionBackwardData)::getstate, METH_NOARGS, "RegionRestrictedConvolutionBackwardData getstate"}, | |||
{const_cast<char*>("__setstate__"), PyOp(RegionRestrictedConvolutionBackwardData)::setstate, METH_VARARGS, "RegionRestrictedConvolutionBackwardData setstate"}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
void _init_py_RegionRestrictedConvolutionBackwardData(py::module m) { | |||
using py_op = PyOp(RegionRestrictedConvolutionBackwardData); | |||
auto& py_type = PyOpType(RegionRestrictedConvolutionBackwardData); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.RegionRestrictedConvolutionBackwardData"; | |||
py_type.tp_basicsize = sizeof(PyOp(RegionRestrictedConvolutionBackwardData)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "RegionRestrictedConvolutionBackwardData"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
py_type.tp_init = py_op::py_init; | |||
py_type.tp_methods = py_op::tp_methods; | |||
py_type.tp_getset = py_op::py_getsetters; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
_init_py_RegionRestrictedConvolutionBackwardData_Mode(py_type); | |||
_init_py_RegionRestrictedConvolutionBackwardData_Sparse(py_type); | |||
_init_py_RegionRestrictedConvolutionBackwardData_Format(py_type); | |||
_init_py_RegionRestrictedConvolutionBackwardData_ComputeMode(py_type); | |||
PyType_Modified(&py_type); | |||
m.add_object("RegionRestrictedConvolutionBackwardData", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(RegionRestrictedConvolutionBackwardData::typeinfo(), &py_type).second); | |||
} | |||
template<> struct EnumTrait<Remap::InterpolationMode> { | |||
static constexpr const char *name = "Remap.InterpolationMode"; | |||
static constexpr std::underlying_type_t<Remap::InterpolationMode> max = 5 - 1; | |||
@@ -18700,6 +19274,8 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { | |||
_init_py_ROIAlign(m); \ | |||
_init_py_ROIPooling(m); \ | |||
_init_py_Reduce(m); \ | |||
_init_py_RegionRestrictedConvolution(m); \ | |||
_init_py_RegionRestrictedConvolutionBackwardData(m); \ | |||
_init_py_Remap(m); \ | |||
_init_py_RemoteRecv(m); \ | |||
_init_py_RemoteSend(m); \ | |||
@@ -1517,6 +1517,58 @@ public: | |||
} | |||
}; | |||
class RegionRestrictedConvolution : public OpDefImplBase<RegionRestrictedConvolution> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
using Mode = ::megdnn::param::Convolution::Mode; | |||
using Sparse = ::megdnn::param::Convolution::Sparse; | |||
using Format = ::megdnn::param::Convolution::Format; | |||
using ComputeMode = ::megdnn::param::Convolution::ComputeMode; | |||
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION; | |||
uint32_t pad_h = 0; | |||
uint32_t pad_w = 0; | |||
uint32_t stride_h = 1; | |||
uint32_t stride_w = 1; | |||
uint32_t dilate_h = 1; | |||
uint32_t dilate_w = 1; | |||
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE; | |||
Format format = ::megdnn::param::Convolution::Format::NCHW; | |||
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT; | |||
RegionRestrictedConvolution() = default; | |||
RegionRestrictedConvolution(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); } | |||
RegionRestrictedConvolution(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {} | |||
::megdnn::param::Convolution param() const { | |||
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode}; | |||
} | |||
}; | |||
class RegionRestrictedConvolutionBackwardData : public OpDefImplBase<RegionRestrictedConvolutionBackwardData> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
using Mode = ::megdnn::param::Convolution::Mode; | |||
using Sparse = ::megdnn::param::Convolution::Sparse; | |||
using Format = ::megdnn::param::Convolution::Format; | |||
using ComputeMode = ::megdnn::param::Convolution::ComputeMode; | |||
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION; | |||
uint32_t pad_h = 0; | |||
uint32_t pad_w = 0; | |||
uint32_t stride_h = 1; | |||
uint32_t stride_w = 1; | |||
uint32_t dilate_h = 1; | |||
uint32_t dilate_w = 1; | |||
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE; | |||
Format format = ::megdnn::param::Convolution::Format::NCHW; | |||
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT; | |||
RegionRestrictedConvolutionBackwardData() = default; | |||
RegionRestrictedConvolutionBackwardData(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); } | |||
RegionRestrictedConvolutionBackwardData(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {} | |||
::megdnn::param::Convolution param() const { | |||
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode}; | |||
} | |||
}; | |||
class Remap : public OpDefImplBase<Remap> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
@@ -1620,6 +1620,52 @@ ReduceInst | |||
.def_readwrite("data_type", &Reduce::data_type) | |||
.def_readwrite("keepdim", &Reduce::keepdim); | |||
py::class_<RegionRestrictedConvolution, std::shared_ptr<RegionRestrictedConvolution>, OpDef> RegionRestrictedConvolutionInst(m, "RegionRestrictedConvolution"); | |||
RegionRestrictedConvolutionInst.attr("Mode") = BatchConvBiasInst.attr("Mode"); | |||
RegionRestrictedConvolutionInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse"); | |||
RegionRestrictedConvolutionInst.attr("Format") = AdaptivePoolingInst.attr("Format"); | |||
RegionRestrictedConvolutionInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode"); | |||
RegionRestrictedConvolutionInst | |||
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {}) | |||
.def_readwrite("mode", &RegionRestrictedConvolution::mode) | |||
.def_readwrite("pad_h", &RegionRestrictedConvolution::pad_h) | |||
.def_readwrite("pad_w", &RegionRestrictedConvolution::pad_w) | |||
.def_readwrite("stride_h", &RegionRestrictedConvolution::stride_h) | |||
.def_readwrite("stride_w", &RegionRestrictedConvolution::stride_w) | |||
.def_readwrite("dilate_h", &RegionRestrictedConvolution::dilate_h) | |||
.def_readwrite("dilate_w", &RegionRestrictedConvolution::dilate_w) | |||
.def_readwrite("sparse", &RegionRestrictedConvolution::sparse) | |||
.def_readwrite("format", &RegionRestrictedConvolution::format) | |||
.def_readwrite("compute_mode", &RegionRestrictedConvolution::compute_mode); | |||
py::class_<RegionRestrictedConvolutionBackwardData, std::shared_ptr<RegionRestrictedConvolutionBackwardData>, OpDef> RegionRestrictedConvolutionBackwardDataInst(m, "RegionRestrictedConvolutionBackwardData"); | |||
RegionRestrictedConvolutionBackwardDataInst.attr("Mode") = BatchConvBiasInst.attr("Mode"); | |||
RegionRestrictedConvolutionBackwardDataInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse"); | |||
RegionRestrictedConvolutionBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format"); | |||
RegionRestrictedConvolutionBackwardDataInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode"); | |||
RegionRestrictedConvolutionBackwardDataInst | |||
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {}) | |||
.def_readwrite("mode", &RegionRestrictedConvolutionBackwardData::mode) | |||
.def_readwrite("pad_h", &RegionRestrictedConvolutionBackwardData::pad_h) | |||
.def_readwrite("pad_w", &RegionRestrictedConvolutionBackwardData::pad_w) | |||
.def_readwrite("stride_h", &RegionRestrictedConvolutionBackwardData::stride_h) | |||
.def_readwrite("stride_w", &RegionRestrictedConvolutionBackwardData::stride_w) | |||
.def_readwrite("dilate_h", &RegionRestrictedConvolutionBackwardData::dilate_h) | |||
.def_readwrite("dilate_w", &RegionRestrictedConvolutionBackwardData::dilate_w) | |||
.def_readwrite("sparse", &RegionRestrictedConvolutionBackwardData::sparse) | |||
.def_readwrite("format", &RegionRestrictedConvolutionBackwardData::format) | |||
.def_readwrite("compute_mode", &RegionRestrictedConvolutionBackwardData::compute_mode); | |||
py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap"); | |||
py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode") | |||
@@ -520,4 +520,9 @@ def MeshGrid: MgbHashableOp<"MeshGrid"> { | |||
MgbStringAttr:$indexing | |||
); | |||
} | |||
def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [ConvolutionParam]>; | |||
def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>; | |||
#endif // MGB_OPS |
@@ -25,6 +25,58 @@ using namespace cg::static_infer; | |||
using intl::WorkspaceLimitGetter; | |||
/* ==================== misc impl ==================== */ | |||
template <typename MGBOPR, typename DNNOPR> | |||
void mixin::RegionConvBackwardDataMixin::init_output_static_infer_desc_for_bwd_data( | |||
cg::OperatorNodeBase* self) { | |||
using namespace cg::static_infer; | |||
auto&& mgr = self->owner_graph()->static_infer_manager(); | |||
DepVal inp_deps; | |||
inp_deps.reserve(6); | |||
for (int i = 0; i < 4; i++) { | |||
inp_deps.push_back({self->input(i), DepType::SHAPE}); | |||
} | |||
auto infer_shp = [self](TensorShape& dest, const InpVal& inp) { | |||
TensorLayout ol{self->output(0)->dtype()}; | |||
mgb_assert( | |||
self->input(0)->dtype().category() == DTypeCategory::FLOAT && | |||
self->input(1)->dtype().category() == DTypeCategory::FLOAT && | |||
self->input(2)->dtype().category() == DTypeCategory::INT && | |||
self->input(3)->dtype().category() == DTypeCategory::INT, | |||
"region conv dtype assert error!"); | |||
static_cast<MGBOPR*>(self)->megdnn_opr()->deduce_layout( | |||
{inp.val.at(0).shape(), self->input(0)->dtype()}, // filter | |||
{inp.val.at(1).shape(), self->input(1)->dtype()}, // diff | |||
{inp.val.at(2).shape(), self->input(2)->dtype()}, // rin | |||
{inp.val.at(3).shape(), self->input(3)->dtype()}, // rout | |||
ol // grad | |||
); | |||
dest = ol; | |||
return true; | |||
}; | |||
mgr.register_shape_infer(self->output(0), {SourceType::DEP, inp_deps, infer_shp}); | |||
// workspace size | |||
auto infer_wk = [self](TensorShape& dest, const InpVal& inp) { | |||
TensorLayout ol{self->output(0)->dtype()}; | |||
dest.ndim = 1; | |||
dest.shape[0] = | |||
static_cast<MGBOPR*>(self)->megdnn_opr()->get_workspace_in_bytes( | |||
{self->input(0)->shape(), self->input(0)->dtype()}, // filter | |||
{self->input(1)->shape(), self->input(1)->dtype()}, // diff | |||
{self->input(2)->shape(), self->input(2)->dtype()}, // rin | |||
{self->input(3)->shape(), self->input(3)->dtype()}, // rout | |||
ol); | |||
return true; | |||
}; | |||
inp_deps.push_back({self->output(0), DepType::SHAPE}); | |||
auto workspace_dep_var = | |||
intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||
if (workspace_dep_var) | |||
inp_deps.push_back({workspace_dep_var, DepType::VALUE}); | |||
mgr.register_shape_infer(self->output(1), {SourceType::DEP, inp_deps, infer_wk}); | |||
} | |||
template <class MgbOpr, class MegDNNOpr> | |||
void mixin::ConvolutionBackwardDataMixin::init_output_static_infer_desc_for_bwd_data( | |||
@@ -1535,6 +1587,226 @@ void BatchConvBiasForward::init_output_format() { | |||
output(0)->format(input(0)->format()); | |||
} | |||
/* ========================== RegionRestrictedConvolutionForward | |||
* ========================== */ | |||
IMPL_CONV(RegionRestrictedConvolutionForward); | |||
RegionRestrictedConvolutionForward::RegionRestrictedConvolutionForward( | |||
VarNode* src, VarNode* filter, VarNode* region_in, VarNode* region_out, | |||
const Param& param, const OperatorNodeConfig& config) | |||
: Super(src->owner_graph(), config, "region_restricted_conv_fwd", | |||
{src, filter, region_in, region_out}) { | |||
init_megdnn_opr(*this, param); | |||
add_input({src, filter, region_in, region_out}); | |||
} | |||
SymbolVar RegionRestrictedConvolutionForward::make( | |||
SymbolVar src, SymbolVar filter, SymbolVar region_in, SymbolVar region_out, | |||
const Param& param, const OperatorNodeConfig& config) { | |||
return src.insert_single_output_opr<RegionRestrictedConvolutionForward>( | |||
src.node(), filter.node(), region_in.node(), region_out.node(), param, | |||
config); | |||
} | |||
void RegionRestrictedConvolutionForward::init_output_dtype() { | |||
mgb_assert( | |||
input(0)->dtype().category() == DTypeCategory::FLOAT, | |||
"input dtype only support FLOAT, \ | |||
but got input dtype: %s", | |||
input(0)->dtype().name()); | |||
output(0)->dtype(input(0)->dtype()); | |||
return; | |||
} | |||
size_t RegionRestrictedConvolutionForward::get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const { | |||
return megdnn_opr()->get_workspace_in_bytes( | |||
{input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||
{input_shapes[1], input(1)->dtype(), input(1)->format()}, | |||
{input_shapes[2], input(2)->dtype(), input(2)->format()}, | |||
{input_shapes[3], input(3)->dtype(), input(3)->format()}, | |||
{output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionForward) { | |||
mgb_assert( | |||
opr.input(0)->dtype().category() == DTypeCategory::FLOAT && | |||
opr.input(1)->dtype().category() == DTypeCategory::FLOAT && | |||
opr.input(2)->dtype().category() == DTypeCategory::INT && | |||
opr.input(3)->dtype().category() == DTypeCategory::INT, | |||
"only float data type supported for grad"); | |||
if (wrt_idx == 0) { // src | |||
SymbolVar grad = RegionRestrictedConvolutionBackwardData::make( | |||
opr.input(1), // filter | |||
out_grad[0], // diff | |||
opr.input(2), // rin | |||
opr.input(3), // rout | |||
opr.input(0), // src | |||
opr.param()); | |||
return grad.node(); | |||
} | |||
// TODO: CUDA WGRAD UNIMPLEMENTED! | |||
if (wrt_idx == 1) { // filter | |||
SymbolVar grad = RegionRestrictedConvolutionBackwardFilter::make( | |||
opr.input(0), // src | |||
out_grad[0], // diff | |||
opr.input(2), // rin | |||
opr.input(3), // rout | |||
opr.input(1), // filter | |||
opr.param()); | |||
return grad.node(); | |||
} | |||
return nullptr; | |||
} | |||
#endif | |||
/* ========================== RegionRestrictedConvolutionBackwardData | |||
* ========================== */ | |||
IMPL_CONV(RegionRestrictedConvolutionBackwardData); | |||
RegionRestrictedConvolutionBackwardData::RegionRestrictedConvolutionBackwardData( | |||
VarNode* filter, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||
VarNode* src, const Param& param, const OperatorNodeConfig& config) | |||
: Super{filter->owner_graph(), | |||
config, | |||
"region_restricted_conv_bwd_data", | |||
{filter, diff, region_in, region_out}} { | |||
init_megdnn_opr(*this, param); | |||
add_input({filter, diff, region_in, region_out}); | |||
if (src) | |||
add_input({src}); | |||
} | |||
SymbolVar RegionRestrictedConvolutionBackwardData::make( | |||
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||
SymbolVar src, const Param& param, const OperatorNodeConfig& config) { | |||
return filter.insert_single_output_opr<RegionRestrictedConvolutionBackwardData>( | |||
filter.node(), diff.node(), region_in.node(), region_out.node(), src.node(), | |||
param, config); | |||
} | |||
SymbolVar RegionRestrictedConvolutionBackwardData::make( | |||
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||
const Param& param, const OperatorNodeConfig& config) { | |||
return make(filter, diff, region_in, region_out, {}, param, config); | |||
} | |||
void RegionRestrictedConvolutionBackwardData::init_output_static_infer_desc() { | |||
init_output_static_infer_desc_for_bwd_data< | |||
RegionRestrictedConvolutionBackwardData, | |||
megdnn::RegionRestrictedConvolutionBackwardData>(this); | |||
} | |||
void RegionRestrictedConvolutionBackwardData::init_output_dtype() { | |||
output(0)->dtype(input(0)->dtype()); | |||
} | |||
void RegionRestrictedConvolutionBackwardData::scn_do_execute() { | |||
megdnn_opr()->exec( | |||
input(0)->dev_tensor().as_megdnn(), // filter | |||
input(1)->dev_tensor().as_megdnn(), // diff | |||
input(2)->dev_tensor().as_megdnn(), // rin | |||
input(3)->dev_tensor().as_megdnn(), // rout | |||
output(0)->dev_tensor().as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output().back())); | |||
} | |||
cg::OperatorNodeBase::NodeProp* RegionRestrictedConvolutionBackwardData:: | |||
do_make_node_prop() const { | |||
auto prop = Super::Super::do_make_node_prop(); | |||
if (input().size() == 5) { | |||
using D = NodeProp::DepType; | |||
prop->reset_dep_type( | |||
input(), | |||
{D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE, D::SHAPE}); | |||
} | |||
return prop; | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionBackwardData) { | |||
if (wrt_idx == 0) { // filter | |||
return RegionRestrictedConvolutionBackwardFilter::make( | |||
out_grad[0], opr.input(1), opr.input(2), opr.input(3), | |||
opr.input(0), opr.param()) | |||
.node(); | |||
} | |||
if (wrt_idx == 1) { // diff | |||
return RegionRestrictedConvolution::make( | |||
out_grad[0], opr.input(0), opr.input(2), opr.input(3), | |||
opr.param()) | |||
.node(); | |||
} | |||
return nullptr; | |||
} | |||
#endif | |||
/* ========================== RegionRestrictedConvolutionBackwardFilter | |||
* ========================== */ | |||
IMPL_CONV(RegionRestrictedConvolutionBackwardFilter); | |||
RegionRestrictedConvolutionBackwardFilter::RegionRestrictedConvolutionBackwardFilter( | |||
VarNode* src, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||
VarNode* filter, const Param& param, const OperatorNodeConfig& config) | |||
: Super({src->owner_graph(), | |||
config, | |||
"region_restricted_conv_bwd_filter", | |||
{src, diff, region_in, region_out, filter}}, | |||
4, false) { | |||
init_megdnn_opr(*this, param); | |||
add_input({src, diff, region_in, region_out, filter}); | |||
} | |||
SymbolVar RegionRestrictedConvolutionBackwardFilter::make( | |||
SymbolVar src, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||
SymbolVar filter, const Param& param, const OperatorNodeConfig& config) { | |||
return src.insert_single_output_opr<RegionRestrictedConvolutionBackwardFilter>( | |||
src.node(), diff.node(), region_in.node(), region_out.node(), filter.node(), | |||
param, config); | |||
} | |||
size_t RegionRestrictedConvolutionBackwardFilter::get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const { | |||
return megdnn_opr()->get_workspace_in_bytes( | |||
{input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||
{input_shapes[1], input(1)->dtype(), input(1)->format()}, | |||
{input_shapes[2], input(2)->dtype(), input(2)->format()}, | |||
{input_shapes[3], input(3)->dtype(), input(3)->format()}, | |||
{output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||
} | |||
void RegionRestrictedConvolutionBackwardFilter::scn_do_execute() { | |||
megdnn_opr()->exec( | |||
input(0)->dev_tensor().as_megdnn(), // src | |||
input(1)->dev_tensor().as_megdnn(), // diff | |||
input(2)->dev_tensor().as_megdnn(), // rin | |||
input(3)->dev_tensor().as_megdnn(), // rout | |||
output(0)->dev_tensor().as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output().back())); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionBackwardFilter) { | |||
if (wrt_idx == 0) { | |||
return RegionRestrictedConvolutionBackwardData::make( | |||
out_grad[0] /*filter*/, opr.input(1) /*diff*/, | |||
opr.input(2) /*rin*/, opr.input(3) /*rout*/, | |||
opr.input(0) /*src*/, opr.param()) | |||
.node(); | |||
} | |||
if (wrt_idx == 1) { | |||
return RegionRestrictedConvolution::make( | |||
opr.input(0) /*src*/, out_grad[0] /*filter*/, | |||
opr.input(2) /*rin*/, opr.input(3) /*rout*/, opr.param()) | |||
.node(); | |||
} | |||
return nullptr; | |||
} | |||
#endif | |||
#undef IMPL_CONV | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -431,6 +431,7 @@ struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0> | |||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
megdnn::param::Convolution3D> {}; | |||
template <> | |||
struct OprLoadDumpImpl<opr::ConvBiasForward, 0> | |||
: public ConvLoadDumpImpl< | |||
@@ -194,6 +194,30 @@ struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0> | |||
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | |||
megdnn::Convolution> {}; | |||
template <> | |||
struct OprMaker<opr::RegionRestrictedConvolutionBackwardData, 0> { | |||
using Opr = opr::RegionRestrictedConvolutionBackwardData; | |||
using Param = Opr::Param; | |||
static cg::OperatorNodeBase* make( | |||
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, | |||
const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(graph); | |||
if (inputs.size() == 4) { // deconv mode | |||
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config) | |||
.node() | |||
->owner_opr(); | |||
} else if (inputs.size() == 5) { // dgrad mode | |||
return Opr::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param, | |||
config) | |||
.node() | |||
->owner_opr(); | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
}; | |||
} // namespace serialization | |||
namespace opr { | |||
@@ -220,6 +244,10 @@ SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); | |||
MGB_SEREG_OPR(RegionRestrictedConvolutionBackwardData, 0); | |||
MGB_SEREG_OPR(RegionRestrictedConvolution, 4); | |||
MGB_SEREG_OPR(RegionRestrictedConvolutionBackwardFilter, 5); | |||
SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); | |||
@@ -18,6 +18,12 @@ protected: | |||
static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self); | |||
}; | |||
class RegionConvBackwardDataMixin : public cg::OperatorNodeMixinBase { | |||
protected: | |||
template <typename MGBOPR, typename DNNOPR> | |||
static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self); | |||
}; | |||
class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { | |||
class PreprocessedFilterExecDep; | |||
@@ -83,6 +89,80 @@ class ConvolutionTestingPeer; | |||
} // namespace testing | |||
/* ==================== RegionRestrictedConvolutionForward ==================== */ | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
RegionRestrictedConvolutionForward, | |||
intl::MegDNNOprWrapperFwd<megdnn::RegionRestrictedConvolutionForward>) // { | |||
size_t get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const override; | |||
void init_output_dtype() override; | |||
public: | |||
MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionForward( | |||
VarNode* src, VarNode* filter, VarNode* region_in, VarNode* region_out, | |||
const Param& param, const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar src, SymbolVar filter, SymbolVar region_in, SymbolVar region_out, | |||
const Param& param, const OperatorNodeConfig& config = {}); | |||
}; | |||
using RegionRestrictedConvolution = RegionRestrictedConvolutionForward; | |||
/* ==================== RegionRestrictedConvolutionBackwardData ==================== */ | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
RegionRestrictedConvolutionBackwardData, | |||
cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl< | |||
megdnn::RegionRestrictedConvolutionBackwardData>>, | |||
public mixin::RegionConvBackwardDataMixin) // { | |||
void scn_do_execute() override; | |||
void init_output_static_infer_desc() override; | |||
NodeProp* do_make_node_prop() const override; | |||
void init_output_dtype() override; | |||
public: | |||
MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionBackwardData( | |||
VarNode* filter, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||
VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||
// grad mode | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); | |||
// sereg for deconv mode | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||
const Param& param, const OperatorNodeConfig& config = {}); | |||
// user interface for deconv | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make_deconv( | |||
SymbolVar data, SymbolVar filter, SymbolVar region_in, SymbolVar region_out, | |||
const Param& param = {}, const OperatorNodeConfig& config = {}) { | |||
return make(filter, data, region_in, region_out, param, config); | |||
} | |||
}; | |||
/* ==================== RegionRestrictedConvolutionBackwardFilter ==================== */ | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
RegionRestrictedConvolutionBackwardFilter, | |||
intl::MegDNNOprWrapperBwd<megdnn::RegionRestrictedConvolutionBackwardFilter>) // { | |||
size_t get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const override; | |||
void scn_do_execute() override; | |||
public: | |||
MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionBackwardFilter( | |||
VarNode* src, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||
VarNode* filter, const Param& param, const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar src, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||
SymbolVar filter, const Param& param, | |||
const OperatorNodeConfig& config = {}); | |||
}; | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
ConvolutionForward, intl::ConvolutionForwardBase, | |||
public mixin::AlgoChooserHelper) // { | |||
@@ -0,0 +1,196 @@ | |||
#include "./legacy_checker.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/serialization/serializer.h" | |||
#include "megbrain/test/autocheck.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megbrain/test/megdnn_helper.h" | |||
#include "megdnn/algorithm_cache.h" | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs/base.h" | |||
#include <gmock/gmock.h> | |||
#include <cmath> | |||
#include <memory> | |||
#include <random> | |||
using namespace mgb; | |||
TEST(TestOprDNN, REGIONCONV_FWD_CPU_WRAPPER) { | |||
using Checker = AutoOprChecker<4, 1>; | |||
megdnn::RegionRestrictedConvolution::Param param; | |||
param.sparse = opr::RegionRestrictedConvolution::Param::Sparse::DENSE; | |||
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
return {opr::RegionRestrictedConvolutionForward::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], param)}; | |||
}; | |||
Checker::RunOptions option; | |||
option.numdiff_eps = 0.1; | |||
option.numdiff_max_err = 1e-2; | |||
auto mask_gen = [&](HostTensorND& src) { | |||
HostTensorGenerator<dtype::Int32, RandomDistribution::CONSTANT> gen(1); | |||
src = *gen(src.shape(), src.comp_node()); | |||
}; | |||
auto float_gen = [&](HostTensorND& src) { | |||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> gen; | |||
src = *gen(src.shape(), src.comp_node()); | |||
}; | |||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
auto opr = | |||
megdnn_naive_handle() | |||
->create_operator<megdnn::RegionRestrictedConvolutionForward>(); | |||
opr->param() = param; | |||
TensorLayout dest_layout; | |||
opr->deduce_layout( | |||
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||
dest_layout); | |||
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes( | |||
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||
dest_layout)); | |||
dest[0].dtype(inp[0]->dtype()) | |||
.comp_node(inp[0]->comp_node()) | |||
.resize(dest_layout); | |||
opr->exec( | |||
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||
inp[3]->as_megdnn(), dest[0].as_megdnn(), | |||
{workspace.data(), workspace.size()}); | |||
}; | |||
Checker(make_graph, fwd, CompNode::load("cpu0")) | |||
.set_input_dtype(0, dtype::Float32()) | |||
.set_input_dtype(1, dtype::Float32()) | |||
.set_input_dtype(2, dtype::Int32()) | |||
.set_input_dtype(3, dtype::Int32()) | |||
.set_input_generator(0, float_gen) | |||
.set_input_generator(1, float_gen) | |||
.set_input_generator(2, mask_gen) | |||
.set_input_generator(3, mask_gen) | |||
.set_input_allow_grad(2, false) | |||
.set_input_allow_grad(3, false) | |||
// {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow} | |||
.run({TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 2, 2}, TensorShape{1, 1, 1}}, | |||
option) | |||
.run({TensorShape{1, 2, 3, 3}, TensorShape{1, 2, 3, 3}, | |||
TensorShape{1, 3, 3}, TensorShape{1, 1, 1}}, | |||
option) | |||
.run({TensorShape{1, 1, 4, 4}, TensorShape{1, 1, 2, 2}, | |||
TensorShape{1, 4, 4}, TensorShape{1, 3, 3}}, | |||
option) | |||
.run({TensorShape{2, 2, 8, 8}, TensorShape{4, 2, 2, 2}, | |||
TensorShape{2, 8, 8}, TensorShape{2, 7, 7}}, | |||
option) | |||
.run({TensorShape{4, 4, 8, 8}, TensorShape{4, 4, 2, 2}, | |||
TensorShape{4, 8, 8}, TensorShape{4, 7, 7}}, | |||
option); | |||
} | |||
#if MGB_CUDA | |||
TEST(TestOprDNN, REGIONCONV_FWD_GPU_WRAPPER) { | |||
using Checker = AutoOprChecker<4, 1>; | |||
megdnn::RegionRestrictedConvolution::Param param; | |||
param.sparse = opr::RegionRestrictedConvolution::Param::Sparse::GROUP; | |||
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
return {opr::RegionRestrictedConvolutionForward::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], param)}; | |||
}; | |||
Checker::RunOptions option; | |||
option.numdiff_eps = 0.1; | |||
option.numdiff_max_err = 1e-2; | |||
auto mask_gen = [&](HostTensorND& src) { | |||
HostTensorGenerator<dtype::Int32, RandomDistribution::CONSTANT> gen(1); | |||
src = *gen(src.shape(), src.comp_node()); | |||
}; | |||
auto uint8_mask_gen = [&](HostTensorND& src) { | |||
HostTensorGenerator<dtype::Uint8, RandomDistribution::CONSTANT> gen(1); | |||
src = *gen(src.shape(), src.comp_node()); | |||
}; | |||
auto float_gen = [&](HostTensorND& src) { | |||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> gen; | |||
src = *gen(src.shape(), src.comp_node()); | |||
}; | |||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
auto opr = | |||
megdnn_naive_handle() | |||
->create_operator<megdnn::RegionRestrictedConvolutionForward>(); | |||
opr->param() = param; | |||
TensorLayout dest_layout; | |||
opr->deduce_layout( | |||
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||
dest_layout); | |||
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes( | |||
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||
dest_layout)); | |||
dest[0].dtype(inp[0]->dtype()) | |||
.comp_node(inp[0]->comp_node()) | |||
.resize(dest_layout); | |||
opr->exec( | |||
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||
inp[3]->as_megdnn(), dest[0].as_megdnn(), | |||
{workspace.data(), workspace.size()}); | |||
}; | |||
Checker(make_graph, fwd, CompNode::load("gpu0")) | |||
.set_input_dtype(0, dtype::Float32()) | |||
.set_input_dtype(1, dtype::Float32()) | |||
.set_input_dtype(2, dtype::Int32()) | |||
.set_input_dtype(3, dtype::Int32()) | |||
.set_input_generator(0, float_gen) | |||
.set_input_generator(1, float_gen) | |||
.set_input_generator(2, mask_gen) | |||
.set_input_generator(3, mask_gen) | |||
.set_input_allow_grad(2, false) | |||
.set_input_allow_grad(3, false) | |||
// {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow} | |||
.run({TensorShape{1, 2, 2, 2}, TensorShape{2, 1, 1, 2, 2}, | |||
TensorShape{1, 2, 2}, TensorShape{1, 1, 1}}, | |||
option) | |||
.run({TensorShape{1, 2, 3, 3}, TensorShape{2, 1, 1, 3, 3}, | |||
TensorShape{1, 3, 3}, TensorShape{1, 1, 1}}, | |||
option) | |||
.run({TensorShape{1, 4, 4, 4}, TensorShape{4, 1, 1, 2, 2}, | |||
TensorShape{1, 4, 4}, TensorShape{1, 3, 3}}, | |||
option) | |||
.run({TensorShape{2, 4, 8, 8}, TensorShape{4, 1, 1, 2, 2}, | |||
TensorShape{2, 8, 8}, TensorShape{2, 7, 7}}, | |||
option); | |||
Checker(make_graph, fwd, CompNode::load("gpu0")) | |||
.set_input_dtype(0, dtype::Float32()) | |||
.set_input_dtype(1, dtype::Float32()) | |||
.set_input_dtype(2, dtype::Uint8()) | |||
.set_input_dtype(3, dtype::Uint8()) | |||
.set_input_generator(0, float_gen) | |||
.set_input_generator(1, float_gen) | |||
.set_input_generator(2, uint8_mask_gen) | |||
.set_input_generator(3, uint8_mask_gen) | |||
.set_input_allow_grad(2, false) | |||
.set_input_allow_grad(3, false) | |||
// {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow} | |||
.run({TensorShape{1, 2, 4, 4}, TensorShape{2, 1, 1, 1, 1}, | |||
TensorShape{1, 4, 4}, TensorShape{1, 4, 4}}, | |||
option) | |||
.run({TensorShape{1, 2, 8, 8}, TensorShape{2, 1, 1, 1, 1}, | |||
TensorShape{1, 8, 8}, TensorShape{1, 8, 8}}, | |||
option) | |||
.run({TensorShape{1, 4, 8, 8}, TensorShape{4, 1, 1, 5, 5}, | |||
TensorShape{1, 8, 8}, TensorShape{1, 4, 4}}, | |||
option) | |||
.run({TensorShape{2, 4, 8, 8}, TensorShape{4, 1, 1, 1, 1}, | |||
TensorShape{2, 8, 8}, TensorShape{2, 8, 8}}, | |||
option); | |||
} | |||
#endif |