Browse Source

feat(mge/quantize): add batch matmul activation module for inference

GitOrigin-RevId: f6dbd1f4c0
release-1.1
Megvii Engine Team 4 years ago
parent
commit
442058aead
9 changed files with 412 additions and 8 deletions
  1. +60
    -2
      imperative/python/megengine/functional/quantized.py
  2. +1
    -0
      imperative/python/megengine/module/__init__.py
  3. +67
    -0
      imperative/python/megengine/module/batch_matmul_activation.py
  4. +1
    -0
      imperative/python/megengine/module/qat/__init__.py
  5. +30
    -0
      imperative/python/megengine/module/qat/batch_matmul_activation.py
  6. +1
    -0
      imperative/python/megengine/module/quantized/__init__.py
  7. +76
    -0
      imperative/python/megengine/module/quantized/batch_matmul_activation.py
  8. +59
    -4
      imperative/python/test/unit/functional/test_functional.py
  9. +117
    -2
      imperative/python/test/unit/module/test_qat.py

+ 60
- 2
imperative/python/megengine/functional/quantized.py View File

@@ -45,8 +45,6 @@ def conv_bias_activation(
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION' 'CROSS_CORRELATION'
:param dtype: support for ``np.dtype``, Default: np.int8 :param dtype: support for ``np.dtype``, Default: np.int8
:param scale: scale if use quantization, Default: 0.0
:param zero_point: scale if use quantization quint8, Default: 0.0
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode`. :class:`P.Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
@@ -75,3 +73,63 @@ def conv_bias_activation(
) )
(outputs,) = apply(op, inp, weight, bias) (outputs,) = apply(op, inp, weight, bias)
return outputs return outputs


def batch_conv_bias_activation(
inp: Tensor,
weight: Tensor,
bias: Tensor,
dtype=None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
nonlinear_mode="IDENTITY",
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
) -> Tensor:
"""
Batch convolution bias with activation operation, only for inference.

:param inp: feature map of the convolution operation.
:param weight: convolution kernel in batched way.
:param bias: bias added to the result of convolution
:param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'
:param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or
:class:`P.Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.

"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.BatchConvBiasForward(
stride_h=sh,
stride_w=sw,
pad_h=ph,
pad_w=pw,
dilate_h=dh,
dilate_w=dw,
dtype=dtype,
format="NCHW",
strategy=get_conv_execution_strategy(),
nonlineMode=nonlinear_mode,
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
(outputs,) = apply(op, inp, weight, bias)
return outputs

+ 1
- 0
imperative/python/megengine/module/__init__.py View File

@@ -9,6 +9,7 @@


from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d
from .batch_matmul_activation import BatchMatMulActivation
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from .concat import Concat from .concat import Concat
from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d


+ 67
- 0
imperative/python/megengine/module/batch_matmul_activation.py View File

@@ -0,0 +1,67 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from ..functional import matmul, relu
from ..tensor import Parameter
from . import init
from .module import Module


class BatchMatMulActivation(Module):
r"""
Batched MatMul with activation(only relu supported), no transpose anywhere.
"""

def __init__(
self,
batch: int,
in_features: int,
out_features: int,
bias: bool = True,
nonlinear_mode="IDENTITY",
**kwargs
):
super().__init__(**kwargs)
self.batch = batch
self.out_features = out_features
self.in_features = in_features
w_shape = (batch, out_features, in_features)
self.weight = Parameter(np.zeros(w_shape, dtype=np.float32))
self.bias = None
if bias:
b_shape = (out_features,)
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
self.nonlinear_mode = nonlinear_mode
self.reset_parameters()

def _get_fanin(self):
return self.in_features

def reset_parameters(self) -> None:
fanin = self._get_fanin()
std = np.sqrt(1 / fanin)
init.normal_(self.weight, 0.0, std)
if self.bias is not None:
init.zeros_(self.bias)

def _calc_linear(self, x, weight, bias):
res = matmul(weight, x)
if self.bias is not None:
res += bias
if self.nonlinear_mode == "RELU":
res = relu(res)
return res

def forward(self, x):
return self._calc_linear(x, self.weight, self.bias)

def _module_info_string(self) -> str:
return "batch={}, in_features={}, out_features={}, bias={}".format(
self.batch, self.in_features, self.out_features, self.bias is not None
)

+ 1
- 0
imperative/python/megengine/module/qat/__init__.py View File

@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .batch_matmul_activation import BatchMatMulActivation
from .concat import Concat from .concat import Concat
from .conv import Conv2d, ConvRelu2d from .conv import Conv2d, ConvRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d


+ 30
- 0
imperative/python/megengine/module/qat/batch_matmul_activation.py View File

@@ -0,0 +1,30 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ...quantization.utils import fake_quant_bias
from .. import batch_matmul_activation as Float
from .module import QATModule


class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):
def forward(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat))

@classmethod
def from_float_module(cls, float_module: Float.BatchMatMulActivation):
qat_module = cls(
float_module.batch,
float_module.in_features,
float_module.out_features,
float_module.bias is not None,
)
qat_module.weight = float_module.weight
qat_module.bias = float_module.bias
return qat_module

+ 1
- 0
imperative/python/megengine/module/quantized/__init__.py View File

@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .batch_matmul_activation import BatchMatMulActivation
from .concat import Concat from .concat import Concat
from .conv import Conv2d, ConvRelu2d from .conv import Conv2d, ConvRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d


+ 76
- 0
imperative/python/megengine/module/quantized/batch_matmul_activation.py View File

@@ -0,0 +1,76 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union

import numpy as np

from ... import module as Float
from ...core.tensor import dtype
from ...functional import expand_dims, squeeze
from ...functional.quantized import batch_conv_bias_activation
from ...tensor import Parameter
from ..qat import batch_matmul_activation as QAT
from .module import QuantizedModule


class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule):
def __init__(
self,
batch: int,
in_features: int,
out_features: int,
bias: bool = True,
nonlinear_mode="IDENTITY",
dtype=None,
**kwargs
):
super().__init__(batch, in_features, out_features, bias, **kwargs)
self.output_dtype = dtype

def calc_bmm_quantized(self, inp):
inp_scale = dtype.get_scale(inp.dtype)
w_scale = dtype.get_scale(self.weight.dtype)
bias_scale = inp_scale * w_scale
inp = expand_dims(inp, [-1])
res = batch_conv_bias_activation(
inp,
self.weight,
self.bias.astype(dtype.qint32(bias_scale)),
dtype=self.output_dtype,
stride=1,
padding=0,
dilation=1,
groups=1,
nonlinear_mode=self.nonlinear_mode,
)
return squeeze(res, -1)

@classmethod
def from_qat_module(cls, qat_module: QAT.BatchMatMulActivation):
output_dtype = qat_module.get_activation_dtype()
qbmm = cls(
qat_module.batch,
qat_module.in_features,
qat_module.out_features,
qat_module.bias is not None,
dtype=output_dtype,
)
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
weight = expand_dims(weight, [-1, -2])
qbmm.weight = Parameter(weight.numpy())
if qat_module.bias is not None:
bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1))
qbmm.bias = Parameter(bias.numpy())
else:
qbmm.bias = Parameter(
np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32)
)
return qbmm

def forward(self, inp):
return self.calc_bmm_quantized(inp)

+ 59
- 4
imperative/python/test/unit/functional/test_functional.py View File

@@ -20,6 +20,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.distributed.helper import get_device_count_by_fork




def test_where(): def test_where():
@@ -420,7 +421,9 @@ def test_nms():
np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))




@pytest.mark.skip(reason="cuda does not support nchw int8")
@pytest.mark.skipif(
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8"
)
def test_conv_bias(): def test_conv_bias():
inp_scale = 1.5 inp_scale = 1.5
w_scale = 2.5 w_scale = 2.5
@@ -446,7 +449,7 @@ def test_conv_bias():
nonlinear_mode="IDENTITY", nonlinear_mode="IDENTITY",
): ):
inp_v = np.random.normal(size=(N, IC, IH, IW)) inp_v = np.random.normal(size=(N, IC, IH, IW))
w_v = np.random.normal(size=(OC, IC, KW, KW))
w_v = np.random.normal(size=(OC, IC, KH, KW))
b_v = np.random.normal(size=(1, OC, 1, 1)) b_v = np.random.normal(size=(1, OC, 1, 1))
inp_scale = dtype.get_scale(inp_dtype) inp_scale = dtype.get_scale(inp_dtype)
w_scale = dtype.get_scale(w_dtype) w_scale = dtype.get_scale(w_dtype)
@@ -486,13 +489,12 @@ def test_conv_bias():
inp = convert_to_nchw4(inp) inp = convert_to_nchw4(inp)
w = convert_to_nchw4(w) w = convert_to_nchw4(w)
b = convert_to_nchw4(b) b = convert_to_nchw4(b)
return F.nn.conv_bias_activation(
return F.quantized.conv_bias_activation(
inp, inp,
w, w,
b, b,
stride=(SH, SW), stride=(SH, SW),
padding=(PH, PW), padding=(PH, PW),
format=format,
dtype=out_dtype, dtype=out_dtype,
nonlinear_mode=nonlinear_mode, nonlinear_mode=nonlinear_mode,
) )
@@ -522,6 +524,59 @@ def test_conv_bias():
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")




@pytest.mark.skipif(
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda"
)
def test_batch_conv_bias():
inp_scale = 1.5
w_scale = 2.5
outp_scale = 1.5
inp_dtype = dtype.qint8(inp_scale)
w_dtype = dtype.qint8(w_scale)
b_dtype = dtype.qint32(inp_scale * w_scale)
out_dtype = dtype.qint8(outp_scale)

def run(
N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True,
):
inp_v = np.random.normal(size=(N, IC, IH, IW))
w_v = np.random.normal(size=(N, OC, IC, KH, KW))
b_v = np.random.normal(size=(1, OC, 1, 1))
inp_scale = dtype.get_scale(inp_dtype)
w_scale = dtype.get_scale(w_dtype)
b_scale = dtype.get_scale(b_dtype)

inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)

inp_int8 = tensor(inpv, dtype=inp_dtype)
w_int8 = Parameter(wv, dtype=w_dtype)
b_int32 = Parameter(bv, dtype=b_dtype)

inp_fp32 = inp_int8.astype("float32")
w_fp32 = w_int8.astype("float32")
b_fp32 = b_int32.astype("float32")

def run_batch_conv_bias(inp, w, b):
b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
result = F.quantized.batch_conv_bias_activation(
inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype,
)
return result.astype("float32")

expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0]
expected = expected.astype(out_dtype).astype("float32")
expected = F.flatten(expected)

result = run_batch_conv_bias(inp_int8, w_int8, b_int32)
result = F.flatten(result)

np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)

run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)


def test_zero_stride_numpy_array(): def test_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224).astype(np.float32) inp = np.random.randn(3, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :] inp = inp[np.newaxis, :]


+ 117
- 2
imperative/python/test/unit/module/test_qat.py View File

@@ -1,9 +1,15 @@
import io
from itertools import product from itertools import product


import numpy as np import numpy as np
import pytest


from megengine import tensor
import megengine.utils.comp_graph_tools as cgtools
from megengine import jit, tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional import expand_dims
from megengine.module import ( from megengine.module import (
BatchMatMulActivation,
Conv2d, Conv2d,
ConvBn2d, ConvBn2d,
ConvRelu2d, ConvRelu2d,
@@ -11,7 +17,12 @@ from megengine.module import (
Module, Module,
QuantStub, QuantStub,
) )
from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.quantization.quantize import (
disable_fake_quant,
enable_fake_quant,
quantize,
quantize_qat,
)




def test_qat_convbn2d(): def test_qat_convbn2d():
@@ -88,3 +99,107 @@ def test_qat_conv():
qat_net.eval() qat_net.eval()
qat_outputs = qat_net(inputs) qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())


@pytest.mark.skipif(
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda"
)
def test_qat_batchmatmul_activation():
batch = 4
in_features = 8
out_features = 4

class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.batch_mm = BatchMatMulActivation(
batch, in_features, out_features, bias=bias
)

def forward(self, inp):
out = self.quant(inp)
out = self.batch_mm(out)
out = self.dequant(out)
return out

inputs = tensor(
np.random.randn(batch, in_features, out_features).astype(np.float32)
)
for bias in (True, False):
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())


@pytest.mark.skip(reason="FIXME: abnormal exit")
def test_quantize_batchmatmul_activation():
batch = 4
in_features = 8
out_features = 4

class TestNet(Module):
def __init__(self, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.batch_mm = BatchMatMulActivation(
batch, in_features, out_features, bias=bias
)

def forward(self, inp):
out = self.quant(inp)
out = self.batch_mm(out)
out = expand_dims(out, -1)
out = self.dequant(out)
return out

inputs = tensor(
np.random.randn(batch, in_features, out_features).astype(np.float32)
)
for bias in (True, False):
net = TestNet(bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

enable_fake_quant(qat_net)
qat_outputs = qat_net(inputs)
qnet = quantize(qat_net, inplace=False)
qnet.eval()
quantize_outputs = qnet(inputs)
np.testing.assert_allclose(
qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
)

@jit.trace(capture_as_const=True)
def f(x):
qnet.eval()
return qnet(x)

f(inputs)
file = io.BytesIO()
f.dump(file, enable_nchw4=True)
file.seek(0)
dumped_outputs = cgtools.load_and_inference(file, [inputs])[0]
np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)

Loading…
Cancel
Save