GitOrigin-RevId: f6dbd1f4c0
release-1.1
@@ -45,8 +45,6 @@ def conv_bias_activation( | |||||
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | ||||
'CROSS_CORRELATION' | 'CROSS_CORRELATION' | ||||
:param dtype: support for ``np.dtype``, Default: np.int8 | :param dtype: support for ``np.dtype``, Default: np.int8 | ||||
:param scale: scale if use quantization, Default: 0.0 | |||||
:param zero_point: scale if use quantization quint8, Default: 0.0 | |||||
:type compute_mode: string or | :type compute_mode: string or | ||||
:class:`P.Convolution.ComputeMode`. | :class:`P.Convolution.ComputeMode`. | ||||
:param compute_mode: when set to "DEFAULT", no special requirements will be | :param compute_mode: when set to "DEFAULT", no special requirements will be | ||||
@@ -75,3 +73,63 @@ def conv_bias_activation( | |||||
) | ) | ||||
(outputs,) = apply(op, inp, weight, bias) | (outputs,) = apply(op, inp, weight, bias) | ||||
return outputs | return outputs | ||||
def batch_conv_bias_activation( | |||||
inp: Tensor, | |||||
weight: Tensor, | |||||
bias: Tensor, | |||||
dtype=None, | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
nonlinear_mode="IDENTITY", | |||||
conv_mode="CROSS_CORRELATION", | |||||
compute_mode="DEFAULT", | |||||
) -> Tensor: | |||||
""" | |||||
Batch convolution bias with activation operation, only for inference. | |||||
:param inp: feature map of the convolution operation. | |||||
:param weight: convolution kernel in batched way. | |||||
:param bias: bias added to the result of convolution | |||||
:param stride: stride of the 2D convolution operation. Default: 1 | |||||
:param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
and the shape of weight should be `(groups, out_channel // groups, | |||||
in_channels // groups, height, width)`. | |||||
:type conv_mode: string or :class:`P.Convolution.Mode`. | |||||
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
'CROSS_CORRELATION' | |||||
:param dtype: support for ``np.dtype``, Default: np.int8 | |||||
:type compute_mode: string or | |||||
:class:`P.Convolution.ComputeMode`. | |||||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
placed on the precision of intermediate results. When set to "FLOAT32", | |||||
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||||
""" | |||||
ph, pw = _pair(padding) | |||||
sh, sw = _pair_nonzero(stride) | |||||
dh, dw = _pair_nonzero(dilation) | |||||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
op = builtin.BatchConvBiasForward( | |||||
stride_h=sh, | |||||
stride_w=sw, | |||||
pad_h=ph, | |||||
pad_w=pw, | |||||
dilate_h=dh, | |||||
dilate_w=dw, | |||||
dtype=dtype, | |||||
format="NCHW", | |||||
strategy=get_conv_execution_strategy(), | |||||
nonlineMode=nonlinear_mode, | |||||
mode=conv_mode, | |||||
compute_mode=compute_mode, | |||||
sparse=sparse_type, | |||||
) | |||||
(outputs,) = apply(op, inp, weight, bias) | |||||
return outputs |
@@ -9,6 +9,7 @@ | |||||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | ||||
from .batch_matmul_activation import BatchMatMulActivation | |||||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
from .concat import Concat | from .concat import Concat | ||||
from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | ||||
@@ -0,0 +1,67 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
from ..functional import matmul, relu | |||||
from ..tensor import Parameter | |||||
from . import init | |||||
from .module import Module | |||||
class BatchMatMulActivation(Module): | |||||
r""" | |||||
Batched MatMul with activation(only relu supported), no transpose anywhere. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
batch: int, | |||||
in_features: int, | |||||
out_features: int, | |||||
bias: bool = True, | |||||
nonlinear_mode="IDENTITY", | |||||
**kwargs | |||||
): | |||||
super().__init__(**kwargs) | |||||
self.batch = batch | |||||
self.out_features = out_features | |||||
self.in_features = in_features | |||||
w_shape = (batch, out_features, in_features) | |||||
self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||||
self.bias = None | |||||
if bias: | |||||
b_shape = (out_features,) | |||||
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||||
self.nonlinear_mode = nonlinear_mode | |||||
self.reset_parameters() | |||||
def _get_fanin(self): | |||||
return self.in_features | |||||
def reset_parameters(self) -> None: | |||||
fanin = self._get_fanin() | |||||
std = np.sqrt(1 / fanin) | |||||
init.normal_(self.weight, 0.0, std) | |||||
if self.bias is not None: | |||||
init.zeros_(self.bias) | |||||
def _calc_linear(self, x, weight, bias): | |||||
res = matmul(weight, x) | |||||
if self.bias is not None: | |||||
res += bias | |||||
if self.nonlinear_mode == "RELU": | |||||
res = relu(res) | |||||
return res | |||||
def forward(self, x): | |||||
return self._calc_linear(x, self.weight, self.bias) | |||||
def _module_info_string(self) -> str: | |||||
return "batch={}, in_features={}, out_features={}, bias={}".format( | |||||
self.batch, self.in_features, self.out_features, self.bias is not None | |||||
) |
@@ -5,6 +5,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .batch_matmul_activation import BatchMatMulActivation | |||||
from .concat import Concat | from .concat import Concat | ||||
from .conv import Conv2d, ConvRelu2d | from .conv import Conv2d, ConvRelu2d | ||||
from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
@@ -0,0 +1,30 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import batch_matmul_activation as Float | |||||
from .module import QATModule | |||||
class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): | |||||
def forward(self, inp): | |||||
w_qat = self.apply_quant_weight(self.weight) | |||||
b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) | |||||
@classmethod | |||||
def from_float_module(cls, float_module: Float.BatchMatMulActivation): | |||||
qat_module = cls( | |||||
float_module.batch, | |||||
float_module.in_features, | |||||
float_module.out_features, | |||||
float_module.bias is not None, | |||||
) | |||||
qat_module.weight = float_module.weight | |||||
qat_module.bias = float_module.bias | |||||
return qat_module |
@@ -5,6 +5,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .batch_matmul_activation import BatchMatMulActivation | |||||
from .concat import Concat | from .concat import Concat | ||||
from .conv import Conv2d, ConvRelu2d | from .conv import Conv2d, ConvRelu2d | ||||
from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
@@ -0,0 +1,76 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Tuple, Union | |||||
import numpy as np | |||||
from ... import module as Float | |||||
from ...core.tensor import dtype | |||||
from ...functional import expand_dims, squeeze | |||||
from ...functional.quantized import batch_conv_bias_activation | |||||
from ...tensor import Parameter | |||||
from ..qat import batch_matmul_activation as QAT | |||||
from .module import QuantizedModule | |||||
class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): | |||||
def __init__( | |||||
self, | |||||
batch: int, | |||||
in_features: int, | |||||
out_features: int, | |||||
bias: bool = True, | |||||
nonlinear_mode="IDENTITY", | |||||
dtype=None, | |||||
**kwargs | |||||
): | |||||
super().__init__(batch, in_features, out_features, bias, **kwargs) | |||||
self.output_dtype = dtype | |||||
def calc_bmm_quantized(self, inp): | |||||
inp_scale = dtype.get_scale(inp.dtype) | |||||
w_scale = dtype.get_scale(self.weight.dtype) | |||||
bias_scale = inp_scale * w_scale | |||||
inp = expand_dims(inp, [-1]) | |||||
res = batch_conv_bias_activation( | |||||
inp, | |||||
self.weight, | |||||
self.bias.astype(dtype.qint32(bias_scale)), | |||||
dtype=self.output_dtype, | |||||
stride=1, | |||||
padding=0, | |||||
dilation=1, | |||||
groups=1, | |||||
nonlinear_mode=self.nonlinear_mode, | |||||
) | |||||
return squeeze(res, -1) | |||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT.BatchMatMulActivation): | |||||
output_dtype = qat_module.get_activation_dtype() | |||||
qbmm = cls( | |||||
qat_module.batch, | |||||
qat_module.in_features, | |||||
qat_module.out_features, | |||||
qat_module.bias is not None, | |||||
dtype=output_dtype, | |||||
) | |||||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||||
weight = expand_dims(weight, [-1, -2]) | |||||
qbmm.weight = Parameter(weight.numpy()) | |||||
if qat_module.bias is not None: | |||||
bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) | |||||
qbmm.bias = Parameter(bias.numpy()) | |||||
else: | |||||
qbmm.bias = Parameter( | |||||
np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) | |||||
) | |||||
return qbmm | |||||
def forward(self, inp): | |||||
return self.calc_bmm_quantized(inp) |
@@ -20,6 +20,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor | |||||
from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
def test_where(): | def test_where(): | ||||
@@ -420,7 +421,9 @@ def test_nms(): | |||||
np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) | np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) | ||||
@pytest.mark.skip(reason="cuda does not support nchw int8") | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||||
) | |||||
def test_conv_bias(): | def test_conv_bias(): | ||||
inp_scale = 1.5 | inp_scale = 1.5 | ||||
w_scale = 2.5 | w_scale = 2.5 | ||||
@@ -446,7 +449,7 @@ def test_conv_bias(): | |||||
nonlinear_mode="IDENTITY", | nonlinear_mode="IDENTITY", | ||||
): | ): | ||||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | inp_v = np.random.normal(size=(N, IC, IH, IW)) | ||||
w_v = np.random.normal(size=(OC, IC, KW, KW)) | |||||
w_v = np.random.normal(size=(OC, IC, KH, KW)) | |||||
b_v = np.random.normal(size=(1, OC, 1, 1)) | b_v = np.random.normal(size=(1, OC, 1, 1)) | ||||
inp_scale = dtype.get_scale(inp_dtype) | inp_scale = dtype.get_scale(inp_dtype) | ||||
w_scale = dtype.get_scale(w_dtype) | w_scale = dtype.get_scale(w_dtype) | ||||
@@ -486,13 +489,12 @@ def test_conv_bias(): | |||||
inp = convert_to_nchw4(inp) | inp = convert_to_nchw4(inp) | ||||
w = convert_to_nchw4(w) | w = convert_to_nchw4(w) | ||||
b = convert_to_nchw4(b) | b = convert_to_nchw4(b) | ||||
return F.nn.conv_bias_activation( | |||||
return F.quantized.conv_bias_activation( | |||||
inp, | inp, | ||||
w, | w, | ||||
b, | b, | ||||
stride=(SH, SW), | stride=(SH, SW), | ||||
padding=(PH, PW), | padding=(PH, PW), | ||||
format=format, | |||||
dtype=out_dtype, | dtype=out_dtype, | ||||
nonlinear_mode=nonlinear_mode, | nonlinear_mode=nonlinear_mode, | ||||
) | ) | ||||
@@ -522,6 +524,59 @@ def test_conv_bias(): | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | ||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||||
) | |||||
def test_batch_conv_bias(): | |||||
inp_scale = 1.5 | |||||
w_scale = 2.5 | |||||
outp_scale = 1.5 | |||||
inp_dtype = dtype.qint8(inp_scale) | |||||
w_dtype = dtype.qint8(w_scale) | |||||
b_dtype = dtype.qint32(inp_scale * w_scale) | |||||
out_dtype = dtype.qint8(outp_scale) | |||||
def run( | |||||
N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True, | |||||
): | |||||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||||
w_v = np.random.normal(size=(N, OC, IC, KH, KW)) | |||||
b_v = np.random.normal(size=(1, OC, 1, 1)) | |||||
inp_scale = dtype.get_scale(inp_dtype) | |||||
w_scale = dtype.get_scale(w_dtype) | |||||
b_scale = dtype.get_scale(b_dtype) | |||||
inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||||
wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||||
bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||||
inp_int8 = tensor(inpv, dtype=inp_dtype) | |||||
w_int8 = Parameter(wv, dtype=w_dtype) | |||||
b_int32 = Parameter(bv, dtype=b_dtype) | |||||
inp_fp32 = inp_int8.astype("float32") | |||||
w_fp32 = w_int8.astype("float32") | |||||
b_fp32 = b_int32.astype("float32") | |||||
def run_batch_conv_bias(inp, w, b): | |||||
b = b if has_bias else Parameter(np.zeros_like(b.numpy())) | |||||
result = F.quantized.batch_conv_bias_activation( | |||||
inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype, | |||||
) | |||||
return result.astype("float32") | |||||
expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0] | |||||
expected = expected.astype(out_dtype).astype("float32") | |||||
expected = F.flatten(expected) | |||||
result = run_batch_conv_bias(inp_int8, w_int8, b_int32) | |||||
result = F.flatten(result) | |||||
np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale) | |||||
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||||
def test_zero_stride_numpy_array(): | def test_zero_stride_numpy_array(): | ||||
inp = np.random.randn(3, 224, 224).astype(np.float32) | inp = np.random.randn(3, 224, 224).astype(np.float32) | ||||
inp = inp[np.newaxis, :] | inp = inp[np.newaxis, :] | ||||
@@ -1,9 +1,15 @@ | |||||
import io | |||||
from itertools import product | from itertools import product | ||||
import numpy as np | import numpy as np | ||||
import pytest | |||||
from megengine import tensor | |||||
import megengine.utils.comp_graph_tools as cgtools | |||||
from megengine import jit, tensor | |||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
from megengine.functional import expand_dims | |||||
from megengine.module import ( | from megengine.module import ( | ||||
BatchMatMulActivation, | |||||
Conv2d, | Conv2d, | ||||
ConvBn2d, | ConvBn2d, | ||||
ConvRelu2d, | ConvRelu2d, | ||||
@@ -11,7 +17,12 @@ from megengine.module import ( | |||||
Module, | Module, | ||||
QuantStub, | QuantStub, | ||||
) | ) | ||||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||||
from megengine.quantization.quantize import ( | |||||
disable_fake_quant, | |||||
enable_fake_quant, | |||||
quantize, | |||||
quantize_qat, | |||||
) | |||||
def test_qat_convbn2d(): | def test_qat_convbn2d(): | ||||
@@ -88,3 +99,107 @@ def test_qat_conv(): | |||||
qat_net.eval() | qat_net.eval() | ||||
qat_outputs = qat_net(inputs) | qat_outputs = qat_net(inputs) | ||||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | ||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||||
) | |||||
def test_qat_batchmatmul_activation(): | |||||
batch = 4 | |||||
in_features = 8 | |||||
out_features = 4 | |||||
class TestNet(Module): | |||||
def __init__(self, bias): | |||||
super().__init__() | |||||
self.quant = QuantStub() | |||||
self.dequant = DequantStub() | |||||
self.batch_mm = BatchMatMulActivation( | |||||
batch, in_features, out_features, bias=bias | |||||
) | |||||
def forward(self, inp): | |||||
out = self.quant(inp) | |||||
out = self.batch_mm(out) | |||||
out = self.dequant(out) | |||||
return out | |||||
inputs = tensor( | |||||
np.random.randn(batch, in_features, out_features).astype(np.float32) | |||||
) | |||||
for bias in (True, False): | |||||
net = TestNet(bias) | |||||
net.train() | |||||
qat_net = quantize_qat(net, inplace=False) | |||||
disable_fake_quant(qat_net) | |||||
normal_outputs = net(inputs) | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
net.eval() | |||||
normal_outputs = net(inputs) | |||||
qat_net.eval() | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
@pytest.mark.skip(reason="FIXME: abnormal exit") | |||||
def test_quantize_batchmatmul_activation(): | |||||
batch = 4 | |||||
in_features = 8 | |||||
out_features = 4 | |||||
class TestNet(Module): | |||||
def __init__(self, bias): | |||||
super().__init__() | |||||
self.quant = QuantStub() | |||||
self.dequant = DequantStub() | |||||
self.batch_mm = BatchMatMulActivation( | |||||
batch, in_features, out_features, bias=bias | |||||
) | |||||
def forward(self, inp): | |||||
out = self.quant(inp) | |||||
out = self.batch_mm(out) | |||||
out = expand_dims(out, -1) | |||||
out = self.dequant(out) | |||||
return out | |||||
inputs = tensor( | |||||
np.random.randn(batch, in_features, out_features).astype(np.float32) | |||||
) | |||||
for bias in (True, False): | |||||
net = TestNet(bias) | |||||
net.train() | |||||
qat_net = quantize_qat(net, inplace=False) | |||||
disable_fake_quant(qat_net) | |||||
normal_outputs = net(inputs) | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
net.eval() | |||||
normal_outputs = net(inputs) | |||||
qat_net.eval() | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
enable_fake_quant(qat_net) | |||||
qat_outputs = qat_net(inputs) | |||||
qnet = quantize(qat_net, inplace=False) | |||||
qnet.eval() | |||||
quantize_outputs = qnet(inputs) | |||||
np.testing.assert_allclose( | |||||
qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6 | |||||
) | |||||
@jit.trace(capture_as_const=True) | |||||
def f(x): | |||||
qnet.eval() | |||||
return qnet(x) | |||||
f(inputs) | |||||
file = io.BytesIO() | |||||
f.dump(file, enable_nchw4=True) | |||||
file.seek(0) | |||||
dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] | |||||
np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) |