GitOrigin-RevId: 0d78cbab9a
release-1.4
@@ -39,6 +39,9 @@ class QATModule(Module): | |||||
self.weight_fake_quant = None # type: FakeQuantize | self.weight_fake_quant = None # type: FakeQuantize | ||||
self.act_fake_quant = None # type: FakeQuantize | self.act_fake_quant = None # type: FakeQuantize | ||||
def __repr__(self): | |||||
return "QAT." + super().__repr__() | |||||
def set_qconfig(self, qconfig: QConfig): | def set_qconfig(self, qconfig: QConfig): | ||||
r""" | r""" | ||||
Set quantization related configs with ``qconfig``, including | Set quantization related configs with ``qconfig``, including | ||||
@@ -22,6 +22,9 @@ class QuantizedModule(Module): | |||||
raise ValueError("quantized module only support inference.") | raise ValueError("quantized module only support inference.") | ||||
return super().__call__(*inputs, **kwargs) | return super().__call__(*inputs, **kwargs) | ||||
def __repr__(self): | |||||
return "Quantized." + super().__repr__() | |||||
@classmethod | @classmethod | ||||
@abstractmethod | @abstractmethod | ||||
def from_qat_module(cls, qat_module: QATModule): | def from_qat_module(cls, qat_module: QATModule): | ||||
@@ -1,3 +1,11 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from functools import partial | from functools import partial | ||||
import numpy as np | import numpy as np | ||||
@@ -1,3 +1,11 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import platform | import platform | ||||
import numpy as np | import numpy as np | ||||
@@ -1,3 +1,11 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
@@ -0,0 +1,62 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import megengine.module as M | |||||
from megengine.quantization import quantize, quantize_qat | |||||
def test_repr(): | |||||
class Net(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv_bn = M.ConvBnRelu2d(3, 3, 3) | |||||
self.linear = M.Linear(3, 3) | |||||
def forward(self, x): | |||||
return x | |||||
net = Net() | |||||
ground_truth = ( | |||||
"Net(\n" | |||||
" (conv_bn): ConvBnRelu2d(\n" | |||||
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" | |||||
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" | |||||
" )\n" | |||||
" (linear): Linear(in_features=3, out_features=3, bias=True)\n" | |||||
")" | |||||
) | |||||
assert net.__repr__() == ground_truth | |||||
quantize_qat(net) | |||||
ground_truth = ( | |||||
"Net(\n" | |||||
" (conv_bn): QAT.ConvBnRelu2d(\n" | |||||
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" | |||||
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" | |||||
" (act_observer): ExponentialMovingAverageObserver()\n" | |||||
" (act_fake_quant): FakeQuantize()\n" | |||||
" (weight_observer): MinMaxObserver()\n" | |||||
" (weight_fake_quant): FakeQuantize()\n" | |||||
" )\n" | |||||
" (linear): QAT.Linear(\n" | |||||
" in_features=3, out_features=3, bias=True\n" | |||||
" (act_observer): ExponentialMovingAverageObserver()\n" | |||||
" (act_fake_quant): FakeQuantize()\n" | |||||
" (weight_observer): MinMaxObserver()\n" | |||||
" (weight_fake_quant): FakeQuantize()\n" | |||||
" )\n" | |||||
")" | |||||
) | |||||
assert net.__repr__() == ground_truth | |||||
quantize(net) | |||||
ground_truth = ( | |||||
"Net(\n" | |||||
" (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n" | |||||
" (linear): Quantized.Linear()\n" | |||||
")" | |||||
) | |||||
assert net.__repr__() == ground_truth |