GitOrigin-RevId: 7245e669a7
tags/v1.3.0
@@ -48,8 +48,8 @@ class Softmax(Module): | |||||
""" | """ | ||||
def __init__(self, axis=None): | |||||
super().__init__() | |||||
def __init__(self, axis=None, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.axis = axis | self.axis = axis | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
@@ -167,8 +167,8 @@ class PReLU(Module): | |||||
""" | """ | ||||
def __init__(self, num_parameters: int = 1, init: float = 0.25): | |||||
super().__init__() | |||||
def __init__(self, num_parameters: int = 1, init: float = 0.25, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.num_parameters = num_parameters | self.num_parameters = num_parameters | ||||
if num_parameters > 1: | if num_parameters > 1: | ||||
# Assume format is NCHW | # Assume format is NCHW | ||||
@@ -225,8 +225,8 @@ class LeakyReLU(Module): | |||||
""" | """ | ||||
def __init__(self, negative_slope: float = 0.01): | |||||
super().__init__() | |||||
def __init__(self, negative_slope: float = 0.01, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.negative_slope = negative_slope | self.negative_slope = negative_slope | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
@@ -15,10 +15,8 @@ from .module import Module | |||||
class _AdaptivePoolNd(Module): | class _AdaptivePoolNd(Module): | ||||
def __init__( | |||||
self, oshp: Union[Tuple[int, int], int, Tensor], | |||||
): | |||||
super(_AdaptivePoolNd, self).__init__() | |||||
def __init__(self, oshp: Union[Tuple[int, int], int, Tensor], **kwargs): | |||||
super(_AdaptivePoolNd, self).__init__(**kwargs) | |||||
self.oshp = oshp | self.oshp = oshp | ||||
@abstractmethod | @abstractmethod | ||||
@@ -26,8 +26,9 @@ class _BatchNorm(Module): | |||||
affine=True, | affine=True, | ||||
track_running_stats=True, | track_running_stats=True, | ||||
freeze=False, | freeze=False, | ||||
**kwargs | |||||
): | ): | ||||
super(_BatchNorm, self).__init__() | |||||
super(_BatchNorm, self).__init__(**kwargs) | |||||
self.num_features = num_features | self.num_features = num_features | ||||
self.eps = eps | self.eps = eps | ||||
self.momentum = momentum | self.momentum = momentum | ||||
@@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm): | |||||
track_running_stats=True, | track_running_stats=True, | ||||
freeze=False, | freeze=False, | ||||
group: Optional[Group] = WORLD, | group: Optional[Group] = WORLD, | ||||
**kwargs | |||||
) -> None: | ) -> None: | ||||
super().__init__( | super().__init__( | ||||
num_features, eps, momentum, affine, track_running_stats, freeze | |||||
num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs | |||||
) | ) | ||||
self.group = group | self.group = group | ||||
@@ -37,8 +37,9 @@ class _ConvNd(Module): | |||||
dilation: Union[int, Tuple[int, int]], | dilation: Union[int, Tuple[int, int]], | ||||
groups: int, | groups: int, | ||||
bias: bool = True, | bias: bool = True, | ||||
**kwargs | |||||
): | ): | ||||
super().__init__() | |||||
super().__init__(**kwargs) | |||||
if in_channels % groups != 0: | if in_channels % groups != 0: | ||||
raise ValueError("in_channels must be divisible by groups") | raise ValueError("in_channels must be divisible by groups") | ||||
if out_channels % groups != 0: | if out_channels % groups != 0: | ||||
@@ -176,6 +177,7 @@ class Conv1d(_ConvNd): | |||||
bias: bool = True, | bias: bool = True, | ||||
conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
**kwargs | |||||
): | ): | ||||
kernel_size = kernel_size | kernel_size = kernel_size | ||||
stride = stride | stride = stride | ||||
@@ -192,6 +194,7 @@ class Conv1d(_ConvNd): | |||||
dilation, | dilation, | ||||
groups, | groups, | ||||
bias, | bias, | ||||
**kwargs, | |||||
) | ) | ||||
def _get_fanin(self): | def _get_fanin(self): | ||||
@@ -334,6 +337,7 @@ class Conv2d(_ConvNd): | |||||
bias: bool = True, | bias: bool = True, | ||||
conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
**kwargs | |||||
): | ): | ||||
kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
stride = _pair_nonzero(stride) | stride = _pair_nonzero(stride) | ||||
@@ -350,6 +354,7 @@ class Conv2d(_ConvNd): | |||||
dilation, | dilation, | ||||
groups, | groups, | ||||
bias, | bias, | ||||
**kwargs, | |||||
) | ) | ||||
def _get_fanin(self): | def _get_fanin(self): | ||||
@@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd): | |||||
bias: bool = True, | bias: bool = True, | ||||
conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
**kwargs | |||||
): | ): | ||||
kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
stride = _pair_nonzero(stride) | stride = _pair_nonzero(stride) | ||||
@@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd): | |||||
dilation, | dilation, | ||||
groups, | groups, | ||||
bias, | bias, | ||||
**kwargs, | |||||
) | ) | ||||
def _get_fanin(self): | def _get_fanin(self): | ||||
@@ -536,6 +543,7 @@ class LocalConv2d(Conv2d): | |||||
dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
groups: int = 1, | groups: int = 1, | ||||
conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
**kwargs | |||||
): | ): | ||||
self.input_height = input_height | self.input_height = input_height | ||||
self.input_width = input_width | self.input_width = input_width | ||||
@@ -548,6 +556,7 @@ class LocalConv2d(Conv2d): | |||||
dilation, | dilation, | ||||
groups, | groups, | ||||
bias=False, | bias=False, | ||||
**kwargs, | |||||
) | ) | ||||
def _infer_weight_shape(self): | def _infer_weight_shape(self): | ||||
@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): | |||||
momentum=0.9, | momentum=0.9, | ||||
affine=True, | affine=True, | ||||
track_running_stats=True, | track_running_stats=True, | ||||
**kwargs | |||||
): | ): | ||||
super().__init__() | super().__init__() | ||||
self.conv = Conv2d( | self.conv = Conv2d( | ||||
@@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module): | |||||
bias, | bias, | ||||
conv_mode, | conv_mode, | ||||
compute_mode, | compute_mode, | ||||
**kwargs, | |||||
) | ) | ||||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | ||||
@@ -20,8 +20,8 @@ class Dropout(Module): | |||||
:param drop_prob: The probability to drop (set to zero) each single element | :param drop_prob: The probability to drop (set to zero) each single element | ||||
""" | """ | ||||
def __init__(self, drop_prob=0.0): | |||||
super().__init__() | |||||
def __init__(self, drop_prob=0.0, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.drop_prob = drop_prob | self.drop_prob = drop_prob | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
@@ -72,8 +72,8 @@ class Elemwise(Module): | |||||
* "NOT": bool unary: ~x | * "NOT": bool unary: ~x | ||||
""" | """ | ||||
def __init__(self, method): | |||||
super().__init__() | |||||
def __init__(self, method, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.method = method | self.method = method | ||||
def forward(self, *inps): | def forward(self, *inps): | ||||
@@ -64,8 +64,9 @@ class Embedding(Module): | |||||
norm_type: Optional[float] = None, | norm_type: Optional[float] = None, | ||||
initial_weight: Parameter = None, | initial_weight: Parameter = None, | ||||
freeze: bool = False, | freeze: bool = False, | ||||
**kwargs | |||||
): | ): | ||||
super().__init__() | |||||
super().__init__(**kwargs) | |||||
if padding_idx is not None: | if padding_idx is not None: | ||||
raise ValueError("Not support padding index now.") | raise ValueError("Not support padding index now.") | ||||
if max_norm is not None or norm_type is not None: | if max_norm is not None or norm_type is not None: | ||||
@@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module): | |||||
See :func:`~.tensorrt_runtime_opr` for more details. | See :func:`~.tensorrt_runtime_opr` for more details. | ||||
""" | """ | ||||
def __init__( | |||||
self, data, | |||||
): | |||||
super(TensorrtRuntimeSubgraph, self).__init__() | |||||
def __init__(self, data, **kwargs): | |||||
super(TensorrtRuntimeSubgraph, self).__init__(**kwargs) | |||||
self._data = data | self._data = data | ||||
@property | @property | ||||
@@ -20,8 +20,8 @@ class GroupNorm(Module): | |||||
Reference: https://arxiv.org/pdf/1803.08494.pdf. | Reference: https://arxiv.org/pdf/1803.08494.pdf. | ||||
""" | """ | ||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): | |||||
super().__init__() | |||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): | |||||
super().__init__(**kwargs) | |||||
assert num_channels % num_groups == 0 | assert num_channels % num_groups == 0 | ||||
self.num_groups = num_groups | self.num_groups = num_groups | ||||
self.num_channels = num_channels | self.num_channels = num_channels | ||||
@@ -70,8 +70,8 @@ class InstanceNorm(Module): | |||||
Note that InstanceNorm equals using GroupNome with num_groups=num_channels. | Note that InstanceNorm equals using GroupNome with num_groups=num_channels. | ||||
""" | """ | ||||
def __init__(self, num_channels, eps=1e-05, affine=True): | |||||
super().__init__() | |||||
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.num_channels = num_channels | self.num_channels = num_channels | ||||
self.eps = eps | self.eps = eps | ||||
self.affine = affine | self.affine = affine | ||||
@@ -114,8 +114,8 @@ class LayerNorm(Module): | |||||
Note that LayerNorm equals using GroupNorm with num_groups=1. | Note that LayerNorm equals using GroupNorm with num_groups=1. | ||||
""" | """ | ||||
def __init__(self, num_channels, eps=1e-05, affine=True): | |||||
super().__init__() | |||||
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.num_channels = num_channels | self.num_channels = num_channels | ||||
self.eps = eps | self.eps = eps | ||||
self.affine = affine | self.affine = affine | ||||
@@ -19,8 +19,9 @@ class _PoolNd(Module): | |||||
kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
stride: Union[int, Tuple[int, int]] = None, | stride: Union[int, Tuple[int, int]] = None, | ||||
padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
**kwargs | |||||
): | ): | ||||
super(_PoolNd, self).__init__() | |||||
super(_PoolNd, self).__init__(**kwargs) | |||||
self.kernel_size = kernel_size | self.kernel_size = kernel_size | ||||
self.stride = stride or kernel_size | self.stride = stride or kernel_size | ||||
self.padding = padding | self.padding = padding | ||||
@@ -46,8 +46,8 @@ class Sequential(Module): | |||||
pred1 = net1(data) | pred1 = net1(data) | ||||
""" | """ | ||||
def __init__(self, *args): | |||||
super().__init__() | |||||
def __init__(self, *args, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.layer_keys = [] | self.layer_keys = [] | ||||
if len(args) == 1 and isinstance(args[0], OrderedDict): | if len(args) == 1 and isinstance(args[0], OrderedDict): | ||||
for key, module in args[0].items(): | for key, module in args[0].items(): | ||||