Browse Source

feat(mge/quantization): add quantization interface

GitOrigin-RevId: 4fe1233ec3
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
8c110c3942
23 changed files with 1418 additions and 6 deletions
  1. +2
    -0
      python_module/megengine/functional/__init__.py
  2. +84
    -0
      python_module/megengine/functional/quantized.py
  3. +35
    -0
      python_module/megengine/functional/tensor.py
  4. +5
    -1
      python_module/megengine/module/__init__.py
  5. +27
    -0
      python_module/megengine/module/concat.py
  6. +6
    -3
      python_module/megengine/module/conv.py
  7. +168
    -0
      python_module/megengine/module/conv_bn_relu.py
  8. +95
    -0
      python_module/megengine/module/elemwise.py
  9. +93
    -1
      python_module/megengine/module/module.py
  10. +34
    -0
      python_module/megengine/module/quant_dequant.py
  11. +11
    -0
      python_module/megengine/module/quantized/__init__.py
  12. +45
    -0
      python_module/megengine/module/quantized/concat.py
  13. +114
    -0
      python_module/megengine/module/quantized/conv_bn_relu.py
  14. +59
    -0
      python_module/megengine/module/quantized/elemwise.py
  15. +61
    -0
      python_module/megengine/module/quantized/quant_dequant.py
  16. +1
    -0
      python_module/megengine/module/sequential.py
  17. +11
    -0
      python_module/megengine/quantization/__init__.py
  18. +48
    -0
      python_module/megengine/quantization/fake_quant.py
  19. +193
    -0
      python_module/megengine/quantization/observer.py
  20. +82
    -0
      python_module/megengine/quantization/qconfig.py
  21. +113
    -0
      python_module/megengine/quantization/quantize.py
  22. +23
    -0
      python_module/megengine/quantization/utils.py
  23. +108
    -1
      python_module/test/unit/functional/test_functional.py

+ 2
- 0
python_module/megengine/functional/__init__.py View File

@@ -74,12 +74,14 @@ from .nn import (
softmax,
warp_perspective,
)
from .quantized import conv_bias_activation
from .sort import argsort, sort, top_k
from .tensor import (
add_axis,
arange,
broadcast_to,
concat,
cond_take,
dimshuffle,
gather,
linspace,


+ 84
- 0
python_module/megengine/functional/quantized.py View File

@@ -0,0 +1,84 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
from typing import Tuple, Union

from .. import _internal as mgb
from ..core import Tensor, wrap_io_tensor
from ..utils.types import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy


@wrap_io_tensor
def conv_bias_activation(
inp: Tensor,
weight: Tensor,
bias: Tensor,
dtype=None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
nonlinear_mode="IDENTITY",
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
) -> Tensor:
""" convolution bias with activation operation, only for inference.

:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution
:param stride: Stride of the 2D convolution operation. Default: 1
:param padding: Size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: Dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``.
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'.
:param dtype: Support for np.dtype, Default:
np.int8.
:param scale: scale if use quantization, Default:
0.0.
:param zero_point: scale if use quantization quint8, Default:
0.0.
:type compute_mode: string or
:class:`mgb.opr_param_defs.Convolution.ComputeMode`
:param compute_mode: When set to 'DEFAULT', no special requirements will be
placed on the precision of intermediate results. When set to 'FLOAT32',
Float32 would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.

"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
res = mgb.opr.conv_bias_activation(
inp,
weight,
bias,
compute_mode=compute_mode,
dtype=dtype,
strategy=get_conv_execution_strategy(),
nonlineMode=nonlinear_mode,
sparse=sparse_type,
format="NCHW",
pad_h=ph,
pad_w=pw,
stride_h=sh,
stride_w=sw,
dilate_h=dh,
dilate_w=dw,
mode=conv_mode,
)
return res

+ 35
- 0
python_module/megengine/functional/tensor.py View File

@@ -359,6 +359,41 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
return out


@wrap_io_tensor
def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor:
r"""
Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened.

:param mask: condition param; must be the same shape with data
:param x: input tensor from which to take elements
:param val: value to be compared to by mode

Examples:

.. testcode::

from megengine import tensor
import megengine.functional as F
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32))
x = tensor(np.array([[1, np.inf], [np.nan, 4]],
dtype=np.float32))
v, index = F.cond_take(mask, x, 1)
print(v, index)

Outputs:

.. testoutput::

Tensor([1. 4.]) Tensor([0 3], dtype=int32)

"""

v, index = mgb.opr.cond_take(
x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val
)
return v, index


def shapeof(x: Tensor, axis=None):
r"""
The shape of input tensor.


+ 5
- 1
python_module/megengine/module/__init__.py View File

@@ -8,12 +8,16 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .batchnorm import BatchNorm1d, BatchNorm2d
from .concat import Concat
from .conv import Conv2d, ConvTranspose2d
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .dropout import Dropout
from .elemwise import Elemwise
from .embedding import Embedding
from .identity import Identity
from .linear import Linear
from .module import Module
from .module import Module, QATModule
from .parampack import ParamPack
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub
from .sequential import Sequential

+ 27
- 0
python_module/megengine/module/concat.py View File

@@ -0,0 +1,27 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable

from .. import functional as F
from ..core.tensor import Tensor
from .module import QATModule


class Concat(QATModule):
r"""
A :class:`~.QATModule` to do functional concat, should replace concat with this module,
supporting ``qat`` mode and ``quantized`` mode.
"""

def forward(self, inps: Iterable[Tensor], axis: int = 0):
return F.concat(inps, axis)

def forward_qat(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_fakequant_with_observer(
self.forward(inps, axis), self.act_fake_quant, self.act_observer
)

+ 6
- 3
python_module/megengine/module/conv.py View File

@@ -182,11 +182,11 @@ class Conv2d(_ConvNd):
# Assume format is NCHW
return (1, self.out_channels, 1, 1)

def forward(self, inp):
def calc_conv(self, inp, weight, bias):
return conv2d(
inp,
self.weight,
self.bias,
weight,
bias,
self.stride,
self.padding,
self.dilation,
@@ -195,6 +195,9 @@ class Conv2d(_ConvNd):
self.compute_mode,
)

def forward(self, inp):
return self.calc_conv(inp, self.weight, self.bias)


class ConvTranspose2d(_ConvNd):
r"""Applies a 2D transposed convolution over an input tensor.


+ 168
- 0
python_module/megengine/module/conv_bn_relu.py View File

@@ -0,0 +1,168 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union

from ..core import ones, zeros
from ..functional import flatten, relu, sqrt, sum
from .batchnorm import BatchNorm2d
from .conv import Conv2d
from .module import QATModule


class _ConvBn2d(QATModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
freeze_bn=False,
):
super().__init__()
self.conv = Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
conv_mode,
compute_mode,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
self.freeze_bn = freeze_bn

def update_bn_stats(self):
self.freeze_bn = False
return self

def freeze_bn_stats(self):
self.freeze_bn = True
return self

def get_bn_gamma_beta(self):
if self.bn.weight is None:
gamma = ones((self.bn.num_features), dtype="float32")
else:
gamma = self.bn.weight

if self.bn.bias is None:
beta = zeros((self.bn.num_features), dtype="float32")
else:
beta = self.bn.bias

return gamma, beta

def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out

sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1)

return batch_mean, batch_var

def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma, beta = self.get_bn_gamma_beta()
b = self.conv.bias
if b is None:
b = zeros(self.conv._infer_bias_shape(), dtype="float32")
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
if self.conv.groups == 1:
w_fold = (
self.conv.weight
* gamma.reshape(-1, 1, 1, 1)
* bn_istd.reshape(-1, 1, 1, 1)
)
else:
w_fold = (
self.conv.weight
* gamma.reshape(self.conv.groups, -1, 1, 1, 1)
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1)
)
b_fold = flatten(beta) + (
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd)
)
b_fold = b_fold.reshape(self.conv._infer_bias_shape())

return w_fold, b_fold

def calc_conv_bn_qat(self, inp):
# TODO: use pytorch method as
conv = self.conv(inp)
self.bn(conv)

if self.training:
bn_mean, bn_var = self.get_batch_mean_var(conv)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var

w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var)
w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer
)
return self.conv.calc_conv(inp, w_qat, b_fold)


class ConvBn2d(_ConvBn2d):
r"""
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
and ``normal`` mode.
"""

def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
)

def forward(self, inp):
return self.bn(self.conv(inp))


class ConvBnRelu2d(_ConvBn2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
mode and ``normal`` mode.
"""

def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
)

def forward(self, inp):
return relu(self.bn(self.conv(inp)))

+ 95
- 0
python_module/megengine/module/elemwise.py View File

@@ -0,0 +1,95 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import _internal as mgb
from ..core import Tensor, wrap_io_tensor
from ..core.graph import _use_default_if_none
from .module import QATModule


@wrap_io_tensor
def _elemwise_func(mode, *inputs, **kwargs) -> Tensor:
if all(isinstance(i, (int, float)) for i in inputs):
device, comp_graph = _use_default_if_none(None, None)
ret = mgb.opr.elemwise(
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs
)
return ret.inferred_value[0]
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs)


class Elemwise(QATModule):
r"""
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module,
supporting ``qat`` mode and ``normal`` mode.

:param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float.

* "ADD": a + b
* "FUSE_ADD_RELU": max(x+y, 0)
* "MUL": x * y
* "MIN": min(x, y)
* "MAX": max(x, y)
* "SUB": x - y
* "TRUE_DIV": x / y
* "FUSE_ADD_SIGMOID": sigmoid(x + y)
* "FUSE_ADD_TANH": tanh(x + y)
* "RELU": x > 0 ? x : 0
* "ABS": x > 0 ? x : -x
* "SIGMOID": sigmoid(x)
* "EXP": exp(x)
* "TANH": tanh(x)
* "FUSE_MUL_ADD3": x * y + z
* "FAST_TANH": fast_tanh(x)
* "NEGATE": -x
* "ACOS": acos(x)
* "ASIN": asin(x)
* "CEIL": ceil(x)
* "COS": cos(x)
* "EXPM1": expm1(x)
* "FLOOR": floor(x)
* "LOG": log(x)
* "LOG1P": log1p(x)
* "SIN": sin(x)
* "ROUND": round(x)
* "ERF": erf(x)
* "ERFINV": erfinv(x)
* "ERFC": erfc(x)
* "ERFCINV": erfcinv(x)
* "ABS_GRAD": abs_grad
* "FLOOR_DIV": floor_div
* "MOD": mod
* "SIGMOID_GRAD": sigmoid_grad
* "SWITCH_GT0": switch_gt0
* "TANH_GRAD": tanh_grad
* "LT": lt
* "LEQ": leq
* "EQ": eq
* "POW": pow
* "LOG_SUM_EXP": log_sum_exp
* "FAST_TANH_GRAD": fast_tanh_grad
* "ATAN2": atan2
* "COND_LEQ_MOV": cond_leq_mov
* "H_SWISH": h_swish
* "FUSE_ADD_H_SWISH": h_swish(x+y)
* "H_SWISH_GRAD": h_swish_grad
"""

_elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode

def __init__(self, method):
super().__init__()
self.method = self._elemwise_mode_type.convert(method)

def forward(self, *inps):
return _elemwise_func(self.method, *inps)

def forward_qat(self, *inps):
return self.apply_fakequant_with_observer(
self.forward(*inps), self.act_fake_quant, self.act_observer,
)

+ 93
- 1
python_module/megengine/module/module.py View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -8,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union

import numpy as np
@@ -442,3 +442,95 @@ class Module(metaclass=ABCMeta):
loaded.append(k)

return set(loaded), set(skipped)


class QATModule(Module):
r"""
Base class of quantization related Module. Add extra forward methods
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for
``qat``(quantization aware training) mode and ``quantized`` mode respectively.

Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode,
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode,
which is irreversible.

If you want to recursively switch mode for all QATModule in network, use
functions in :mod:`~.quantization.quantize`.
"""

class QATMode(Enum):
DISABLED = 1
QAT = 2
CALIBRATION = 3

def __init__(self):
from ..quantization import (
QConfig,
FakeQuantize,
Observer,
) # pylint: disable=all

super().__init__()

self.quantizing = self.QATMode.DISABLED
self.scale = None

self.inp_observer = None # type: Observer
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer

self.weight_fake_quant = None # type: FakeQuantize
self.bias_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize

def set_qconfig(self, qconfig: "QConfig"):
self.inp_observer = qconfig.inp_observer()
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()

self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.bias_fake_quant = qconfig.bias_fake_quant()
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)

def apply_observer(self, target: Tensor, obs: "Observer"):
return obs(target)

def apply_fakequant_with_observer(
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
return fq(oup, obs.scale, obs.zero_point)

def set_qat_mode(self, mode: QATMode):
r"""
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``,
``QAT``,``CALIBRATION``.
"""
if not isinstance(mode, self.QATMode):
raise TypeError("mode must be QATMode Enum type")
self.quantizing = mode

def to_quantized(self):
r"""
Return a new :class:`~.Module` with quantized parameters of ``self``
according to scale and zero_point in ``self.xxx_observer``.
"""
raise NotImplementedError(
"Use megengine.quantization.quantize to register the method."
)

@abstractmethod
def forward_qat(self, *args, **kwargs):
r"""
Forward method for ``qat`` mode.
"""

def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.QAT:
return self.forward_qat(*args, **kwargs)
elif self.quantizing == self.QATMode.CALIBRATION:
# TODO implement the CALIBRATION
assert False
return None
else:
return self.forward(*args, **kwargs)

+ 34
- 0
python_module/megengine/module/quant_dequant.py View File

@@ -0,0 +1,34 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .module import QATModule


class QuantStub(QATModule):
r"""
A helper QATModule doing quantize operation on input.
"""

def forward(self, inp):
return inp

def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
inp, self.act_fake_quant, self.act_observer
)


class DequantStub(QATModule):
r"""
A helper QATModule doing de-quantize operation on input.
"""

def forward(self, inp):
return inp

def forward_qat(self, inp):
return inp

+ 11
- 0
python_module/megengine/module/quantized/__init__.py View File

@@ -0,0 +1,11 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .quant_dequant import DequantStub, QuantStub

+ 45
- 0
python_module/megengine/module/quantized/concat.py View File

@@ -0,0 +1,45 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable

from ... import _internal as mgb
from ... import functional as F
from ... import module as Float
from ...core.tensor import Tensor
from ...quantization.utils import register_method_to_class
from ..module import Module


class Concat(Module):
r"""
A :class:`~.Module` to do quantized concat, inference only.
"""

def __init__(self):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)

def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training:
raise ValueError("quantized module only support inference.")
new_inps = (x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis)


@register_method_to_class(Float.Concat)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Concat()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod

+ 114
- 0
python_module/megengine/module/quantized/conv_bn_relu.py View File

@@ -0,0 +1,114 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from typing import Tuple, Union

import megengine._internal as mgb

from ... import module as Float
from ...core import Parameter
from ...functional import conv_bias_activation
from ...module import Conv2d
from ...quantization.utils import register_method_to_class


class _ConvBnActivation2d(Conv2d):
r"""Applies a 2D convolution over an quantized input tensor, inference only.

The parameter is same with :class: `~.Conv2d`
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
True,
conv_mode,
compute_mode,
)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.weight = self.weight.astype(self.output_dtype)
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale))

def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"):
inp_scale = mgb.dtype.get_scale(inp.dtype)
w_scale = mgb.dtype.get_scale(self.weight.dtype)
bias_scale = inp_scale * w_scale
return conv_bias_activation(
inp,
self.weight,
self.bias.astype(mgb.dtype.qint32(bias_scale)),
self.output_dtype,
self.stride,
self.padding,
self.dilation,
self.groups,
conv_mode=self.conv_mode,
compute_mode=self.compute_mode,
nonlinear_mode=nonlinear_mode,
)


class ConvBn2d(_ConvBnActivation2d):
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")


class ConvBnRelu2d(_ConvBnActivation2d):
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")


def to_quantized(quantized_class, float_module):
qconv = quantized_class(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
)
w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var
)
weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.output_dtype = float_module.act_observer.get_dtype()
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams()

return qconv


# replace :class:`~.module.QATModule`'s ``to_quantized`` method.
# implemented here to avoid circular import.
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d))
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d))

+ 59
- 0
python_module/megengine/module/quantized/elemwise.py View File

@@ -0,0 +1,59 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...core import Tensor, wrap_io_tensor
from ...core.graph import _use_default_if_none
from ...quantization.utils import register_method_to_class
from ..module import Module


@wrap_io_tensor
def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor:
if all(isinstance(i, (int, float)) for i in inputs):
device, comp_graph = _use_default_if_none(None, None)
ret = mgb.opr.elemwise_multi_type(
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs,
)
return ret.inferred_value[0]
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs)


class Elemwise(Module):
r"""
quantized module for elemwise operator, inference only.

:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`.
it will do quantized operator with specified output quantized dtype.
"""

_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode

def __init__(self, method):
super().__init__()
self.method = self._elemwise_multi_type_mode.convert("Q" + method)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)

def forward(self, *inps):
if self.training:
raise ValueError("quantized module only support inference.")
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype)


@register_method_to_class(Float.Elemwise)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Elemwise(float_module.method.name)
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod

+ 61
- 0
python_module/megengine/module/quantized/quant_dequant.py View File

@@ -0,0 +1,61 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...quantization.utils import register_method_to_class
from ..module import Module


class QuantStub(Module):
r"""
A helper quantize operation on input and inference only.
"""

def __init__(self):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)

def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype(self.output_dtype)


class DequantStub(Module):
r"""
A helper de-quantize operation and inference only.
"""

def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype("float32")


@register_method_to_class(Float.QuantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = QuantStub()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod


@register_method_to_class(Float.DequantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = DequantStub()
return qmod

+ 1
- 0
python_module/megengine/module/sequential.py View File

@@ -68,6 +68,7 @@ class Sequential(Module):

def __setitem__(self, idx, module):
key = self.layer_keys[idx]
self.layer_values[idx] = module
return setattr(self, key, module)

def __delitem__(self, idx):


+ 11
- 0
python_module/megengine/quantization/__init__.py View File

@@ -0,0 +1,11 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .observer import Observer
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig
from .quantize import quantize, quantize_qat

+ 48
- 0
python_module/megengine/quantization/fake_quant.py View File

@@ -0,0 +1,48 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import functional as F
from .._internal.dtype import _metadata_dict
from ..module import Module
from .observer import Round


class FakeQuantize(Module):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
"""

def __init__(self, dtype: str, enable: bool = True):
super().__init__()
if not dtype in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
)
)
self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin
self.qmax = _metadata_dict[dtype].qmax
self.enabled = enable

def enable(self):
self.enabled = True

def disable(self):
self.enabled = False

def forward(self, inp, scale, zero_point):
if self.enabled:
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup

return inp

+ 193
- 0
python_module/megengine/quantization/observer.py View File

@@ -0,0 +1,193 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod

import numpy as np

from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, ones, tensor, zeros
from ..module import Module


class Round(Function):
def forward(self, x):
return x.round()

def backward(self, output_grads):
return output_grads


class Observer(Module):
r"""
A base class for Observer Module.

:param dtype: a string indicating to collect scale and zero_point of which dtype
"""

def __init__(self, dtype="qint8"):
super().__init__()
if dtype not in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
)
)
self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin
self.qmax = _metadata_dict[dtype].qmax
self.zero_point, self.scale = None, None
self.enabled = True

def get_dtype(self):
scale, zero_point = self.get_qparams()
numpy_scale = None if scale is None else scale.numpy()[0]
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)

def enable(self):
self.enabled = True

def disable(self):
self.enabled = False

@abstractmethod
def forward(self, x):
pass

@abstractmethod
def get_qparams(self, **kwargs):
pass


class IdentityObserver(Observer):
r"""
An test Observer that always return scale:1 and zero_point:0.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.zero_point = ones((1), dtype="float32")
self.scale = zeros((1), dtype="float32")

def forward(self, x):
return x

def get_qparams(self):
return self.scale, self.zero_point


class MinMaxObserver(Observer):
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
super().__init__(*args, **kwargs)
self.symmetric = symmetric
if self.symmetric:
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)

self.min_val = Buffer(0.0, dtype=np.float32)
self.max_val = Buffer(0.0, dtype=np.float32)
self.scale_limit = eps
# flag is used by cond_take, first time will be first flag, and after will be set as not_flag
self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))

def set_min_max(self, tmp_min, tmp_max):
# FIXME: cond_take will destory shape, use reshape to reset shape
tmp_min = tmp_min.reshape(1)
tmp_max = tmp_max.reshape(1)
if self.training:
F.zero_grad(
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(
self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0
)
)

# FIXME: add_update is applied after the whole trace procedure in `symbolic=True`
# mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further
# calculation in FakeQuant.
self.set_scale_zero_point(tmp_min, tmp_max)

def set_scale_zero_point(self, tmp_min, tmp_max):
if self.symmetric:
symmetric_max_vals = F.maximum(-tmp_min, tmp_max)
# use maximun to avoid scale too small at the begin
self.scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
)
# zero_point = self.zero_point
else:
# use maximun to avoid scale too small at the begin
self.scale = F.maximum(
(tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit
)
# caculate zero_point
self.zero_point = self.qmin - Round()((tmp_min / self.scale))

def get_qparams(self):
# scale and zero_point is runtime tensor rather than Buffer,
# so need to re-calc if min_val and max_val are loaded.
if self.scale is None:
self.set_scale_zero_point(self.min_val, self.max_val)

return self.scale, self.zero_point

def forward(self, x_orig):
if self.enabled:
# stop gradient
x = F.zero_grad(x_orig)
# find max and min
tmp_min, _ = F.cond_take(
self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
)
tmp_max, _ = F.cond_take(
self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
)
self.set_min_max(tmp_min, tmp_max)
return x_orig


class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(self, momentum=0.9, *args, **kwargs):
super().__init__(*args, **kwargs)
self.momentum = Buffer(momentum)

def set_momentum(self, momentum):
self.momentum.set_value(momentum)

def forward(self, x_orig):
if self.enabled:
# stop gradient
x = F.zero_grad(x_orig)
# Exponential Moving Average
tmp_min, _ = F.cond_take(
self.first_flag,
F.concat(
[
x.min(),
self.momentum * self.min_val + (1 - self.momentum) * x.min(),
]
),
)
tmp_max, _ = F.cond_take(
self.first_flag,
F.concat(
[
x.max(),
self.momentum * self.max_val + (1 - self.momentum) * x.max(),
]
),
)
self.set_min_max(tmp_min, tmp_max)
return x_orig

+ 82
- 0
python_module/megengine/quantization/qconfig.py View File

@@ -0,0 +1,82 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial

from ..module import Module
from .fake_quant import FakeQuantize
from .observer import ExponentialMovingAverageObserver, MinMaxObserver


class QConfig:
"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation``, ``weight`` and ``bias``.

And ``fake_quant`` parameter to indicate

See :meth:`~.QATModule.set_qconfig` for detail usage.

:param inp_observer: interface to instantiate an :class:`~.Observer` indicating
how to collect scales and zero_point of input.
:param weight_observer: similar to ``inp_observer`` but toward weight.
:param act_observer: similar to ``inp_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different
instance for each target tensor, for better control on enable and disable.
:param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype``
in advance, for bias's dtype is unable to be inferred from observer.

Examples:

.. code-block::

# Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=ExponentialMovingAverageObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
)
"""

def __init__(
self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant,
):
if (
isinstance(act_observer, Module)
or isinstance(weight_observer, Module)
or isinstance(inp_observer, Module)
):
raise ValueError(
"QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use"
" partial(MyObserver, x=1) to override arguments to constructor if needed"
)
self.act_observer = act_observer
self.weight_observer = weight_observer
self.inp_observer = inp_observer
self.fake_quant = fake_quant
self.bias_fake_quant = bias_fake_quant


# Default QAT QConfigs
min_max_fakequant_qconfig = QConfig(
inp_observer=MinMaxObserver,
weight_observer=MinMaxObserver,
act_observer=MinMaxObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)

ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)

+ 113
- 0
python_module/megengine/quantization/quantize.py View File

@@ -0,0 +1,113 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy

from ..module import Module, QATModule, Sequential, quantized
from .qconfig import QConfig, ema_fakequant_qconfig


def quantize(module: Module, inplace=True):
r"""
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`.

:param module: root module to do convert recursively.
"""

if not inplace:
module = deepcopy(module)

def is_qat_module(obj):
return isinstance(obj, QATModule)

# no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_qat_module
):
if isinstance(parent, Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = submodule.to_quantized()
else:
setattr(parent, key.split(".")[-1], submodule.to_quantized())


def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r"""
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply`
and set qconfig relatively.

:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.QAT)
mod.set_qconfig(qconfig)

module.apply(fn)


def disable_fake_quant(module: Module):
r"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`

:param module: root module to do disable fake quantization recursively.
"""

def fn(mod):
if isinstance(mod, QATModule):
mod.act_fake_quant.disable()
mod.weight_fake_quant.disable()
mod.inp_fake_quant.disable()

module.apply(fn)


def disable_observer(module: Module):
r"""
Recursively disable `module` observer in QATModule through :meth:`~.Module.apply`

:param module: root module to do disable observer recursively.
"""

def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.disable()

module.apply(fn)


def enable_fake_quant(module: Module):
r"""
Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply`

:param module: root module to do enable fake quantization recursively.
"""

def fn(mod):
if isinstance(mod, QATModule):
mod.act_fake_quant.enable()
mod.weight_fake_quant.enable()
mod.inp_fake_quant.enable()

module.apply(fn)


def enable_observer(module: Module):
r"""
Recursively enable `module` observer in QATModule through :meth:`~.Module.apply`

:param module: root module to do enable observer recursively.
"""

def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.enable()

module.apply(fn)

+ 23
- 0
python_module/megengine/quantization/utils.py View File

@@ -0,0 +1,23 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from functools import partial, update_wrapper, wraps


def register_method_to_class(cls):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)

if isinstance(func, partial):
update_wrapper(func, func.func)
setattr(cls, func.__name__, wrapper)
return func

return decorator

+ 108
- 1
python_module/test/unit/functional/test_functional.py View File

@@ -7,10 +7,12 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
from helpers import opr_test

import megengine._internal as mgb
import megengine.functional as F
from megengine import Buffer, jit, tensor
from megengine import Buffer, Parameter, is_cuda_available, jit, tensor
from megengine.test import assertTensorClose


@@ -332,3 +334,108 @@ def test_binary_cross_entropy():
{"input": [data2, label2], "output": expect2,},
]
opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn)


@pytest.mark.skip
def test_conv_bias():
inp_scale = 0.01
w_scale = 0.02
outp_scale = 0.1
inp_dtype = mgb.dtype.qint8(inp_scale)
w_dtype = mgb.dtype.qint8(w_scale)
b_dtype = mgb.dtype.qint32(inp_scale * w_scale)
out_dtype = mgb.dtype.qint8(outp_scale)

def run(
N,
IC,
OC,
IH,
IW,
KH,
KW,
PH,
PW,
SH,
SW,
has_bias=True,
nonlinear_mode="IDENTITY",
):
inp_v = np.random.normal(size=(N, IC, IH, IW))
w_v = np.random.normal(size=(OC, IC, KW, KW))
b_v = np.random.normal(size=(1, OC, 1, 1))
inp_scale = mgb.dtype.get_scale(inp_dtype)
w_scale = mgb.dtype.get_scale(w_dtype)
b_scale = mgb.dtype.get_scale(b_dtype)

inpv = mgb.dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
wv = mgb.dtype.convert_to_qint8(w_v * w_scale, w_dtype)
bv = mgb.dtype.convert_to_qint32(b_v * b_scale, b_dtype)

inp_int8 = tensor(inpv, dtype=inp_dtype)
w_int8 = Parameter(wv, dtype=w_dtype)
b_int32 = Parameter(bv, dtype=b_dtype)

inp_fp32 = inp_int8.astype("float32")
w_fp32 = w_int8.astype("float32")
b_fp32 = b_int32.astype("float32")

jit.trace.enabled = True
b_symbolic = True

def convert_to_nchw4(var):
return var.reshape(
var.shapeof(0), var.shapeof(1) // 4, 4, var.shapeof(2), var.shapeof(3)
).dimshuffle(0, 1, 3, 4, 2)

@jit.trace(symbolic=b_symbolic)
def run_conv2d(inp, w, b):
O = F.conv2d(
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
)
if nonlinear_mode == "RELU":
return F.relu(O)
else:
return O

@jit.trace(symbolic=b_symbolic)
def run_conv_bias(inp, w, b, format="NCHW"):
b = b if has_bias else np.zeros_like(b)
if format == "NCHW4":
inp = convert_to_nchw4(inp)
w = convert_to_nchw4(w)
b = F.flatten(b)
return F.conv_bias_activation(
inp,
w,
b,
stride=(SH, SW),
padding=(PH, PW),
dtype=out_dtype,
nonlinear_mode=nonlinear_mode,
)

format = "NCHW4" if is_cuda_available() else "NCHW"

expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
expected = expected.astype(out_dtype).astype("float32")
result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
"float32"
)
if format == "NCHW4":
result = result.dimshuffle(0, 1, 4, 2, 3)
expected = F.flatten(expected)
result = F.flatten(result)
assertTensorClose(result.numpy(), expected.numpy())

if not is_cuda_available():
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)

run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)

run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")

Loading…
Cancel
Save