From bb94f2aa3db54ff45d1884b91583c4a3f1a566c2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 30 Mar 2021 14:45:06 +0800 Subject: [PATCH] feat(mge/module): add __repr__ method for qat and quantized module GitOrigin-RevId: 0d78cbab9a93801c62f0c5d079ec20ccc705ff17 --- imperative/python/megengine/module/qat/module.py | 3 ++ .../python/megengine/module/quantized/module.py | 3 ++ .../python/test/unit/quantization/test_module.py | 8 +++ .../python/test/unit/quantization/test_observer.py | 8 +++ .../python/test/unit/quantization/test_op.py | 8 +++ .../python/test/unit/quantization/test_repr.py | 62 ++++++++++++++++++++++ 6 files changed, 92 insertions(+) create mode 100644 imperative/python/test/unit/quantization/test_repr.py diff --git a/imperative/python/megengine/module/qat/module.py b/imperative/python/megengine/module/qat/module.py index ad66d844..3800e574 100644 --- a/imperative/python/megengine/module/qat/module.py +++ b/imperative/python/megengine/module/qat/module.py @@ -39,6 +39,9 @@ class QATModule(Module): self.weight_fake_quant = None # type: FakeQuantize self.act_fake_quant = None # type: FakeQuantize + def __repr__(self): + return "QAT." + super().__repr__() + def set_qconfig(self, qconfig: QConfig): r""" Set quantization related configs with ``qconfig``, including diff --git a/imperative/python/megengine/module/quantized/module.py b/imperative/python/megengine/module/quantized/module.py index dad477ed..3532375a 100644 --- a/imperative/python/megengine/module/quantized/module.py +++ b/imperative/python/megengine/module/quantized/module.py @@ -22,6 +22,9 @@ class QuantizedModule(Module): raise ValueError("quantized module only support inference.") return super().__call__(*inputs, **kwargs) + def __repr__(self): + return "Quantized." + super().__repr__() + @classmethod @abstractmethod def from_qat_module(cls, qat_module: QATModule): diff --git a/imperative/python/test/unit/quantization/test_module.py b/imperative/python/test/unit/quantization/test_module.py index 542afea7..c30d53f9 100644 --- a/imperative/python/test/unit/quantization/test_module.py +++ b/imperative/python/test/unit/quantization/test_module.py @@ -1,3 +1,11 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + from functools import partial import numpy as np diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index bd306244..691e701c 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -1,3 +1,11 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + import platform import numpy as np diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py index 53500751..a6a10a91 100644 --- a/imperative/python/test/unit/quantization/test_op.py +++ b/imperative/python/test/unit/quantization/test_op.py @@ -1,3 +1,11 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + import numpy as np import pytest diff --git a/imperative/python/test/unit/quantization/test_repr.py b/imperative/python/test/unit/quantization/test_repr.py new file mode 100644 index 00000000..fb0dd9b2 --- /dev/null +++ b/imperative/python/test/unit/quantization/test_repr.py @@ -0,0 +1,62 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import megengine.module as M +from megengine.quantization import quantize, quantize_qat + + +def test_repr(): + class Net(M.Module): + def __init__(self): + super().__init__() + self.conv_bn = M.ConvBnRelu2d(3, 3, 3) + self.linear = M.Linear(3, 3) + + def forward(self, x): + return x + + net = Net() + ground_truth = ( + "Net(\n" + " (conv_bn): ConvBnRelu2d(\n" + " (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" + " (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" + " )\n" + " (linear): Linear(in_features=3, out_features=3, bias=True)\n" + ")" + ) + assert net.__repr__() == ground_truth + quantize_qat(net) + ground_truth = ( + "Net(\n" + " (conv_bn): QAT.ConvBnRelu2d(\n" + " (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" + " (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" + " (act_observer): ExponentialMovingAverageObserver()\n" + " (act_fake_quant): FakeQuantize()\n" + " (weight_observer): MinMaxObserver()\n" + " (weight_fake_quant): FakeQuantize()\n" + " )\n" + " (linear): QAT.Linear(\n" + " in_features=3, out_features=3, bias=True\n" + " (act_observer): ExponentialMovingAverageObserver()\n" + " (act_fake_quant): FakeQuantize()\n" + " (weight_observer): MinMaxObserver()\n" + " (weight_fake_quant): FakeQuantize()\n" + " )\n" + ")" + ) + assert net.__repr__() == ground_truth + quantize(net) + ground_truth = ( + "Net(\n" + " (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n" + " (linear): Quantized.Linear()\n" + ")" + ) + assert net.__repr__() == ground_truth