Browse Source

feat(mge/module): add __repr__ method for qat and quantized module

GitOrigin-RevId: 0d78cbab9a
release-1.4
Megvii Engine Team 4 years ago
parent
commit
bb94f2aa3d
6 changed files with 92 additions and 0 deletions
  1. +3
    -0
      imperative/python/megengine/module/qat/module.py
  2. +3
    -0
      imperative/python/megengine/module/quantized/module.py
  3. +8
    -0
      imperative/python/test/unit/quantization/test_module.py
  4. +8
    -0
      imperative/python/test/unit/quantization/test_observer.py
  5. +8
    -0
      imperative/python/test/unit/quantization/test_op.py
  6. +62
    -0
      imperative/python/test/unit/quantization/test_repr.py

+ 3
- 0
imperative/python/megengine/module/qat/module.py View File

@@ -39,6 +39,9 @@ class QATModule(Module):
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize

def __repr__(self):
return "QAT." + super().__repr__()

def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including


+ 3
- 0
imperative/python/megengine/module/quantized/module.py View File

@@ -22,6 +22,9 @@ class QuantizedModule(Module):
raise ValueError("quantized module only support inference.")
return super().__call__(*inputs, **kwargs)

def __repr__(self):
return "Quantized." + super().__repr__()

@classmethod
@abstractmethod
def from_qat_module(cls, qat_module: QATModule):


+ 8
- 0
imperative/python/test/unit/quantization/test_module.py View File

@@ -1,3 +1,11 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from functools import partial

import numpy as np


+ 8
- 0
imperative/python/test/unit/quantization/test_observer.py View File

@@ -1,3 +1,11 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import platform

import numpy as np


+ 8
- 0
imperative/python/test/unit/quantization/test_op.py View File

@@ -1,3 +1,11 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import numpy as np
import pytest



+ 62
- 0
imperative/python/test/unit/quantization/test_repr.py View File

@@ -0,0 +1,62 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import megengine.module as M
from megengine.quantization import quantize, quantize_qat


def test_repr():
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv_bn = M.ConvBnRelu2d(3, 3, 3)
self.linear = M.Linear(3, 3)

def forward(self, x):
return x

net = Net()
ground_truth = (
"Net(\n"
" (conv_bn): ConvBnRelu2d(\n"
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" )\n"
" (linear): Linear(in_features=3, out_features=3, bias=True)\n"
")"
)
assert net.__repr__() == ground_truth
quantize_qat(net)
ground_truth = (
"Net(\n"
" (conv_bn): QAT.ConvBnRelu2d(\n"
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (act_observer): ExponentialMovingAverageObserver()\n"
" (act_fake_quant): FakeQuantize()\n"
" (weight_observer): MinMaxObserver()\n"
" (weight_fake_quant): FakeQuantize()\n"
" )\n"
" (linear): QAT.Linear(\n"
" in_features=3, out_features=3, bias=True\n"
" (act_observer): ExponentialMovingAverageObserver()\n"
" (act_fake_quant): FakeQuantize()\n"
" (weight_observer): MinMaxObserver()\n"
" (weight_fake_quant): FakeQuantize()\n"
" )\n"
")"
)
assert net.__repr__() == ground_truth
quantize(net)
ground_truth = (
"Net(\n"
" (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n"
" (linear): Quantized.Linear()\n"
")"
)
assert net.__repr__() == ground_truth

Loading…
Cancel
Save