Browse Source

fix(init): fix fan_in and fan_out for group conv2d

GitOrigin-RevId: a6f41063f0
tags/v1.8.1.m1
Megvii Engine Team 3 years ago
parent
commit
3159eecadd
2 changed files with 47 additions and 11 deletions
  1. +19
    -10
      imperative/python/megengine/module/init.py
  2. +28
    -1
      imperative/python/test/unit/module/test_init.py

+ 19
- 10
imperative/python/megengine/module/init.py View File

@@ -74,7 +74,7 @@ def calculate_gain(
) -> float: ) -> float:
r"""Returns a recommended gain value (see the table below) for the given nonlinearity r"""Returns a recommended gain value (see the table below) for the given nonlinearity
function. function.
================= ==================================================== ================= ====================================================
nonlinearity gain nonlinearity gain
================= ==================================================== ================= ====================================================
@@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes
input tensor is stored in ``NCHW`` format. input tensor is stored in ``NCHW`` format.


Note:
The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This
function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses
``fan_out = O * K * K``.

Args: Args:
tensor: weight tensor in ``NCHW`` format. tensor: weight tensor in ``NCHW`` format.
""" """
@@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
fan_in = shape[1] fan_in = shape[1]
fan_out = shape[0] fan_out = shape[0]
else: else:
if ndim >= 5:
# ignore the groups dimension of group conv2d and group conv3d
# FIXME: will be wrong for conv3d
shape = shape[1:]
num_input_fmaps = shape[1] num_input_fmaps = shape[1]
num_output_fmaps = shape[0] num_output_fmaps = shape[0]
receptive_field_size = 1 receptive_field_size = 1
@@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
def calculate_correct_fan(tensor: Tensor, mode: str) -> float: def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
r"""Calculates fan_in / fan_out value for given weight tensor, depending on given r"""Calculates fan_in / fan_out value for given weight tensor, depending on given
``mode``. ``mode``.
See :func:`calculate_fan_in_and_fan_out` for details. See :func:`calculate_fan_in_and_fan_out` for details.


Args: Args:
@@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)` r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)`
where where
.. math:: .. math::


a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` - `Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010). Glorot, X. & Bengio, Y. (2010).
@@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from r"""Fills tensor with random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where :math:`\mathcal{N}(0, \text{std}^2)` where
.. math:: .. math::


\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` - `Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010). Glorot, X. & Bengio, Y. (2010).
@@ -220,11 +229,11 @@ def msra_uniform_(
) -> None: ) -> None:
r"""Fills tensor wilth random values sampled from r"""Fills tensor wilth random values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math:: .. math::


\text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet `Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification` classification`
@@ -251,11 +260,11 @@ def msra_normal_(
) -> None: ) -> None:
r"""Fills tensor wilth random values sampled from r"""Fills tensor wilth random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where :math:`\mathcal{N}(0, \text{std}^2)` where
.. math:: .. math::


\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet `Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification` classification`


+ 28
- 1
imperative/python/test/unit/module/test_init.py View File

@@ -10,7 +10,7 @@ import numpy as np
import pytest import pytest


from megengine import tensor from megengine import tensor
from megengine.module import Conv2d, Linear
from megengine.module import Conv1d, Conv2d, Conv3d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out, fill_ from megengine.module.init import calculate_fan_in_and_fan_out, fill_




@@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out():
with pytest.raises(ValueError): with pytest.raises(ValueError):
calculate_fan_in_and_fan_out(l.bias) calculate_fan_in_and_fan_out(l.bias)


l = Conv1d(in_channels=2, out_channels=3, kernel_size=5)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5
assert fanout == 3 * 5

# FIXME: will be wrong for group conv1d
# l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2)
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 // 2 * 5
# assert fanout == 4 // 2 * 5

l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7))
fanin, fanout = calculate_fan_in_and_fan_out(l.weight) fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5 * 7 assert fanin == 2 * 5 * 7
assert fanout == 3 * 5 * 7 assert fanout == 3 * 5 * 7

l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7
assert fanout == 4 // 2 * 5 * 7

# FIXME: will be wrong for conv3d
# l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9))
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 * 5 * 7 * 9
# assert fanout == 3 * 5 * 7 * 9

l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7 * 9
assert fanout == 4 // 2 * 5 * 7 * 9

Loading…
Cancel
Save