Browse Source

test(mge/quantization): add `quantize_disabled` related test

GitOrigin-RevId: f62ba600c5
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
5c2323529d
3 changed files with 19 additions and 4 deletions
  1. +1
    -0
      python_module/megengine/module/module.py
  2. +1
    -3
      python_module/megengine/quantization/quantize.py
  3. +17
    -1
      python_module/test/unit/quantization/quantize.py

+ 1
- 0
python_module/megengine/module/module.py View File

@@ -318,6 +318,7 @@ class Module(metaclass=ABCMeta):
Set ``module``'s ``quantize_diabled`` attribute and return ``module``. Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Could be used as a decorator. Could be used as a decorator.
""" """

def fn(module: Module) -> None: def fn(module: Module) -> None:
module.quantize_diabled = value module.quantize_diabled = value




+ 1
- 3
python_module/megengine/quantization/quantize.py View File

@@ -80,9 +80,7 @@ def quantize(module: Module, inplace: bool = True):




def quantize_qat( def quantize_qat(
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig,
): ):
r""" r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule` Recursively convert float :class:`~.Module` to :class:`~.QATModule`


+ 17
- 1
python_module/test/unit/quantization/quantize.py View File

@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import module as Float from megengine import module as Float
from megengine.module import qat as QAT from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat




def test_get_quantable_module_names(): def test_get_quantable_module_names():
@@ -36,3 +36,19 @@ def test_get_quantable_module_names():
and issubclass(value, Float.Module) and issubclass(value, Float.Module)
and value != Float.Module and value != Float.Module
) )


def test_disable_quantize():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.conv = Float.ConvBnRelu2d(3, 3, 3)
self.conv.disable_quantize()

def forward(self, x):
return self.conv(x)

net = Net()
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)

Loading…
Cancel
Save