|
|
@@ -7,7 +7,7 @@ |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
from megengine import module as Float |
|
|
|
from megengine.module import qat as QAT |
|
|
|
from megengine.quantization.quantize import _get_quantable_module_names |
|
|
|
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat |
|
|
|
|
|
|
|
|
|
|
|
def test_get_quantable_module_names(): |
|
|
@@ -36,3 +36,19 @@ def test_get_quantable_module_names(): |
|
|
|
and issubclass(value, Float.Module) |
|
|
|
and value != Float.Module |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_disable_quantize(): |
|
|
|
class Net(Float.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.conv = Float.ConvBnRelu2d(3, 3, 3) |
|
|
|
self.conv.disable_quantize() |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return self.conv(x) |
|
|
|
|
|
|
|
net = Net() |
|
|
|
qat_net = quantize_qat(net, inplace=False) |
|
|
|
assert isinstance(qat_net.conv, Float.ConvBnRelu2d) |
|
|
|
assert isinstance(qat_net.conv.conv, Float.Conv2d) |