From 99309fa3d6f649f4dce44025bfb85a96fcaafab0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 1 Aug 2022 15:18:41 +0800 Subject: [PATCH] feat(mge/functional): add param output_padding for deconv ops GitOrigin-RevId: 8a69608953a69b40db4d23b489435a7ae03c9523 --- imperative/python/megengine/functional/nn.py | 70 +++++++++++++++- .../python/megengine/functional/quantized.py | 29 ++++++- imperative/python/megengine/module/conv.py | 26 +++++- imperative/python/megengine/module/qat/conv.py | 1 + .../python/megengine/module/quantized/conv.py | 6 ++ .../python/test/unit/functional/test_functional.py | 37 ++++++++- imperative/src/impl/ops/convolution.cpp | 97 +++++++++++++++------- 7 files changed, 232 insertions(+), 34 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 58a60b4d..7f53e4c9 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -335,6 +335,7 @@ def conv_transpose2d( bias: Optional[Tensor] = None, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, + output_padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, conv_mode="cross_correlation", @@ -352,6 +353,7 @@ def conv_transpose2d( stride: stride of the 2D convolution operation. Default: 1 padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 + output_padding: size of paddings appended to output. Default: 0 dilation: dilation of the 2D convolution operation. Default: 1 groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, @@ -374,6 +376,7 @@ def conv_transpose2d( stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) + output_pad_h, output_pad_w = expand_hw(output_padding) dilate_h, dilate_w = expand_hw(dilation) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) @@ -389,7 +392,32 @@ def conv_transpose2d( compute_mode=compute_mode, sparse=sparse_type, ) - (output,) = apply(op, weight, inp) + if output_pad_h != 0 or output_pad_h != 0: + assert ( + output_pad_h < stride[0] + ), "output_padding[0] shoule be less than stride[0]" + assert ( + output_pad_w < stride[1] + ), "output_padding[1] shoule be less than stride[1]" + Hout = ( + (inp.shape[2] - 1) * stride[0] + - 2 * padding[0] + + dilation[0] * (weight.shape[2] - 1) + + output_pad_h + + 1 + ) + Wout = ( + (inp.shape[3] - 1) * stride[1] + - 2 * padding[1] + + dilation[1] * (weight.shape[3] - 1) + + output_pad_w + + 1 + ) + output_shape = [inp.shape[0], weight.shape[1], Hout, Wout] + output_shape = astensor1d(output_shape) + (output,) = apply(op, weight, inp, output_shape) + else: + (output,) = apply(op, weight, inp) if bias is not None: if amp._enabled: bias = cast_tensors(bias) @@ -528,6 +556,7 @@ def conv_transpose3d( bias: Optional[Tensor] = None, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, + output_padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, ) -> Tensor: @@ -544,6 +573,7 @@ def conv_transpose3d( stride: stride of the 3D convolution operation. Default: 1 padding: size of the paddings added to the input on all sides of its spatial dimensions. Only zero-padding is supported. Default: 0 + output_padding: size of paddings appended to output. Default: 0 dilation: dilation of the 3D convolution operation. Default: 1 groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, @@ -558,6 +588,7 @@ def conv_transpose3d( pad = expand_dhw(padding) stride = expand_dhw(stride) dilate = expand_dhw(dilation) + output_padding = expand_dhw(output_padding) sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution3DBackwardData( @@ -573,7 +604,42 @@ def conv_transpose3d( strategy=get_execution_strategy(), sparse=sparse_type, ) - (output,) = apply(op, weight, inp) + if output_padding[0] != 0 or output_padding[1] != 0 or output_padding[2] != 0: + assert ( + output_padding[0] < stride[0] + ), "output_padding[0] shoule be less than stride[0]" + assert ( + output_padding[1] < stride[1] + ), "output_padding[1] shoule be less than stride[1]" + assert ( + output_padding[2] < stride[2] + ), "output_padding[2] shoule be less than stride[2]" + Dout = ( + (inp.shape[2] - 1) * stride[0] + - 2 * padding[0] + + dilation[0] * (weight.shape[2] - 1) + + output_padding[0] + + 1 + ) + Hout = ( + (inp.shape[3] - 1) * stride[1] + - 2 * padding[1] + + dilation[1] * (weight.shape[3] - 1) + + output_padding[1] + + 1 + ) + Wout = ( + (inp.shape[4] - 1) * stride[2] + - 2 * padding[2] + + dilation[2] * (weight.shape[4] - 1) + + output_padding[2] + + 1 + ) + output_shape = [inp.shape[0], weight.shape[1], Dout, Hout, Wout] + output_shape = astensor1d(output_shape) + (output,) = apply(op, weight, inp, output_shape) + else: + (output,) = apply(op, weight, inp) if bias is not None: output += bias return output diff --git a/imperative/python/megengine/functional/quantized.py b/imperative/python/megengine/functional/quantized.py index fd047bc5..79f8daa2 100644 --- a/imperative/python/megengine/functional/quantized.py +++ b/imperative/python/megengine/functional/quantized.py @@ -134,6 +134,7 @@ def conv_transpose2d( dtype=None, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, + output_padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, conv_mode="cross_correlation", @@ -156,6 +157,7 @@ def conv_transpose2d( ) pad_h, pad_w = _pair(padding) + output_pad_h, output_pad_w = _pair(output_padding) stride_h, stride_w = _pair_nonzero(stride) dilate_h, dilate_w = _pair_nonzero(dilation) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) @@ -173,5 +175,30 @@ def conv_transpose2d( compute_mode=compute_mode, mode=conv_mode, ) - (output,) = apply(op, weight, inp) + if output_pad_h != 0 or output_pad_h != 0: + assert ( + output_pad_h < stride[0] + ), "output_padding[0] shoule be less than stride[0]" + assert ( + output_pad_w < stride[1] + ), "output_padding[1] shoule be less than stride[1]" + Hout = ( + (inp.shape[2] - 1) * stride[0] + - 2 * padding[0] + + dilation[0] * (weight.shape[2] - 1) + + output_pad_h + + 1 + ) + Wout = ( + (inp.shape[3] - 1) * stride[1] + - 2 * padding[1] + + dilation[1] * (weight.shape[3] - 1) + + output_pad_w + + 1 + ) + output_shape = [inp.shape[0], weight.shape[1], Hout, Wout] + output_shape = Tensor(output_shape) + (output,) = apply(op, weight, inp, output_shape) + else: + (output,) = apply(op, weight, inp) return output diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 0ac9b394..8188b652 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -30,6 +30,7 @@ class _ConvNd(Module): kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], + output_padding: Union[int, Tuple[int, int]], dilation: Union[int, Tuple[int, int]], groups: int, bias: bool = True, @@ -45,6 +46,7 @@ class _ConvNd(Module): self.kernel_size = kernel_size self.stride = stride self.padding = padding + self.output_padding = output_padding self.dilation = dilation self.groups = groups @@ -178,6 +180,7 @@ class Conv1d(_ConvNd): kernel_size, stride, padding, + 0, dilation, groups, bias, @@ -352,6 +355,7 @@ class Conv2d(_ConvNd): kernel_size, stride, padding, + 0, dilation, groups, bias, @@ -505,6 +509,7 @@ class Conv3d(_ConvNd): kernel_size, stride, padding, + 0, dilation, groups, bias, @@ -572,6 +577,7 @@ class ConvTranspose2d(_ConvNd): stride: stride of the 2D convolution operation. Default: 1 padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 + output_padding: size of paddings appended to output. Default: 0 dilation: dilation of the 2D convolution operation. Default: 1 groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, @@ -591,6 +597,8 @@ class ConvTranspose2d(_ConvNd): * ``bias`` usually has shape ``(1, out_channels, *1)`` """ + output_padding = 0 + def __init__( self, in_channels: int, @@ -598,6 +606,7 @@ class ConvTranspose2d(_ConvNd): kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, + output_padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, @@ -608,6 +617,7 @@ class ConvTranspose2d(_ConvNd): kernel_size = _pair_nonzero(kernel_size) stride = _pair_nonzero(stride) padding = _pair(padding) + output_padding = _pair(output_padding) dilation = _pair_nonzero(dilation) self.conv_mode = conv_mode self.compute_mode = compute_mode @@ -617,6 +627,7 @@ class ConvTranspose2d(_ConvNd): kernel_size, stride, padding, + output_padding, dilation, groups, bias, @@ -656,6 +667,7 @@ class ConvTranspose2d(_ConvNd): bias, self.stride, self.padding, + self.output_padding, self.dilation, self.groups, self.conv_mode, @@ -817,6 +829,7 @@ class DeformableConv2d(_ConvNd): kernel_size, stride, padding, + 0, dilation, groups, bias, @@ -889,6 +902,7 @@ class ConvTranspose3d(_ConvNd): stride: stride of the 3D convolution operation. Default: 1 padding: size of the paddings added to the input on all sides of its spatial dimensions. Only zero-padding is supported. Default: 0 + output_padding: size of paddings appended to output. Default: 0 dilation: dilation of the 3D convolution operation. Default: 1 groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, @@ -902,6 +916,8 @@ class ConvTranspose3d(_ConvNd): * ``bias`` usually has shape ``(1, out_channels, *1)`` """ + output_padding = 0 + def __init__( self, in_channels: int, @@ -909,6 +925,7 @@ class ConvTranspose3d(_ConvNd): kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, + output_padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, bias: bool = True, @@ -923,6 +940,7 @@ class ConvTranspose3d(_ConvNd): kernel_size=kernel_size, stride=stride, padding=padding, + output_padding=output_padding, dilation=dilation, groups=groups, bias=bias, @@ -956,5 +974,11 @@ class ConvTranspose3d(_ConvNd): def forward(self, inp): return conv_transpose3d( - inp, self.weight, self.bias, self.stride, self.padding, self.dilation, + inp, + self.weight, + self.bias, + self.stride, + self.padding, + self.output_padding, + self.dilation, ) diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index 0780eaef..e3281106 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -74,6 +74,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): float_module.kernel_size, float_module.stride, float_module.padding, + float_module.output_padding, float_module.dilation, float_module.groups, float_module.bias is not None, diff --git a/imperative/python/megengine/module/quantized/conv.py b/imperative/python/megengine/module/quantized/conv.py index aad7bfe3..0fc8142e 100644 --- a/imperative/python/megengine/module/quantized/conv.py +++ b/imperative/python/megengine/module/quantized/conv.py @@ -138,6 +138,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): dtype: data type of the output, should be qint8. """ + output_padding = 0 + def __init__( self, in_channels: int, @@ -145,6 +147,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, + output_padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, @@ -159,6 +162,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): kernel_size=kernel_size, stride=stride, padding=padding, + output_padding=output_padding, dilation=dilation, groups=groups, bias=bias, @@ -180,6 +184,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): qat_module.kernel_size, qat_module.stride, qat_module.padding, + qat_module.output_padding, qat_module.dilation, qat_module.groups, qat_module.bias is not None, @@ -212,6 +217,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): dtype=self.output_dtype, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, groups=self.groups, conv_mode=self.conv_mode, diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index f7838b3b..87db37cb 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -18,7 +18,8 @@ from megengine.core._trace_option import use_symbolic_shape from megengine.core.autodiff.grad import Grad from megengine.core.tensor.utils import make_shape_tuple from megengine.device import get_device_count -from megengine.module import LayerNorm +from megengine.jit.tracing import trace +from megengine.module import ConvTranspose2d, ConvTranspose3d, LayerNorm _assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) @@ -1374,3 +1375,37 @@ def test_local_conv2d(stride, padding, dilation, ksize, groups): ) ref = local_conv2d_np(data, weight, stride, padding, dilation) np.testing.assert_almost_equal(output.numpy(), ref, 5) + + +def test_conv_transpose2d(): + m = ConvTranspose2d( + 16, 33, (3, 5), output_padding=(1, 2), stride=(2, 3), padding=(4, 2) + ) + + @trace(symbolic=True) + def fwd(inp: Tensor): + return m(inp) + + input = Tensor(np.random.rand(20, 16, 50, 100)) + output = fwd(input) + output_shape = Tensor(output.shape) + np.testing.assert_equal( + output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32) + ) + + +def test_conv_transpose3d(): + m = ConvTranspose3d( + 16, 33, (3, 5, 2), output_padding=(2, 1, 1), stride=(3, 2, 2), padding=(0, 4, 2) + ) + + @trace(symbolic=True) + def fwd(inp: Tensor): + return m(inp) + + input = Tensor(np.random.rand(20, 16, 10, 50, 100)) + output = fwd(input) + output_shape = Tensor(output.shape) + np.testing.assert_equal( + output_shape.numpy(), np.array([20, 33, 32, 96, 197], dtype=np.int32) + ) diff --git a/imperative/src/impl/ops/convolution.cpp b/imperative/src/impl/ops/convolution.cpp index 1a6823b9..dc622509 100644 --- a/imperative/src/impl/ops/convolution.cpp +++ b/imperative/src/impl/ops/convolution.cpp @@ -5,6 +5,7 @@ #include "../op_trait.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megbrain/opr/tensor_gen.h" namespace mgb { namespace imperative { @@ -152,8 +153,11 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { inputs[0], inputs[1], conv.param(), conv.policy(), config); } else { mgb_assert(inputs.size() == 3); + auto* src_for_shape = + opr::Alloc::make(inputs[2], inputs[0]->dtype(), {}).node(); return opr::ConvolutionBackwardData::make( - inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); + inputs[0], inputs[1], src_for_shape, conv.param(), conv.policy(), + config); } } @@ -168,6 +172,14 @@ std::tuple, bool> infer_output_attrs_fallible( if (filter.ndim && diff.ndim) { // deduce_layout won't override existing dtype dnn_opr.opr().deduce_layout(filter, diff, output_layout); + if (inputs.size() == 3) { + if (!inputs[2].value.empty()) { + cg::copy_tensor_value_to_shape(output_layout, inputs[2].value); + output_layout.init_contiguous_stride(); + } else { + output_layout.ndim = 0; + } + } } return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; } @@ -185,8 +197,11 @@ SmallVector apply_on_physical_tensor( return output_descs[0].layout; } else { TensorLayout out_layout{inputs[0]->dtype()}; - dnn_opr.op()->deduce_layout( - inputs[0]->layout(), inputs[1]->layout(), out_layout); + if (inputs.size() == 3) { + cg::copy_tensor_value_to_shape( + out_layout, inputs[2]->get_value().proxy_to_default_cpu()); + out_layout.init_contiguous_stride(); + } return out_layout; } }(); @@ -263,50 +278,74 @@ namespace convolution3d_backward_data { std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { mgb_assert( - inputs.size() == 2, - "inputs num of conv_transpose3d should be 2 but you give %zu", + inputs.size() == 2 || inputs.size() == 3, + "inputs num of conv_transpose3d should be 2 or 3 but you give %zu", inputs.size()); - auto&& op_def = def.cast_final_safe(); - auto&& weight = inputs[0]; + auto&& conv3dbwd = def.cast_final_safe(); + DnnOprHelper dnn_opr(conv3dbwd.param()); + auto&& filter = inputs[0]; auto&& diff = inputs[1]; - if (!(weight.layout.ndim && diff.layout.ndim)) { - return {{{TensorLayout{weight.layout.dtype}, weight.comp_node}}, false}; + + if (!(filter.layout.ndim && diff.layout.ndim)) { + return {{{TensorLayout{filter.layout.dtype}, filter.comp_node}}, false}; } - DnnOprHelper dnn_opr(op_def.param()); - auto oup_layout = dnn_opr.deduce_layout(weight.layout, diff.layout); - return {{{oup_layout, weight.comp_node}}, true}; + + TensorLayout output_layout = dnn_opr.deduce_layout(filter.layout, diff.layout); + if (filter.layout.ndim && diff.layout.ndim) { + if (inputs.size() == 3) { + if (!inputs[2].value.empty()) { + cg::copy_tensor_value_to_shape(output_layout, inputs[2].value); + output_layout.init_contiguous_stride(); + } else { + output_layout.ndim = 0; + } + } + } + return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { - auto&& conv = def.cast_final_safe(); - auto cn = inputs[0]->comp_node(); - - auto&& wlayout = inputs[0]->layout(); - auto&& dlayout = inputs[1]->layout(); - - DnnOprCaller dnn_op( - cn, conv.param(), conv.policy()); - - auto oup_layout = [&] { + auto&& conv3dbwd = def.cast_final_safe(); + CompNode cn = inputs[0]->comp_node(); + DnnOprCaller dnn_opr( + cn, conv3dbwd.param(), conv3dbwd.policy()); + auto out_layout = [&] { if (validated) { return output_descs[0].layout; } else { - return dnn_op.deduce_layout(wlayout, dlayout); + TensorLayout out_layout{inputs[0]->dtype()}; + dnn_opr.op()->deduce_layout( + inputs[0]->layout(), inputs[1]->layout(), out_layout); + if (inputs.size() == 3) { + cg::copy_tensor_value_to_shape( + out_layout, inputs[2]->get_value().proxy_to_default_cpu()); + out_layout.init_contiguous_stride(); + } + return out_layout; } }(); - auto oup = Tensor::make(oup_layout, cn); - dnn_op.exec_fastrun(inputs[0], inputs[1], oup); - return {oup}; + auto out = Tensor::make(out_layout, cn); + dnn_opr.exec_fastrun(inputs[0], inputs[1], out); + return {out}; } auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& conv = static_cast(def); OperatorNodeConfig config{conv.make_name()}; - mgb_assert(inputs.size() == 2); - return opr::Convolution3DBackwardData::make( - inputs[0], inputs[1], conv.param(), conv.policy(), config); + if (inputs.size() == 2) { + return opr::Convolution3DBackwardData::make( + inputs[0], inputs[1], conv.param(), conv.policy(), config); + } else { + mgb_assert(inputs.size() == 3); + // The output shape is calculated in advance and given as input + auto* src_for_shape = + opr::Alloc::make(inputs[2], inputs[0]->dtype(), {}).node(); + return opr::Convolution3DBackwardData::make( + inputs[0], inputs[1], src_for_shape, conv.param(), conv.policy(), + config); + } } OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)