@@ -35,6 +35,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
param().padding_val, stream); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
#undef cb | |||
} | |||
@@ -60,7 +60,8 @@ __global__ void paddingConst_kernel( | |||
params.src_stride[dim].divisor(); | |||
*/ | |||
} | |||
dst[out_index] = in_src_valid_area ? src[in_index] : padding_val; | |||
dst[out_index] = | |||
in_src_valid_area ? src[in_index] : static_cast<T>(padding_val); | |||
} | |||
} | |||
@@ -256,6 +257,7 @@ void padding_backward_proxy( | |||
const float_t padding_val, cudaStream_t stream); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
#undef cb | |||
#undef INST | |||
@@ -171,7 +171,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
switch (param().padding_mode) { | |||
case param::Padding::PaddingMode::CONSTANT: | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_const_internal<T>( \ | |||
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params, \ | |||
@@ -179,28 +179,31 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
#undef cb | |||
break; | |||
case param::Padding::PaddingMode::REPLICATE: | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_replicate_internal<T>( \ | |||
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
#undef cb | |||
break; | |||
case param::Padding::PaddingMode::REFLECT: | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_reflect_internal<T>( \ | |||
src.layout.ndim, n, src.ptr<T>(), dst.ptr<T>(), params)); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
#undef cb | |||
break; | |||
default: | |||
@@ -101,6 +101,36 @@ TEST_F(CUDA, PADDING_REFLECT2) { | |||
4, 1, 6, 3, 6, 1, 6, 3})}); | |||
} | |||
TEST_F(CUDA, PADDING_REFLECT2_QUANTIZED) { | |||
Checker<Padding> checker(handle_cuda(), false); | |||
param::Padding param; | |||
param.padding_mode = param::Padding::PaddingMode::REFLECT; | |||
param.front_offset_dim0 = 2; | |||
param.front_offset_dim1 = 1; | |||
param.front_offset_dim2 = 0; | |||
param.front_offset_dim3 = 0; | |||
param.front_offset_dim4 = 0; | |||
param.front_offset_dim5 = 0; | |||
param.front_offset_dim6 = 0; | |||
param.back_offset_dim0 = 0; | |||
param.back_offset_dim1 = 2; | |||
param.back_offset_dim2 = 0; | |||
param.back_offset_dim3 = 0; | |||
param.back_offset_dim4 = 0; | |||
param.back_offset_dim5 = 0; | |||
param.back_offset_dim6 = 0; | |||
checker.set_param(param).exect( | |||
Testcase{ | |||
TensorValue( | |||
{3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}), | |||
{}}, | |||
Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5, | |||
4, 5, 6, 5, 4, 2, 1, | |||
2, 3, 2, 1, 5, 4, 5, | |||
6, 5, 4, 8, 7, 8, 9, | |||
8, 7})}); | |||
} | |||
TEST_F(CUDA, PADDING_REPLICATE) { | |||
Checker<Padding> checker(handle_cuda(), false); | |||
param::Padding param; | |||
@@ -83,6 +83,36 @@ TEST_F(NAIVE, PADDING_REFLECT) { | |||
{10}, dtype::Float32(), {3, 2, 1, 2, 3, 4, 5, 4, 3, 2})}); | |||
} | |||
TEST_F(NAIVE, PADDING_REFLECT2) { | |||
Checker<Padding> checker(handle(), false); | |||
param::Padding param; | |||
param.padding_mode = param::Padding::PaddingMode::REFLECT; | |||
param.front_offset_dim0 = 2; | |||
param.front_offset_dim1 = 1; | |||
param.front_offset_dim2 = 0; | |||
param.front_offset_dim3 = 0; | |||
param.front_offset_dim4 = 0; | |||
param.front_offset_dim5 = 0; | |||
param.front_offset_dim6 = 0; | |||
param.back_offset_dim0 = 0; | |||
param.back_offset_dim1 = 2; | |||
param.back_offset_dim2 = 0; | |||
param.back_offset_dim3 = 0; | |||
param.back_offset_dim4 = 0; | |||
param.back_offset_dim5 = 0; | |||
param.back_offset_dim6 = 0; | |||
checker.set_param(param).exect( | |||
Testcase{ | |||
TensorValue( | |||
{3, 3}, dtype::QuantizedS8(), {1, 2, 3, 4, 5, 6, 7, 8, 9}), | |||
{}}, | |||
Testcase{{}, TensorValue({5, 6}, dtype::QuantizedS8(), {8, 7, 8, 9, 8, 7, 5, | |||
4, 5, 6, 5, 4, 2, 1, | |||
2, 3, 2, 1, 5, 4, 5, | |||
6, 5, 4, 8, 7, 8, 9, | |||
8, 7})}); | |||
} | |||
TEST_F(NAIVE, PADDING_REPLICATE) { | |||
Checker<Padding> checker(handle(), false); | |||
param::Padding param; | |||
@@ -18,6 +18,7 @@ from ..functional import ( | |||
conv_transpose3d, | |||
deformable_conv2d, | |||
local_conv2d, | |||
pad, | |||
relu, | |||
) | |||
from ..tensor import Parameter | |||
@@ -126,7 +127,7 @@ class Conv1d(_ConvNd): | |||
kernel_size: size of weight on spatial dimensions. | |||
stride: stride of the 1D convolution operation. | |||
padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
spatial dimensions. Default: 0 | |||
dilation: dilation of the 1D convolution operation. Default: 1 | |||
groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
@@ -139,6 +140,8 @@ class Conv1d(_ConvNd): | |||
placed on the precision of intermediate results. When set to "float32", | |||
"float32" would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". | |||
Refer to :class:`~.module.padding.Pad` for more information. | |||
Note: | |||
* ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` , | |||
@@ -177,6 +180,7 @@ class Conv1d(_ConvNd): | |||
bias: bool = True, | |||
conv_mode: str = "cross_correlation", | |||
compute_mode: str = "default", | |||
padding_mode: str = "zeros", | |||
**kwargs | |||
): | |||
kernel_size = kernel_size | |||
@@ -185,6 +189,7 @@ class Conv1d(_ConvNd): | |||
dilation = dilation | |||
self.conv_mode = conv_mode | |||
self.compute_mode = compute_mode | |||
self.padding_mode = padding_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
@@ -223,7 +228,27 @@ class Conv1d(_ConvNd): | |||
# Assume format is NCH(W=1) | |||
return (1, self.out_channels, 1) | |||
def get_pad_witdth(self): | |||
return ((0, 0), (0, 0), (self.padding, self.padding)) | |||
def calc_conv(self, inp, weight, bias): | |||
assert self.padding_mode in [ | |||
"zeros", | |||
"reflect", | |||
"replicate", | |||
] | |||
if self.padding_mode != "zeros": | |||
return conv1d( | |||
pad(inp, self.get_pad_witdth(), self.padding_mode), | |||
weight, | |||
bias, | |||
self.stride, | |||
0, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
return conv1d( | |||
inp, | |||
weight, | |||
@@ -287,7 +312,7 @@ class Conv2d(_ConvNd): | |||
``(kernel_size, kernel_size)``. | |||
stride: stride of the 2D convolution operation. Default: 1 | |||
padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
spatial dimensions. Default: 0 | |||
dilation: dilation of the 2D convolution operation. Default: 1 | |||
groups: number of groups into which the input and output channels are divided, | |||
so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
@@ -300,6 +325,8 @@ class Conv2d(_ConvNd): | |||
placed on the precision of intermediate results. When set to "float32", | |||
"float32" would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". | |||
Refer to :class:`~.module.padding.Pad` for more information. | |||
Note: | |||
* ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` , | |||
@@ -338,6 +365,7 @@ class Conv2d(_ConvNd): | |||
bias: bool = True, | |||
conv_mode: str = "cross_correlation", | |||
compute_mode: str = "default", | |||
padding_mode: str = "zeros", | |||
**kwargs | |||
): | |||
kernel_size = _pair_nonzero(kernel_size) | |||
@@ -346,6 +374,7 @@ class Conv2d(_ConvNd): | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = conv_mode | |||
self.compute_mode = compute_mode | |||
self.padding_mode = padding_mode | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
@@ -384,7 +413,32 @@ class Conv2d(_ConvNd): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def get_pad_witdth(self): | |||
return ( | |||
(0, 0), | |||
(0, 0), | |||
(self.padding[0], self.padding[0]), | |||
(self.padding[1], self.padding[1]), | |||
) | |||
def calc_conv(self, inp, weight, bias): | |||
assert self.padding_mode in [ | |||
"zeros", | |||
"reflect", | |||
"replicate", | |||
] | |||
if self.padding_mode != "zeros": | |||
return conv2d( | |||
pad(inp, self.get_pad_witdth(), self.padding_mode), | |||
weight, | |||
bias, | |||
self.stride, | |||
0, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
return conv2d( | |||
inp, | |||
weight, | |||
@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
padding_mode: str = "zeros", | |||
**kwargs | |||
): | |||
super().__init__(**kwargs) | |||
@@ -44,6 +45,7 @@ class _ConvBnActivation2d(Module): | |||
bias, | |||
conv_mode, | |||
compute_mode, | |||
padding_mode, | |||
**kwargs, | |||
) | |||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QATModule): | |||
float_module.bias is not None, | |||
float_module.conv_mode, | |||
float_module.compute_mode, | |||
float_module.padding_mode, | |||
name=float_module.name, | |||
) | |||
qat_module.weight = float_module.weight | |||
@@ -147,6 +147,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||
float_module.conv.bias is not None, | |||
float_module.conv.conv_mode, | |||
float_module.conv.compute_mode, | |||
padding_mode=float_module.conv.padding_mode, | |||
name=float_module.name, | |||
) | |||
qat_module.conv.weight = float_module.conv.weight | |||
@@ -11,7 +11,7 @@ import numpy as np | |||
from ... import module as Float | |||
from ...core.tensor import dtype | |||
from ...functional.nn import conv_bias_activation | |||
from ...functional.nn import conv_bias_activation, pad | |||
from ...functional.quantized import conv_transpose2d | |||
from ...tensor import Parameter | |||
from ..qat import conv as QAT | |||
@@ -38,6 +38,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||
conv_mode: str = "cross_correlation", | |||
compute_mode: str = "default", | |||
dtype=None, | |||
padding_mode: str = "zeros", | |||
**kwargs | |||
): | |||
super().__init__( | |||
@@ -51,13 +52,33 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||
True, | |||
conv_mode, | |||
compute_mode, | |||
padding_mode, | |||
) | |||
self.output_dtype = dtype | |||
def calc_conv_quantized(self, inp, nonlinear_mode="identity"): | |||
assert self.padding_mode in [ | |||
"zeros", | |||
"reflect", | |||
"replicate", | |||
] | |||
inp_scale = dtype.get_scale(inp.dtype) | |||
w_scale = dtype.get_scale(self.weight.dtype) | |||
bias_scale = inp_scale * w_scale | |||
if self.padding_mode != "zeros": | |||
return conv_bias_activation( | |||
pad(inp, self.get_pad_witdth(), self.padding_mode), | |||
self.weight, | |||
self.bias.astype(dtype.qint32(bias_scale)), | |||
self.output_dtype, | |||
self.stride, | |||
0, | |||
self.dilation, | |||
self.groups, | |||
conv_mode=self.conv_mode, | |||
compute_mode=self.compute_mode, | |||
nonlinear_mode=nonlinear_mode, | |||
) | |||
return conv_bias_activation( | |||
inp, | |||
self.weight, | |||
@@ -88,6 +109,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||
qat_module.dilation, | |||
qat_module.groups, | |||
dtype=output_dtype, | |||
padding_mode=qat_module.padding_mode, | |||
name=qat_module.name, | |||
) | |||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
@@ -31,6 +31,7 @@ class _ConvBnActivation2d(Conv2d): | |||
qat_module.conv.groups, | |||
dtype=output_dtype, | |||
name=qat_module.name, | |||
padding_mode=qat_module.conv.padding_mode, | |||
) | |||
w_fold, b_fold = qat_module.fold_weight_bias( | |||
qat_module.bn.running_mean, qat_module.bn.running_var | |||
@@ -126,6 +126,9 @@ def convbn2d_module_loader(expr): | |||
module = expr.inputs[0].owner | |||
if not hasattr(module.bn, "param_dim"): | |||
module.bn.param_dim = "dim_1c11" | |||
module = expr.inputs[0].owner | |||
if not hasattr(module.conv, "padding_mode"): | |||
module.conv.padding_mode = "zeros" | |||
@register_opdef_loader(BatchNorm) | |||
@@ -162,3 +165,36 @@ def tensor_gen_func_loader(expr): | |||
else: | |||
device = None | |||
expr.set_args_kwargs(shape, dtype=dtype, device=device) | |||
@register_functional_loader(("megengine.functional.nn", "pad")) | |||
def pad_func_loader(expr): | |||
if "pad_witdth" in expr.kwargs: | |||
kwargs = expr.kwargs | |||
kwargs["pad_width"] = kwargs.pop("pad_witdth") | |||
expr.set_args_kwargs(*expr.args, **kwargs) | |||
@register_module_loader( | |||
("megengine.module.conv", "Conv1d"), | |||
("megengine.module.conv", "Conv2d"), | |||
("megengine.module.conv", "ConvRelu2d"), | |||
("megengine.module.qat.conv", "Conv2d"), | |||
("megengine.module.qat.conv", "ConvRelu2d"), | |||
("megengine.module.quantized.conv", "Conv2d"), | |||
("megengine.module.quantized.conv", "ConvRelu2d"), | |||
) | |||
def conv2d_module_loader(expr): | |||
module = expr.inputs[0].owner | |||
if not hasattr(module, "padding_mode"): | |||
module.padding_mode = "zeros" | |||
@register_module_loader( | |||
("megengine.module.quantized.conv_bn", "ConvBn2d"), | |||
("megengine.module.quantized.conv_bn", "ConvBnRelu2d"), | |||
) | |||
def quantized_convbn2d_module_loader(expr): | |||
module = expr.inputs[0].owner | |||
if not hasattr(module, "padding_mode"): | |||
module.padding_mode = "zeros" |
@@ -60,7 +60,18 @@ def test_qat_convbn2d(): | |||
) | |||
def test_qat_conv(): | |||
@pytest.mark.parametrize( | |||
"padding, padding_mode", | |||
[ | |||
(0, "zeros"), | |||
((1, 2), "zeros"), | |||
(3, "reflect"), | |||
((1, 2), "reflect"), | |||
(4, "replicate"), | |||
((1, 2), "replicate"), | |||
], | |||
) | |||
def test_qat_conv(padding, padding_mode): | |||
in_channels = 32 | |||
out_channels = 64 | |||
@@ -72,7 +83,13 @@ def test_qat_conv(): | |||
self.quant = QuantStub() | |||
self.dequant = DequantStub() | |||
self.conv = Conv2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
groups=groups, | |||
bias=bias, | |||
padding=padding, | |||
padding_mode=padding_mode, | |||
) | |||
self.conv_relu = ConvRelu2d( | |||
out_channels, in_channels, kernel_size, groups=groups, bias=bias | |||
@@ -236,11 +236,16 @@ def test_linear(): | |||
@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) | |||
def test_conv(module): | |||
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) | |||
@pytest.mark.parametrize("padding_mode", ["zeros", "reflect", "replicate"]) | |||
def test_conv(module, padding_mode): | |||
normal_net = getattr(Float, module)( | |||
3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode | |||
) | |||
normal_net.eval() | |||
qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) | |||
qat_net = getattr(QAT, module)( | |||
3, 3, 3, 1, 1, 1, bias=True, padding_mode=padding_mode | |||
) | |||
qat_net.eval() | |||
disable_observer(qat_net) | |||