Browse Source

test(mge/quantization): add `quantize_disabled` related test

GitOrigin-RevId: f62ba600c5
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
5c2323529d
3 changed files with 19 additions and 4 deletions
  1. +1
    -0
      python_module/megengine/module/module.py
  2. +1
    -3
      python_module/megengine/quantization/quantize.py
  3. +17
    -1
      python_module/test/unit/quantization/quantize.py

+ 1
- 0
python_module/megengine/module/module.py View File

@@ -318,6 +318,7 @@ class Module(metaclass=ABCMeta):
Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Could be used as a decorator.
"""

def fn(module: Module) -> None:
module.quantize_diabled = value



+ 1
- 3
python_module/megengine/quantization/quantize.py View File

@@ -80,9 +80,7 @@ def quantize(module: Module, inplace: bool = True):


def quantize_qat(
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig,
):
r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule`


+ 17
- 1
python_module/test/unit/quantization/quantize.py View File

@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import module as Float
from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat


def test_get_quantable_module_names():
@@ -36,3 +36,19 @@ def test_get_quantable_module_names():
and issubclass(value, Float.Module)
and value != Float.Module
)


def test_disable_quantize():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.conv = Float.ConvBnRelu2d(3, 3, 3)
self.conv.disable_quantize()

def forward(self, x):
return self.conv(x)

net = Net()
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)

Loading…
Cancel
Save