|
|
@@ -37,8 +37,9 @@ class _ConvNd(Module): |
|
|
|
dilation: Union[int, Tuple[int, int]], |
|
|
|
groups: int, |
|
|
|
bias: bool = True, |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
super().__init__(**kwargs) |
|
|
|
if in_channels % groups != 0: |
|
|
|
raise ValueError("in_channels must be divisible by groups") |
|
|
|
if out_channels % groups != 0: |
|
|
@@ -176,6 +177,7 @@ class Conv1d(_ConvNd): |
|
|
|
bias: bool = True, |
|
|
|
conv_mode: str = "CROSS_CORRELATION", |
|
|
|
compute_mode: str = "DEFAULT", |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
kernel_size = kernel_size |
|
|
|
stride = stride |
|
|
@@ -192,6 +194,7 @@ class Conv1d(_ConvNd): |
|
|
|
dilation, |
|
|
|
groups, |
|
|
|
bias, |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
|
|
|
|
def _get_fanin(self): |
|
|
@@ -334,6 +337,7 @@ class Conv2d(_ConvNd): |
|
|
|
bias: bool = True, |
|
|
|
conv_mode: str = "CROSS_CORRELATION", |
|
|
|
compute_mode: str = "DEFAULT", |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
kernel_size = _pair_nonzero(kernel_size) |
|
|
|
stride = _pair_nonzero(stride) |
|
|
@@ -350,6 +354,7 @@ class Conv2d(_ConvNd): |
|
|
|
dilation, |
|
|
|
groups, |
|
|
|
bias, |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
|
|
|
|
def _get_fanin(self): |
|
|
@@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd): |
|
|
|
bias: bool = True, |
|
|
|
conv_mode: str = "CROSS_CORRELATION", |
|
|
|
compute_mode: str = "DEFAULT", |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
kernel_size = _pair_nonzero(kernel_size) |
|
|
|
stride = _pair_nonzero(stride) |
|
|
@@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd): |
|
|
|
dilation, |
|
|
|
groups, |
|
|
|
bias, |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
|
|
|
|
def _get_fanin(self): |
|
|
@@ -536,6 +543,7 @@ class LocalConv2d(Conv2d): |
|
|
|
dilation: Union[int, Tuple[int, int]] = 1, |
|
|
|
groups: int = 1, |
|
|
|
conv_mode: str = "CROSS_CORRELATION", |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
self.input_height = input_height |
|
|
|
self.input_width = input_width |
|
|
@@ -548,6 +556,7 @@ class LocalConv2d(Conv2d): |
|
|
|
dilation, |
|
|
|
groups, |
|
|
|
bias=False, |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
|
|
|
|
def _infer_weight_shape(self): |
|
|
|