Browse Source

fix(mge/module): add kwargs param for all modules

GitOrigin-RevId: 7245e669a7
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
47db29aaa2
12 changed files with 42 additions and 31 deletions
  1. +6
    -6
      imperative/python/megengine/module/activation.py
  2. +2
    -4
      imperative/python/megengine/module/adaptive_pooling.py
  3. +4
    -2
      imperative/python/megengine/module/batchnorm.py
  4. +10
    -1
      imperative/python/megengine/module/conv.py
  5. +2
    -0
      imperative/python/megengine/module/conv_bn.py
  6. +2
    -2
      imperative/python/megengine/module/dropout.py
  7. +2
    -2
      imperative/python/megengine/module/elemwise.py
  8. +2
    -1
      imperative/python/megengine/module/embedding.py
  9. +2
    -4
      imperative/python/megengine/module/external.py
  10. +6
    -6
      imperative/python/megengine/module/normalization.py
  11. +2
    -1
      imperative/python/megengine/module/pooling.py
  12. +2
    -2
      imperative/python/megengine/module/sequential.py

+ 6
- 6
imperative/python/megengine/module/activation.py View File

@@ -48,8 +48,8 @@ class Softmax(Module):


""" """


def __init__(self, axis=None):
super().__init__()
def __init__(self, axis=None, **kwargs):
super().__init__(**kwargs)
self.axis = axis self.axis = axis


def forward(self, inputs): def forward(self, inputs):
@@ -167,8 +167,8 @@ class PReLU(Module):


""" """


def __init__(self, num_parameters: int = 1, init: float = 0.25):
super().__init__()
def __init__(self, num_parameters: int = 1, init: float = 0.25, **kwargs):
super().__init__(**kwargs)
self.num_parameters = num_parameters self.num_parameters = num_parameters
if num_parameters > 1: if num_parameters > 1:
# Assume format is NCHW # Assume format is NCHW
@@ -225,8 +225,8 @@ class LeakyReLU(Module):


""" """


def __init__(self, negative_slope: float = 0.01):
super().__init__()
def __init__(self, negative_slope: float = 0.01, **kwargs):
super().__init__(**kwargs)
self.negative_slope = negative_slope self.negative_slope = negative_slope


def forward(self, inputs): def forward(self, inputs):


+ 2
- 4
imperative/python/megengine/module/adaptive_pooling.py View File

@@ -15,10 +15,8 @@ from .module import Module




class _AdaptivePoolNd(Module): class _AdaptivePoolNd(Module):
def __init__(
self, oshp: Union[Tuple[int, int], int, Tensor],
):
super(_AdaptivePoolNd, self).__init__()
def __init__(self, oshp: Union[Tuple[int, int], int, Tensor], **kwargs):
super(_AdaptivePoolNd, self).__init__(**kwargs)
self.oshp = oshp self.oshp = oshp


@abstractmethod @abstractmethod


+ 4
- 2
imperative/python/megengine/module/batchnorm.py View File

@@ -26,8 +26,9 @@ class _BatchNorm(Module):
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
freeze=False, freeze=False,
**kwargs
): ):
super(_BatchNorm, self).__init__()
super(_BatchNorm, self).__init__(**kwargs)
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = momentum
@@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm):
track_running_stats=True, track_running_stats=True,
freeze=False, freeze=False,
group: Optional[Group] = WORLD, group: Optional[Group] = WORLD,
**kwargs
) -> None: ) -> None:
super().__init__( super().__init__(
num_features, eps, momentum, affine, track_running_stats, freeze
num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
) )
self.group = group self.group = group




+ 10
- 1
imperative/python/megengine/module/conv.py View File

@@ -37,8 +37,9 @@ class _ConvNd(Module):
dilation: Union[int, Tuple[int, int]], dilation: Union[int, Tuple[int, int]],
groups: int, groups: int,
bias: bool = True, bias: bool = True,
**kwargs
): ):
super().__init__()
super().__init__(**kwargs)
if in_channels % groups != 0: if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups") raise ValueError("in_channels must be divisible by groups")
if out_channels % groups != 0: if out_channels % groups != 0:
@@ -176,6 +177,7 @@ class Conv1d(_ConvNd):
bias: bool = True, bias: bool = True,
conv_mode: str = "CROSS_CORRELATION", conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT", compute_mode: str = "DEFAULT",
**kwargs
): ):
kernel_size = kernel_size kernel_size = kernel_size
stride = stride stride = stride
@@ -192,6 +194,7 @@ class Conv1d(_ConvNd):
dilation, dilation,
groups, groups,
bias, bias,
**kwargs,
) )


def _get_fanin(self): def _get_fanin(self):
@@ -334,6 +337,7 @@ class Conv2d(_ConvNd):
bias: bool = True, bias: bool = True,
conv_mode: str = "CROSS_CORRELATION", conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT", compute_mode: str = "DEFAULT",
**kwargs
): ):
kernel_size = _pair_nonzero(kernel_size) kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride) stride = _pair_nonzero(stride)
@@ -350,6 +354,7 @@ class Conv2d(_ConvNd):
dilation, dilation,
groups, groups,
bias, bias,
**kwargs,
) )


def _get_fanin(self): def _get_fanin(self):
@@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd):
bias: bool = True, bias: bool = True,
conv_mode: str = "CROSS_CORRELATION", conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT", compute_mode: str = "DEFAULT",
**kwargs
): ):
kernel_size = _pair_nonzero(kernel_size) kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride) stride = _pair_nonzero(stride)
@@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd):
dilation, dilation,
groups, groups,
bias, bias,
**kwargs,
) )


def _get_fanin(self): def _get_fanin(self):
@@ -536,6 +543,7 @@ class LocalConv2d(Conv2d):
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1, groups: int = 1,
conv_mode: str = "CROSS_CORRELATION", conv_mode: str = "CROSS_CORRELATION",
**kwargs
): ):
self.input_height = input_height self.input_height = input_height
self.input_width = input_width self.input_width = input_width
@@ -548,6 +556,7 @@ class LocalConv2d(Conv2d):
dilation, dilation,
groups, groups,
bias=False, bias=False,
**kwargs,
) )


def _infer_weight_shape(self): def _infer_weight_shape(self):


+ 2
- 0
imperative/python/megengine/module/conv_bn.py View File

@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module):
momentum=0.9, momentum=0.9,
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
**kwargs
): ):
super().__init__() super().__init__()
self.conv = Conv2d( self.conv = Conv2d(
@@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module):
bias, bias,
conv_mode, conv_mode,
compute_mode, compute_mode,
**kwargs,
) )
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)




+ 2
- 2
imperative/python/megengine/module/dropout.py View File

@@ -20,8 +20,8 @@ class Dropout(Module):
:param drop_prob: The probability to drop (set to zero) each single element :param drop_prob: The probability to drop (set to zero) each single element
""" """


def __init__(self, drop_prob=0.0):
super().__init__()
def __init__(self, drop_prob=0.0, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prob self.drop_prob = drop_prob


def forward(self, inputs): def forward(self, inputs):


+ 2
- 2
imperative/python/megengine/module/elemwise.py View File

@@ -72,8 +72,8 @@ class Elemwise(Module):
* "NOT": bool unary: ~x * "NOT": bool unary: ~x
""" """


def __init__(self, method):
super().__init__()
def __init__(self, method, **kwargs):
super().__init__(**kwargs)
self.method = method self.method = method


def forward(self, *inps): def forward(self, *inps):


+ 2
- 1
imperative/python/megengine/module/embedding.py View File

@@ -64,8 +64,9 @@ class Embedding(Module):
norm_type: Optional[float] = None, norm_type: Optional[float] = None,
initial_weight: Parameter = None, initial_weight: Parameter = None,
freeze: bool = False, freeze: bool = False,
**kwargs
): ):
super().__init__()
super().__init__(**kwargs)
if padding_idx is not None: if padding_idx is not None:
raise ValueError("Not support padding index now.") raise ValueError("Not support padding index now.")
if max_norm is not None or norm_type is not None: if max_norm is not None or norm_type is not None:


+ 2
- 4
imperative/python/megengine/module/external.py View File

@@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module):
See :func:`~.tensorrt_runtime_opr` for more details. See :func:`~.tensorrt_runtime_opr` for more details.
""" """


def __init__(
self, data,
):
super(TensorrtRuntimeSubgraph, self).__init__()
def __init__(self, data, **kwargs):
super(TensorrtRuntimeSubgraph, self).__init__(**kwargs)
self._data = data self._data = data


@property @property


+ 6
- 6
imperative/python/megengine/module/normalization.py View File

@@ -20,8 +20,8 @@ class GroupNorm(Module):
Reference: https://arxiv.org/pdf/1803.08494.pdf. Reference: https://arxiv.org/pdf/1803.08494.pdf.
""" """


def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
super().__init__()
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs):
super().__init__(**kwargs)
assert num_channels % num_groups == 0 assert num_channels % num_groups == 0
self.num_groups = num_groups self.num_groups = num_groups
self.num_channels = num_channels self.num_channels = num_channels
@@ -70,8 +70,8 @@ class InstanceNorm(Module):
Note that InstanceNorm equals using GroupNome with num_groups=num_channels. Note that InstanceNorm equals using GroupNome with num_groups=num_channels.
""" """


def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels self.num_channels = num_channels
self.eps = eps self.eps = eps
self.affine = affine self.affine = affine
@@ -114,8 +114,8 @@ class LayerNorm(Module):
Note that LayerNorm equals using GroupNorm with num_groups=1. Note that LayerNorm equals using GroupNorm with num_groups=1.
""" """


def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels self.num_channels = num_channels
self.eps = eps self.eps = eps
self.affine = affine self.affine = affine


+ 2
- 1
imperative/python/megengine/module/pooling.py View File

@@ -19,8 +19,9 @@ class _PoolNd(Module):
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = None, stride: Union[int, Tuple[int, int]] = None,
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
**kwargs
): ):
super(_PoolNd, self).__init__()
super(_PoolNd, self).__init__(**kwargs)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = stride or kernel_size self.stride = stride or kernel_size
self.padding = padding self.padding = padding


+ 2
- 2
imperative/python/megengine/module/sequential.py View File

@@ -46,8 +46,8 @@ class Sequential(Module):
pred1 = net1(data) pred1 = net1(data)
""" """


def __init__(self, *args):
super().__init__()
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.layer_keys = [] self.layer_keys = []
if len(args) == 1 and isinstance(args[0], OrderedDict): if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items(): for key, module in args[0].items():


Loading…
Cancel
Save