Browse Source

fix(mge/functional): fix conv* dtype promotion

GitOrigin-RevId: 3f03790cfc
release-1.5
Megvii Engine Team 3 years ago
parent
commit
3f3a256e0f
3 changed files with 75 additions and 8 deletions
  1. +41
    -5
      imperative/python/megengine/functional/nn.py
  2. +3
    -3
      imperative/python/megengine/module/conv.py
  3. +31
    -0
      imperative/python/test/unit/module/test_conv.py

+ 41
- 5
imperative/python/megengine/functional/nn.py View File

@@ -9,7 +9,7 @@
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union


from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.special import Const from ..core.ops.special import Const
@@ -157,6 +157,12 @@ def conv1d(
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias) inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)


inp = expand_dims(inp, 3) inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3) weight = expand_dims(weight, 3)
@@ -211,7 +217,7 @@ def conv2d(
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
:param groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups, and the shape of weight should be ``(groups, out_channel // groups,
@@ -234,6 +240,12 @@ def conv2d(
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias) inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
@@ -297,6 +309,12 @@ def conv3d(
stride = _triple_nonzero(stride) stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation) dilate = _triple_nonzero(dilation)


dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)

sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D( op = builtin.Convolution3D(
pad_d=pad[D], pad_d=pad[D],
@@ -341,7 +359,7 @@ def conv_transpose2d(
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
:param groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by groups, ``in_channels`` and ``out_channels`` must be divisible by groups,
and the shape of weight should be ``(groups, in_channels // groups, and the shape of weight should be ``(groups, in_channels // groups,
@@ -364,6 +382,12 @@ def conv_transpose2d(
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias) inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)


if groups != 1: if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.") raise NotImplementedError("group transposed conv2d is not supported yet.")
@@ -413,7 +437,7 @@ def deformable_conv2d(
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
:param groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by groups, ``in_channels`` and ``out_channels`` must be divisible by groups,
and the shape of weight should be ``(groups, out_channel // groups, and the shape of weight should be ``(groups, out_channel // groups,
@@ -482,6 +506,12 @@ def local_conv2d(
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)


dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)

op = builtin.GroupLocal( op = builtin.GroupLocal(
stride_h=stride_h, stride_h=stride_h,
stride_w=stride_w, stride_w=stride_w,
@@ -507,7 +537,7 @@ def conv_transpose3d(
dilation: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
) -> Tensor: ) -> Tensor:
""" """
3D transposed convolution operation. Only support the case that groups = 1
3D transposed convolution operation. Only support the case that groups = 1
and conv_mode = "cross_correlation". and conv_mode = "cross_correlation".


Refer to :class:`~.ConvTranspose3d` for more information. Refer to :class:`~.ConvTranspose3d` for more information.
@@ -527,6 +557,12 @@ def conv_transpose3d(
stride = _triple_nonzero(stride) stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation) dilate = _triple_nonzero(dilation)


dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)

op = builtin.Convolution3DBackwardData( op = builtin.Convolution3DBackwardData(
pad_d=pad[D], pad_d=pad[D],
pad_h=pad[H], pad_h=pad[H],


+ 3
- 3
imperative/python/megengine/module/conv.py View File

@@ -269,11 +269,11 @@ class Conv2d(_ConvNd):
output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where


.. math:: .. math::
\text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
\text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
\text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor \text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor


.. math:: .. math::
\text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
\text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
\text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor \text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor


When `groups == in_channels` and `out_channels == K * in_channels`, When `groups == in_channels` and `out_channels == K * in_channels`,
@@ -939,7 +939,7 @@ class ConvTranspose3d(_ConvNd):
ichl = self.in_channels ichl = self.in_channels
ochl = self.out_channels ochl = self.out_channels
kt, kh, kw = self.kernel_size kt, kh, kw = self.kernel_size
return (ochl, ichl, kt, kh, kw)
return (ichl, ochl, kt, kh, kw)


def _infer_bias_shape(self): def _infer_bias_shape(self):
# Assume format is NCTHW # Assume format is NCTHW


+ 31
- 0
imperative/python/test/unit/module/test_conv.py View File

@@ -9,11 +9,41 @@
import itertools import itertools


import numpy as np import numpy as np
import pytest


import megengine.module as M
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.functional.debug_param import (
get_execution_strategy,
set_execution_strategy,
)
from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d




@pytest.fixture
def reproducible():
old = get_execution_strategy()
set_execution_strategy("HEURISTIC_REPRODUCIBLE")
yield
set_execution_strategy(old)


# NOTE: test in module for convenience. should really test in functional
@pytest.mark.parametrize(
"name",
["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"],
)
def test_conv_dtype_promotion(name, reproducible):
N, Ci, Co, K = 2, 16, 32, 3
S = (7,) * int(name[-2])
if "Local" in name:
m = getattr(M, name)(Ci, Co, *S, K)
else:
m = getattr(M, name)(Ci, Co, K)
x = tensor(np.random.random(size=(N, Ci) + S).astype("float16"))
np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy())


def test_conv_transpose2d(): def test_conv_transpose2d():
SH, SW = 3, 1 SH, SW = 3, 1
PH, PW = 2, 0 PH, PW = 2, 0
@@ -163,6 +193,7 @@ def test_conv_transpose3d():
) )
out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW] out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW]


assert conv_transpose3d.weight.numpy().shape == weight.shape
conv_transpose3d.weight = Parameter(weight) conv_transpose3d.weight = Parameter(weight)
out_meg = conv_transpose3d.forward(tensor(inp)) out_meg = conv_transpose3d.forward(tensor(inp))




Loading…
Cancel
Save