@@ -35,6 +35,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
param().padding_val, stream); \ | param().padding_val, stream); \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -60,7 +60,8 @@ __global__ void paddingConst_kernel( | |||||
params.src_stride[dim].divisor(); | params.src_stride[dim].divisor(); | ||||
*/ | */ | ||||
} | } | ||||
dst[out_index] = in_src_valid_area ? src[in_index] : padding_val; | |||||
dst[out_index] = | |||||
in_src_valid_area ? src[in_index] : static_cast<T>(padding_val); | |||||
} | } | ||||
} | } | ||||
@@ -256,6 +257,7 @@ void padding_backward_proxy( | |||||
const float_t padding_val, cudaStream_t stream); | const float_t padding_val, cudaStream_t stream); | ||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
#undef cb | #undef cb | ||||
#undef INST | #undef INST | ||||
@@ -171,7 +171,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
switch (param().padding_mode) { | switch (param().padding_mode) { | ||||
case param::Padding::PaddingMode::CONSTANT: | case param::Padding::PaddingMode::CONSTANT: | ||||
#define cb(DType) \ | #define cb(DType) \ | ||||
if (src.layout.dtype == DType()) { \ | |||||
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | using T = typename DTypeTrait<DType>::ctype; \ | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_const_internal<T>( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(exec_const_internal<T>( \ | ||||
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params, \ | src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params, \ | ||||
@@ -179,28 +179,31 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
return; \ | return; \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
#undef cb | #undef cb | ||||
break; | break; | ||||
case param::Padding::PaddingMode::REPLICATE: | case param::Padding::PaddingMode::REPLICATE: | ||||
#define cb(DType) \ | #define cb(DType) \ | ||||
if (src.layout.dtype == DType()) { \ | |||||
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | using T = typename DTypeTrait<DType>::ctype; \ | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_replicate_internal<T>( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(exec_replicate_internal<T>( \ | ||||
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ | src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ | ||||
return; \ | return; \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
#undef cb | #undef cb | ||||
break; | break; | ||||
case param::Padding::PaddingMode::REFLECT: | case param::Padding::PaddingMode::REFLECT: | ||||
#define cb(DType) \ | #define cb(DType) \ | ||||
if (src.layout.dtype == DType()) { \ | |||||
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | using T = typename DTypeTrait<DType>::ctype; \ | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_reflect_internal<T>( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(exec_reflect_internal<T>( \ | ||||
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ | src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ | ||||
return; \ | return; \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
#undef cb | #undef cb | ||||
break; | break; | ||||
default: | default: | ||||
@@ -101,6 +101,36 @@ TEST_F(CUDA, PADDING_REFLECT2) { | |||||
4, 1, 6, 3, 6, 1, 6, 3})}); | 4, 1, 6, 3, 6, 1, 6, 3})}); | ||||
} | } | ||||
TEST_F(CUDA, PADDING_REFLECT2_QUANTIZED) { | |||||
Checker<Padding> checker(handle_cuda(), false); | |||||
param::Padding param; | |||||
param.padding_mode = param::Padding::PaddingMode::REFLECT; | |||||
param.front_offset_dim0 = 2; | |||||
param.front_offset_dim1 = 1; | |||||
param.front_offset_dim2 = 0; | |||||
param.front_offset_dim3 = 0; | |||||
param.front_offset_dim4 = 0; | |||||
param.front_offset_dim5 = 0; | |||||
param.front_offset_dim6 = 0; | |||||
param.back_offset_dim0 = 0; | |||||
param.back_offset_dim1 = 2; | |||||
param.back_offset_dim2 = 0; | |||||
param.back_offset_dim3 = 0; | |||||
param.back_offset_dim4 = 0; | |||||
param.back_offset_dim5 = 0; | |||||
param.back_offset_dim6 = 0; | |||||
checker.set_param(param).exect( | |||||
Testcase{ | |||||
TensorValue( | |||||
{3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}), | |||||
{}}, | |||||
Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5, | |||||
4, 5, 6, 5, 4, 2, 1, | |||||
2, 3, 2, 1, 5, 4, 5, | |||||
6, 5, 4, 8, 7, 8, 9, | |||||
8, 7})}); | |||||
} | |||||
TEST_F(CUDA, PADDING_REPLICATE) { | TEST_F(CUDA, PADDING_REPLICATE) { | ||||
Checker<Padding> checker(handle_cuda(), false); | Checker<Padding> checker(handle_cuda(), false); | ||||
param::Padding param; | param::Padding param; | ||||
@@ -83,6 +83,36 @@ TEST_F(NAIVE, PADDING_REFLECT) { | |||||
{10}, dtype::Float32(), {3, 2, 1, 2, 3, 4, 5, 4, 3, 2})}); | {10}, dtype::Float32(), {3, 2, 1, 2, 3, 4, 5, 4, 3, 2})}); | ||||
} | } | ||||
TEST_F(NAIVE, PADDING_REFLECT2) { | |||||
Checker<Padding> checker(handle(), false); | |||||
param::Padding param; | |||||
param.padding_mode = param::Padding::PaddingMode::REFLECT; | |||||
param.front_offset_dim0 = 2; | |||||
param.front_offset_dim1 = 1; | |||||
param.front_offset_dim2 = 0; | |||||
param.front_offset_dim3 = 0; | |||||
param.front_offset_dim4 = 0; | |||||
param.front_offset_dim5 = 0; | |||||
param.front_offset_dim6 = 0; | |||||
param.back_offset_dim0 = 0; | |||||
param.back_offset_dim1 = 2; | |||||
param.back_offset_dim2 = 0; | |||||
param.back_offset_dim3 = 0; | |||||
param.back_offset_dim4 = 0; | |||||
param.back_offset_dim5 = 0; | |||||
param.back_offset_dim6 = 0; | |||||
checker.set_param(param).exect( | |||||
Testcase{ | |||||
TensorValue( | |||||
{3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}), | |||||
{}}, | |||||
Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5, | |||||
4, 5, 6, 5, 4, 2, 1, | |||||
2, 3, 2, 1, 5, 4, 5, | |||||
6, 5, 4, 8, 7, 8, 9, | |||||
8, 7})}); | |||||
} | |||||
TEST_F(NAIVE, PADDING_REPLICATE) { | TEST_F(NAIVE, PADDING_REPLICATE) { | ||||
Checker<Padding> checker(handle(), false); | Checker<Padding> checker(handle(), false); | ||||
param::Padding param; | param::Padding param; | ||||
@@ -18,6 +18,7 @@ from ..functional import ( | |||||
conv_transpose3d, | conv_transpose3d, | ||||
deformable_conv2d, | deformable_conv2d, | ||||
local_conv2d, | local_conv2d, | ||||
pad, | |||||
relu, | relu, | ||||
) | ) | ||||
from ..tensor import Parameter | from ..tensor import Parameter | ||||
@@ -126,7 +127,7 @@ class Conv1d(_ConvNd): | |||||
kernel_size: size of weight on spatial dimensions. | kernel_size: size of weight on spatial dimensions. | ||||
stride: stride of the 1D convolution operation. | stride: stride of the 1D convolution operation. | ||||
padding: size of the paddings added to the input on both sides of its | padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
spatial dimensions. Default: 0 | |||||
dilation: dilation of the 1D convolution operation. Default: 1 | dilation: dilation of the 1D convolution operation. Default: 1 | ||||
groups: number of groups to divide input and output channels into, | groups: number of groups to divide input and output channels into, | ||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | so as to perform a "grouped convolution". When ``groups`` is not 1, | ||||
@@ -139,6 +140,8 @@ class Conv1d(_ConvNd): | |||||
placed on the precision of intermediate results. When set to "float32", | placed on the precision of intermediate results. When set to "float32", | ||||
"float32" would be used for accumulator and intermediate result, but only | "float32" would be used for accumulator and intermediate result, but only | ||||
effective when input and output are of float16 dtype. | effective when input and output are of float16 dtype. | ||||
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". | |||||
Refer to :class:`~.module.padding.Pad` for more information. | |||||
Note: | Note: | ||||
* ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` , | * ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` , | ||||
@@ -177,6 +180,7 @@ class Conv1d(_ConvNd): | |||||
bias: bool = True, | bias: bool = True, | ||||
conv_mode: str = "cross_correlation", | conv_mode: str = "cross_correlation", | ||||
compute_mode: str = "default", | compute_mode: str = "default", | ||||
padding_mode: str = "zeros", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
kernel_size = kernel_size | kernel_size = kernel_size | ||||
@@ -185,6 +189,7 @@ class Conv1d(_ConvNd): | |||||
dilation = dilation | dilation = dilation | ||||
self.conv_mode = conv_mode | self.conv_mode = conv_mode | ||||
self.compute_mode = compute_mode | self.compute_mode = compute_mode | ||||
self.padding_mode = padding_mode | |||||
super().__init__( | super().__init__( | ||||
in_channels, | in_channels, | ||||
out_channels, | out_channels, | ||||
@@ -223,7 +228,27 @@ class Conv1d(_ConvNd): | |||||
# Assume format is NCH(W=1) | # Assume format is NCH(W=1) | ||||
return (1, self.out_channels, 1) | return (1, self.out_channels, 1) | ||||
def get_pad_witdth(self): | |||||
return ((0, 0), (0, 0), (self.padding, self.padding)) | |||||
def calc_conv(self, inp, weight, bias): | def calc_conv(self, inp, weight, bias): | ||||
assert self.padding_mode in [ | |||||
"zeros", | |||||
"reflect", | |||||
"replicate", | |||||
] | |||||
if self.padding_mode != "zeros": | |||||
return conv1d( | |||||
pad(inp, self.get_pad_witdth(), self.padding_mode), | |||||
weight, | |||||
bias, | |||||
self.stride, | |||||
0, | |||||
self.dilation, | |||||
self.groups, | |||||
self.conv_mode, | |||||
self.compute_mode, | |||||
) | |||||
return conv1d( | return conv1d( | ||||
inp, | inp, | ||||
weight, | weight, | ||||
@@ -287,7 +312,7 @@ class Conv2d(_ConvNd): | |||||
``(kernel_size, kernel_size)``. | ``(kernel_size, kernel_size)``. | ||||
stride: stride of the 2D convolution operation. Default: 1 | stride: stride of the 2D convolution operation. Default: 1 | ||||
padding: size of the paddings added to the input on both sides of its | padding: size of the paddings added to the input on both sides of its | ||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
spatial dimensions. Default: 0 | |||||
dilation: dilation of the 2D convolution operation. Default: 1 | dilation: dilation of the 2D convolution operation. Default: 1 | ||||
groups: number of groups into which the input and output channels are divided, | groups: number of groups into which the input and output channels are divided, | ||||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | so as to perform a ``grouped convolution``. When ``groups`` is not 1, | ||||
@@ -300,6 +325,8 @@ class Conv2d(_ConvNd): | |||||
placed on the precision of intermediate results. When set to "float32", | placed on the precision of intermediate results. When set to "float32", | ||||
"float32" would be used for accumulator and intermediate result, but only | "float32" would be used for accumulator and intermediate result, but only | ||||
effective when input and output are of float16 dtype. | effective when input and output are of float16 dtype. | ||||
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". | |||||
Refer to :class:`~.module.padding.Pad` for more information. | |||||
Note: | Note: | ||||
* ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` , | * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` , | ||||
@@ -338,6 +365,7 @@ class Conv2d(_ConvNd): | |||||
bias: bool = True, | bias: bool = True, | ||||
conv_mode: str = "cross_correlation", | conv_mode: str = "cross_correlation", | ||||
compute_mode: str = "default", | compute_mode: str = "default", | ||||
padding_mode: str = "zeros", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
@@ -346,6 +374,7 @@ class Conv2d(_ConvNd): | |||||
dilation = _pair_nonzero(dilation) | dilation = _pair_nonzero(dilation) | ||||
self.conv_mode = conv_mode | self.conv_mode = conv_mode | ||||
self.compute_mode = compute_mode | self.compute_mode = compute_mode | ||||
self.padding_mode = padding_mode | |||||
super().__init__( | super().__init__( | ||||
in_channels, | in_channels, | ||||
out_channels, | out_channels, | ||||
@@ -384,7 +413,32 @@ class Conv2d(_ConvNd): | |||||
# Assume format is NCHW | # Assume format is NCHW | ||||
return (1, self.out_channels, 1, 1) | return (1, self.out_channels, 1, 1) | ||||
def get_pad_witdth(self): | |||||
return ( | |||||
(0, 0), | |||||
(0, 0), | |||||
(self.padding[0], self.padding[0]), | |||||
(self.padding[1], self.padding[1]), | |||||
) | |||||
def calc_conv(self, inp, weight, bias): | def calc_conv(self, inp, weight, bias): | ||||
assert self.padding_mode in [ | |||||
"zeros", | |||||
"reflect", | |||||
"replicate", | |||||
] | |||||
if self.padding_mode != "zeros": | |||||
return conv2d( | |||||
pad(inp, self.get_pad_witdth(), self.padding_mode), | |||||
weight, | |||||
bias, | |||||
self.stride, | |||||
0, | |||||
self.dilation, | |||||
self.groups, | |||||
self.conv_mode, | |||||
self.compute_mode, | |||||
) | |||||
return conv2d( | return conv2d( | ||||
inp, | inp, | ||||
weight, | weight, | ||||
@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): | |||||
momentum=0.9, | momentum=0.9, | ||||
affine=True, | affine=True, | ||||
track_running_stats=True, | track_running_stats=True, | ||||
padding_mode: str = "zeros", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
@@ -44,6 +45,7 @@ class _ConvBnActivation2d(Module): | |||||
bias, | bias, | ||||
conv_mode, | conv_mode, | ||||
compute_mode, | compute_mode, | ||||
padding_mode, | |||||
**kwargs, | **kwargs, | ||||
) | ) | ||||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | ||||
@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QATModule): | |||||
float_module.bias is not None, | float_module.bias is not None, | ||||
float_module.conv_mode, | float_module.conv_mode, | ||||
float_module.compute_mode, | float_module.compute_mode, | ||||
float_module.padding_mode, | |||||
name=float_module.name, | name=float_module.name, | ||||
) | ) | ||||
qat_module.weight = float_module.weight | qat_module.weight = float_module.weight | ||||
@@ -147,6 +147,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
float_module.conv.bias is not None, | float_module.conv.bias is not None, | ||||
float_module.conv.conv_mode, | float_module.conv.conv_mode, | ||||
float_module.conv.compute_mode, | float_module.conv.compute_mode, | ||||
padding_mode=float_module.conv.padding_mode, | |||||
name=float_module.name, | name=float_module.name, | ||||
) | ) | ||||
qat_module.conv.weight = float_module.conv.weight | qat_module.conv.weight = float_module.conv.weight | ||||
@@ -11,7 +11,7 @@ import numpy as np | |||||
from ... import module as Float | from ... import module as Float | ||||
from ...core.tensor import dtype | from ...core.tensor import dtype | ||||
from ...functional.nn import conv_bias_activation | |||||
from ...functional.nn import conv_bias_activation, pad | |||||
from ...functional.quantized import conv_transpose2d | from ...functional.quantized import conv_transpose2d | ||||
from ...tensor import Parameter | from ...tensor import Parameter | ||||
from ..qat import conv as QAT | from ..qat import conv as QAT | ||||
@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
conv_mode: str = "cross_correlation", | conv_mode: str = "cross_correlation", | ||||
compute_mode: str = "default", | compute_mode: str = "default", | ||||
dtype=None, | dtype=None, | ||||
padding_mode: str = "zeros", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super().__init__( | super().__init__( | ||||
@@ -51,13 +52,33 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
True, | True, | ||||
conv_mode, | conv_mode, | ||||
compute_mode, | compute_mode, | ||||
padding_mode, | |||||
) | ) | ||||
self.output_dtype = dtype | self.output_dtype = dtype | ||||
def calc_conv_quantized(self, inp, nonlinear_mode="identity"): | def calc_conv_quantized(self, inp, nonlinear_mode="identity"): | ||||
assert self.padding_mode in [ | |||||
"zeros", | |||||
"reflect", | |||||
"replicate", | |||||
] | |||||
inp_scale = dtype.get_scale(inp.dtype) | inp_scale = dtype.get_scale(inp.dtype) | ||||
w_scale = dtype.get_scale(self.weight.dtype) | w_scale = dtype.get_scale(self.weight.dtype) | ||||
bias_scale = inp_scale * w_scale | bias_scale = inp_scale * w_scale | ||||
if self.padding_mode != "zeros": | |||||
return conv_bias_activation( | |||||
pad(inp, self.get_pad_witdth(), self.padding_mode), | |||||
self.weight, | |||||
self.bias.astype(dtype.qint32(bias_scale)), | |||||
self.output_dtype, | |||||
self.stride, | |||||
0, | |||||
self.dilation, | |||||
self.groups, | |||||
conv_mode=self.conv_mode, | |||||
compute_mode=self.compute_mode, | |||||
nonlinear_mode=nonlinear_mode, | |||||
) | |||||
return conv_bias_activation( | return conv_bias_activation( | ||||
inp, | inp, | ||||
self.weight, | self.weight, | ||||
@@ -88,6 +109,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
qat_module.dilation, | qat_module.dilation, | ||||
qat_module.groups, | qat_module.groups, | ||||
dtype=output_dtype, | dtype=output_dtype, | ||||
padding_mode=qat_module.padding_mode, | |||||
name=qat_module.name, | name=qat_module.name, | ||||
) | ) | ||||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | ||||
@@ -31,6 +31,7 @@ class _ConvBnActivation2d(Conv2d): | |||||
qat_module.conv.groups, | qat_module.conv.groups, | ||||
dtype=output_dtype, | dtype=output_dtype, | ||||
name=qat_module.name, | name=qat_module.name, | ||||
padding_mode=qat_module.conv.padding_mode, | |||||
) | ) | ||||
w_fold, b_fold = qat_module.fold_weight_bias( | w_fold, b_fold = qat_module.fold_weight_bias( | ||||
qat_module.bn.running_mean, qat_module.bn.running_var | qat_module.bn.running_mean, qat_module.bn.running_var | ||||
@@ -126,6 +126,9 @@ def convbn2d_module_loader(expr): | |||||
module = expr.inputs[0].owner | module = expr.inputs[0].owner | ||||
if not hasattr(module.bn, "param_dim"): | if not hasattr(module.bn, "param_dim"): | ||||
module.bn.param_dim = "dim_1c11" | module.bn.param_dim = "dim_1c11" | ||||
module = expr.inputs[0].owner | |||||
if not hasattr(module.conv, "padding_mode"): | |||||
module.conv.padding_mode = "zeros" | |||||
@register_opdef_loader(BatchNorm) | @register_opdef_loader(BatchNorm) | ||||
@@ -162,3 +165,36 @@ def tensor_gen_func_loader(expr): | |||||
else: | else: | ||||
device = None | device = None | ||||
expr.set_args_kwargs(shape, dtype=dtype, device=device) | expr.set_args_kwargs(shape, dtype=dtype, device=device) | ||||
@register_functional_loader(("megengine.functional.nn", "pad")) | |||||
def pad_func_loader(expr): | |||||
if "pad_witdth" in expr.kwargs: | |||||
kwargs = expr.kwargs | |||||
kwargs["pad_width"] = kwargs.pop("pad_witdth") | |||||
expr.set_args_kwargs(*expr.args, **kwargs) | |||||
@register_module_loader( | |||||
("megengine.module.conv", "Conv1d"), | |||||
("megengine.module.conv", "Conv2d"), | |||||
("megengine.module.conv", "ConvRelu2d"), | |||||
("megengine.module.qat.conv", "Conv2d"), | |||||
("megengine.module.qat.conv", "ConvRelu2d"), | |||||
("megengine.module.quantized.conv", "Conv2d"), | |||||
("megengine.module.quantized.conv", "ConvRelu2d"), | |||||
) | |||||
def conv2d_module_loader(expr): | |||||
module = expr.inputs[0].owner | |||||
if not hasattr(module, "padding_mode"): | |||||
module.padding_mode = "zeros" | |||||
@register_module_loader( | |||||
("megengine.module.quantized.conv_bn", "ConvBn2d"), | |||||
("megengine.module.quantized.conv_bn", "ConvBnRelu2d"), | |||||
) | |||||
def quantized_convbn2d_module_loader(expr): | |||||
module = expr.inputs[0].owner | |||||
if not hasattr(module, "padding_mode"): | |||||
module.padding_mode = "zeros" |
@@ -60,7 +60,18 @@ def test_qat_convbn2d(): | |||||
) | ) | ||||
def test_qat_conv(): | |||||
@pytest.mark.parametrize( | |||||
"padding, padding_mode", | |||||
[ | |||||
(0, "zeros"), | |||||
((1, 2), "zeros"), | |||||
(3, "reflect"), | |||||
((1, 2), "reflect"), | |||||
(4, "replicate"), | |||||
((1, 2), "replicate"), | |||||
], | |||||
) | |||||
def test_qat_conv(padding, padding_mode): | |||||
in_channels = 32 | in_channels = 32 | ||||
out_channels = 64 | out_channels = 64 | ||||
@@ -72,7 +83,13 @@ def test_qat_conv(): | |||||
self.quant = QuantStub() | self.quant = QuantStub() | ||||
self.dequant = DequantStub() | self.dequant = DequantStub() | ||||
self.conv = Conv2d( | self.conv = Conv2d( | ||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
in_channels, | |||||
out_channels, | |||||
kernel_size, | |||||
groups=groups, | |||||
bias=bias, | |||||
padding=padding, | |||||
padding_mode=padding_mode, | |||||
) | ) | ||||
self.conv_relu = ConvRelu2d( | self.conv_relu = ConvRelu2d( | ||||
out_channels, in_channels, kernel_size, groups=groups, bias=bias | out_channels, in_channels, kernel_size, groups=groups, bias=bias | ||||
@@ -236,11 +236,16 @@ def test_linear(): | |||||
@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) | @pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) | ||||
def test_conv(module): | |||||
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) | |||||
@pytest.mark.parametrize("padding_mode", ["zeros", "reflect", "replicate"]) | |||||
def test_conv(module, padding_mode): | |||||
normal_net = getattr(Float, module)( | |||||
3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode | |||||
) | |||||
normal_net.eval() | normal_net.eval() | ||||
qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) | |||||
qat_net = getattr(QAT, module)( | |||||
3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode | |||||
) | |||||
qat_net.eval() | qat_net.eval() | ||||
disable_observer(qat_net) | disable_observer(qat_net) | ||||