Browse Source

fix(mge/module): add kwargs param for all modules

GitOrigin-RevId: 7245e669a7
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
47db29aaa2
12 changed files with 42 additions and 31 deletions
  1. +6
    -6
      imperative/python/megengine/module/activation.py
  2. +2
    -4
      imperative/python/megengine/module/adaptive_pooling.py
  3. +4
    -2
      imperative/python/megengine/module/batchnorm.py
  4. +10
    -1
      imperative/python/megengine/module/conv.py
  5. +2
    -0
      imperative/python/megengine/module/conv_bn.py
  6. +2
    -2
      imperative/python/megengine/module/dropout.py
  7. +2
    -2
      imperative/python/megengine/module/elemwise.py
  8. +2
    -1
      imperative/python/megengine/module/embedding.py
  9. +2
    -4
      imperative/python/megengine/module/external.py
  10. +6
    -6
      imperative/python/megengine/module/normalization.py
  11. +2
    -1
      imperative/python/megengine/module/pooling.py
  12. +2
    -2
      imperative/python/megengine/module/sequential.py

+ 6
- 6
imperative/python/megengine/module/activation.py View File

@@ -48,8 +48,8 @@ class Softmax(Module):

"""

def __init__(self, axis=None):
super().__init__()
def __init__(self, axis=None, **kwargs):
super().__init__(**kwargs)
self.axis = axis

def forward(self, inputs):
@@ -167,8 +167,8 @@ class PReLU(Module):

"""

def __init__(self, num_parameters: int = 1, init: float = 0.25):
super().__init__()
def __init__(self, num_parameters: int = 1, init: float = 0.25, **kwargs):
super().__init__(**kwargs)
self.num_parameters = num_parameters
if num_parameters > 1:
# Assume format is NCHW
@@ -225,8 +225,8 @@ class LeakyReLU(Module):

"""

def __init__(self, negative_slope: float = 0.01):
super().__init__()
def __init__(self, negative_slope: float = 0.01, **kwargs):
super().__init__(**kwargs)
self.negative_slope = negative_slope

def forward(self, inputs):


+ 2
- 4
imperative/python/megengine/module/adaptive_pooling.py View File

@@ -15,10 +15,8 @@ from .module import Module


class _AdaptivePoolNd(Module):
def __init__(
self, oshp: Union[Tuple[int, int], int, Tensor],
):
super(_AdaptivePoolNd, self).__init__()
def __init__(self, oshp: Union[Tuple[int, int], int, Tensor], **kwargs):
super(_AdaptivePoolNd, self).__init__(**kwargs)
self.oshp = oshp

@abstractmethod


+ 4
- 2
imperative/python/megengine/module/batchnorm.py View File

@@ -26,8 +26,9 @@ class _BatchNorm(Module):
affine=True,
track_running_stats=True,
freeze=False,
**kwargs
):
super(_BatchNorm, self).__init__()
super(_BatchNorm, self).__init__(**kwargs)
self.num_features = num_features
self.eps = eps
self.momentum = momentum
@@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm):
track_running_stats=True,
freeze=False,
group: Optional[Group] = WORLD,
**kwargs
) -> None:
super().__init__(
num_features, eps, momentum, affine, track_running_stats, freeze
num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
)
self.group = group



+ 10
- 1
imperative/python/megengine/module/conv.py View File

@@ -37,8 +37,9 @@ class _ConvNd(Module):
dilation: Union[int, Tuple[int, int]],
groups: int,
bias: bool = True,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups")
if out_channels % groups != 0:
@@ -176,6 +177,7 @@ class Conv1d(_ConvNd):
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
**kwargs
):
kernel_size = kernel_size
stride = stride
@@ -192,6 +194,7 @@ class Conv1d(_ConvNd):
dilation,
groups,
bias,
**kwargs,
)

def _get_fanin(self):
@@ -334,6 +337,7 @@ class Conv2d(_ConvNd):
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride)
@@ -350,6 +354,7 @@ class Conv2d(_ConvNd):
dilation,
groups,
bias,
**kwargs,
)

def _get_fanin(self):
@@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd):
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride)
@@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd):
dilation,
groups,
bias,
**kwargs,
)

def _get_fanin(self):
@@ -536,6 +543,7 @@ class LocalConv2d(Conv2d):
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
**kwargs
):
self.input_height = input_height
self.input_width = input_width
@@ -548,6 +556,7 @@ class LocalConv2d(Conv2d):
dilation,
groups,
bias=False,
**kwargs,
)

def _infer_weight_shape(self):


+ 2
- 0
imperative/python/megengine/module/conv_bn.py View File

@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module):
momentum=0.9,
affine=True,
track_running_stats=True,
**kwargs
):
super().__init__()
self.conv = Conv2d(
@@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module):
bias,
conv_mode,
compute_mode,
**kwargs,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)



+ 2
- 2
imperative/python/megengine/module/dropout.py View File

@@ -20,8 +20,8 @@ class Dropout(Module):
:param drop_prob: The probability to drop (set to zero) each single element
"""

def __init__(self, drop_prob=0.0):
super().__init__()
def __init__(self, drop_prob=0.0, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prob

def forward(self, inputs):


+ 2
- 2
imperative/python/megengine/module/elemwise.py View File

@@ -72,8 +72,8 @@ class Elemwise(Module):
* "NOT": bool unary: ~x
"""

def __init__(self, method):
super().__init__()
def __init__(self, method, **kwargs):
super().__init__(**kwargs)
self.method = method

def forward(self, *inps):


+ 2
- 1
imperative/python/megengine/module/embedding.py View File

@@ -64,8 +64,9 @@ class Embedding(Module):
norm_type: Optional[float] = None,
initial_weight: Parameter = None,
freeze: bool = False,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
if padding_idx is not None:
raise ValueError("Not support padding index now.")
if max_norm is not None or norm_type is not None:


+ 2
- 4
imperative/python/megengine/module/external.py View File

@@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module):
See :func:`~.tensorrt_runtime_opr` for more details.
"""

def __init__(
self, data,
):
super(TensorrtRuntimeSubgraph, self).__init__()
def __init__(self, data, **kwargs):
super(TensorrtRuntimeSubgraph, self).__init__(**kwargs)
self._data = data

@property


+ 6
- 6
imperative/python/megengine/module/normalization.py View File

@@ -20,8 +20,8 @@ class GroupNorm(Module):
Reference: https://arxiv.org/pdf/1803.08494.pdf.
"""

def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
super().__init__()
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs):
super().__init__(**kwargs)
assert num_channels % num_groups == 0
self.num_groups = num_groups
self.num_channels = num_channels
@@ -70,8 +70,8 @@ class InstanceNorm(Module):
Note that InstanceNorm equals using GroupNome with num_groups=num_channels.
"""

def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels
self.eps = eps
self.affine = affine
@@ -114,8 +114,8 @@ class LayerNorm(Module):
Note that LayerNorm equals using GroupNorm with num_groups=1.
"""

def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels
self.eps = eps
self.affine = affine


+ 2
- 1
imperative/python/megengine/module/pooling.py View File

@@ -19,8 +19,9 @@ class _PoolNd(Module):
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = None,
padding: Union[int, Tuple[int, int]] = 0,
**kwargs
):
super(_PoolNd, self).__init__()
super(_PoolNd, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding


+ 2
- 2
imperative/python/megengine/module/sequential.py View File

@@ -46,8 +46,8 @@ class Sequential(Module):
pred1 = net1(data)
"""

def __init__(self, *args):
super().__init__()
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.layer_keys = []
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():


Loading…
Cancel
Save