GitOrigin-RevId: f6dbd1f4c0
release-1.1
@@ -45,8 +45,6 @@ def conv_bias_activation( | |||
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION' | |||
:param dtype: support for ``np.dtype``, Default: np.int8 | |||
:param scale: scale if use quantization, Default: 0.0 | |||
:param zero_point: scale if use quantization quint8, Default: 0.0 | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode`. | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
@@ -75,3 +73,63 @@ def conv_bias_activation( | |||
) | |||
(outputs,) = apply(op, inp, weight, bias) | |||
return outputs | |||
def batch_conv_bias_activation( | |||
inp: Tensor, | |||
weight: Tensor, | |||
bias: Tensor, | |||
dtype=None, | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
nonlinear_mode="IDENTITY", | |||
conv_mode="CROSS_CORRELATION", | |||
compute_mode="DEFAULT", | |||
) -> Tensor: | |||
""" | |||
Batch convolution bias with activation operation, only for inference. | |||
:param inp: feature map of the convolution operation. | |||
:param weight: convolution kernel in batched way. | |||
:param bias: bias added to the result of convolution | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. | |||
:type conv_mode: string or :class:`P.Convolution.Mode`. | |||
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION' | |||
:param dtype: support for ``np.dtype``, Default: np.int8 | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode`. | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||
""" | |||
ph, pw = _pair(padding) | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.BatchConvBiasForward( | |||
stride_h=sh, | |||
stride_w=sw, | |||
pad_h=ph, | |||
pad_w=pw, | |||
dilate_h=dh, | |||
dilate_w=dw, | |||
dtype=dtype, | |||
format="NCHW", | |||
strategy=get_conv_execution_strategy(), | |||
nonlineMode=nonlinear_mode, | |||
mode=conv_mode, | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
) | |||
(outputs,) = apply(op, inp, weight, bias) | |||
return outputs |
@@ -9,6 +9,7 @@ | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | |||
from .batch_matmul_activation import BatchMatMulActivation | |||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from .concat import Concat | |||
from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
@@ -0,0 +1,67 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..functional import matmul, relu | |||
from ..tensor import Parameter | |||
from . import init | |||
from .module import Module | |||
class BatchMatMulActivation(Module): | |||
r""" | |||
Batched MatMul with activation(only relu supported), no transpose anywhere. | |||
""" | |||
def __init__( | |||
self, | |||
batch: int, | |||
in_features: int, | |||
out_features: int, | |||
bias: bool = True, | |||
nonlinear_mode="IDENTITY", | |||
**kwargs | |||
): | |||
super().__init__(**kwargs) | |||
self.batch = batch | |||
self.out_features = out_features | |||
self.in_features = in_features | |||
w_shape = (batch, out_features, in_features) | |||
self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||
self.bias = None | |||
if bias: | |||
b_shape = (out_features,) | |||
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
self.nonlinear_mode = nonlinear_mode | |||
self.reset_parameters() | |||
def _get_fanin(self): | |||
return self.in_features | |||
def reset_parameters(self) -> None: | |||
fanin = self._get_fanin() | |||
std = np.sqrt(1 / fanin) | |||
init.normal_(self.weight, 0.0, std) | |||
if self.bias is not None: | |||
init.zeros_(self.bias) | |||
def _calc_linear(self, x, weight, bias): | |||
res = matmul(weight, x) | |||
if self.bias is not None: | |||
res += bias | |||
if self.nonlinear_mode == "RELU": | |||
res = relu(res) | |||
return res | |||
def forward(self, x): | |||
return self._calc_linear(x, self.weight, self.bias) | |||
def _module_info_string(self) -> str: | |||
return "batch={}, in_features={}, out_features={}, bias={}".format( | |||
self.batch, self.in_features, self.out_features, self.bias is not None | |||
) |
@@ -5,6 +5,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .batch_matmul_activation import BatchMatMulActivation | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
@@ -0,0 +1,30 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...quantization.utils import fake_quant_bias | |||
from .. import batch_matmul_activation as Float | |||
from .module import QATModule | |||
class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): | |||
def forward(self, inp): | |||
w_qat = self.apply_quant_weight(self.weight) | |||
b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.BatchMatMulActivation): | |||
qat_module = cls( | |||
float_module.batch, | |||
float_module.in_features, | |||
float_module.out_features, | |||
float_module.bias is not None, | |||
) | |||
qat_module.weight = float_module.weight | |||
qat_module.bias = float_module.bias | |||
return qat_module |
@@ -5,6 +5,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .batch_matmul_activation import BatchMatMulActivation | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
@@ -0,0 +1,76 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Tuple, Union | |||
import numpy as np | |||
from ... import module as Float | |||
from ...core.tensor import dtype | |||
from ...functional import expand_dims, squeeze | |||
from ...functional.quantized import batch_conv_bias_activation | |||
from ...tensor import Parameter | |||
from ..qat import batch_matmul_activation as QAT | |||
from .module import QuantizedModule | |||
class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): | |||
def __init__( | |||
self, | |||
batch: int, | |||
in_features: int, | |||
out_features: int, | |||
bias: bool = True, | |||
nonlinear_mode="IDENTITY", | |||
dtype=None, | |||
**kwargs | |||
): | |||
super().__init__(batch, in_features, out_features, bias, **kwargs) | |||
self.output_dtype = dtype | |||
def calc_bmm_quantized(self, inp): | |||
inp_scale = dtype.get_scale(inp.dtype) | |||
w_scale = dtype.get_scale(self.weight.dtype) | |||
bias_scale = inp_scale * w_scale | |||
inp = expand_dims(inp, [-1]) | |||
res = batch_conv_bias_activation( | |||
inp, | |||
self.weight, | |||
self.bias.astype(dtype.qint32(bias_scale)), | |||
dtype=self.output_dtype, | |||
stride=1, | |||
padding=0, | |||
dilation=1, | |||
groups=1, | |||
nonlinear_mode=self.nonlinear_mode, | |||
) | |||
return squeeze(res, -1) | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT.BatchMatMulActivation): | |||
output_dtype = qat_module.get_activation_dtype() | |||
qbmm = cls( | |||
qat_module.batch, | |||
qat_module.in_features, | |||
qat_module.out_features, | |||
qat_module.bias is not None, | |||
dtype=output_dtype, | |||
) | |||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
weight = expand_dims(weight, [-1, -2]) | |||
qbmm.weight = Parameter(weight.numpy()) | |||
if qat_module.bias is not None: | |||
bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) | |||
qbmm.bias = Parameter(bias.numpy()) | |||
else: | |||
qbmm.bias = Parameter( | |||
np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) | |||
) | |||
return qbmm | |||
def forward(self, inp): | |||
return self.calc_bmm_quantized(inp) |
@@ -20,6 +20,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.autodiff.grad import Grad | |||
from megengine.core.tensor.utils import make_shape_tuple | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
def test_where(): | |||
@@ -420,7 +421,9 @@ def test_nms(): | |||
np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) | |||
@pytest.mark.skip(reason="cuda does not support nchw int8") | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||
) | |||
def test_conv_bias(): | |||
inp_scale = 1.5 | |||
w_scale = 2.5 | |||
@@ -446,7 +449,7 @@ def test_conv_bias(): | |||
nonlinear_mode="IDENTITY", | |||
): | |||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||
w_v = np.random.normal(size=(OC, IC, KW, KW)) | |||
w_v = np.random.normal(size=(OC, IC, KH, KW)) | |||
b_v = np.random.normal(size=(1, OC, 1, 1)) | |||
inp_scale = dtype.get_scale(inp_dtype) | |||
w_scale = dtype.get_scale(w_dtype) | |||
@@ -486,13 +489,12 @@ def test_conv_bias(): | |||
inp = convert_to_nchw4(inp) | |||
w = convert_to_nchw4(w) | |||
b = convert_to_nchw4(b) | |||
return F.nn.conv_bias_activation( | |||
return F.quantized.conv_bias_activation( | |||
inp, | |||
w, | |||
b, | |||
stride=(SH, SW), | |||
padding=(PH, PW), | |||
format=format, | |||
dtype=out_dtype, | |||
nonlinear_mode=nonlinear_mode, | |||
) | |||
@@ -522,6 +524,59 @@ def test_conv_bias(): | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||
) | |||
def test_batch_conv_bias(): | |||
inp_scale = 1.5 | |||
w_scale = 2.5 | |||
outp_scale = 1.5 | |||
inp_dtype = dtype.qint8(inp_scale) | |||
w_dtype = dtype.qint8(w_scale) | |||
b_dtype = dtype.qint32(inp_scale * w_scale) | |||
out_dtype = dtype.qint8(outp_scale) | |||
def run( | |||
N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True, | |||
): | |||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||
w_v = np.random.normal(size=(N, OC, IC, KH, KW)) | |||
b_v = np.random.normal(size=(1, OC, 1, 1)) | |||
inp_scale = dtype.get_scale(inp_dtype) | |||
w_scale = dtype.get_scale(w_dtype) | |||
b_scale = dtype.get_scale(b_dtype) | |||
inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||
wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||
bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||
inp_int8 = tensor(inpv, dtype=inp_dtype) | |||
w_int8 = Parameter(wv, dtype=w_dtype) | |||
b_int32 = Parameter(bv, dtype=b_dtype) | |||
inp_fp32 = inp_int8.astype("float32") | |||
w_fp32 = w_int8.astype("float32") | |||
b_fp32 = b_int32.astype("float32") | |||
def run_batch_conv_bias(inp, w, b): | |||
b = b if has_bias else Parameter(np.zeros_like(b.numpy())) | |||
result = F.quantized.batch_conv_bias_activation( | |||
inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype, | |||
) | |||
return result.astype("float32") | |||
expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0] | |||
expected = expected.astype(out_dtype).astype("float32") | |||
expected = F.flatten(expected) | |||
result = run_batch_conv_bias(inp_int8, w_int8, b_int32) | |||
result = F.flatten(result) | |||
np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale) | |||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||
def test_zero_stride_numpy_array(): | |||
inp = np.random.randn(3, 224, 224).astype(np.float32) | |||
inp = inp[np.newaxis, :] | |||
@@ -1,9 +1,15 @@ | |||
import io | |||
from itertools import product | |||
import numpy as np | |||
import pytest | |||
from megengine import tensor | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import jit, tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.functional import expand_dims | |||
from megengine.module import ( | |||
BatchMatMulActivation, | |||
Conv2d, | |||
ConvBn2d, | |||
ConvRelu2d, | |||
@@ -11,7 +17,12 @@ from megengine.module import ( | |||
Module, | |||
QuantStub, | |||
) | |||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||
from megengine.quantization.quantize import ( | |||
disable_fake_quant, | |||
enable_fake_quant, | |||
quantize, | |||
quantize_qat, | |||
) | |||
def test_qat_convbn2d(): | |||
@@ -88,3 +99,107 @@ def test_qat_conv(): | |||
qat_net.eval() | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||
) | |||
def test_qat_batchmatmul_activation(): | |||
batch = 4 | |||
in_features = 8 | |||
out_features = 4 | |||
class TestNet(Module): | |||
def __init__(self, bias): | |||
super().__init__() | |||
self.quant = QuantStub() | |||
self.dequant = DequantStub() | |||
self.batch_mm = BatchMatMulActivation( | |||
batch, in_features, out_features, bias=bias | |||
) | |||
def forward(self, inp): | |||
out = self.quant(inp) | |||
out = self.batch_mm(out) | |||
out = self.dequant(out) | |||
return out | |||
inputs = tensor( | |||
np.random.randn(batch, in_features, out_features).astype(np.float32) | |||
) | |||
for bias in (True, False): | |||
net = TestNet(bias) | |||
net.train() | |||
qat_net = quantize_qat(net, inplace=False) | |||
disable_fake_quant(qat_net) | |||
normal_outputs = net(inputs) | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
net.eval() | |||
normal_outputs = net(inputs) | |||
qat_net.eval() | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
@pytest.mark.skip(reason="FIXME: abnormal exit") | |||
def test_quantize_batchmatmul_activation(): | |||
batch = 4 | |||
in_features = 8 | |||
out_features = 4 | |||
class TestNet(Module): | |||
def __init__(self, bias): | |||
super().__init__() | |||
self.quant = QuantStub() | |||
self.dequant = DequantStub() | |||
self.batch_mm = BatchMatMulActivation( | |||
batch, in_features, out_features, bias=bias | |||
) | |||
def forward(self, inp): | |||
out = self.quant(inp) | |||
out = self.batch_mm(out) | |||
out = expand_dims(out, -1) | |||
out = self.dequant(out) | |||
return out | |||
inputs = tensor( | |||
np.random.randn(batch, in_features, out_features).astype(np.float32) | |||
) | |||
for bias in (True, False): | |||
net = TestNet(bias) | |||
net.train() | |||
qat_net = quantize_qat(net, inplace=False) | |||
disable_fake_quant(qat_net) | |||
normal_outputs = net(inputs) | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
net.eval() | |||
normal_outputs = net(inputs) | |||
qat_net.eval() | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
enable_fake_quant(qat_net) | |||
qat_outputs = qat_net(inputs) | |||
qnet = quantize(qat_net, inplace=False) | |||
qnet.eval() | |||
quantize_outputs = qnet(inputs) | |||
np.testing.assert_allclose( | |||
qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6 | |||
) | |||
@jit.trace(capture_as_const=True) | |||
def f(x): | |||
qnet.eval() | |||
return qnet(x) | |||
f(inputs) | |||
file = io.BytesIO() | |||
f.dump(file, enable_nchw4=True) | |||
file.seek(0) | |||
dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] | |||
np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) |