|
|
@@ -0,0 +1,62 @@ |
|
|
|
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
|
# |
|
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
|
|
import megengine.module as M |
|
|
|
from megengine.quantization import quantize, quantize_qat |
|
|
|
|
|
|
|
|
|
|
|
def test_repr(): |
|
|
|
class Net(M.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.conv_bn = M.ConvBnRelu2d(3, 3, 3) |
|
|
|
self.linear = M.Linear(3, 3) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return x |
|
|
|
|
|
|
|
net = Net() |
|
|
|
ground_truth = ( |
|
|
|
"Net(\n" |
|
|
|
" (conv_bn): ConvBnRelu2d(\n" |
|
|
|
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" |
|
|
|
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" |
|
|
|
" )\n" |
|
|
|
" (linear): Linear(in_features=3, out_features=3, bias=True)\n" |
|
|
|
")" |
|
|
|
) |
|
|
|
assert net.__repr__() == ground_truth |
|
|
|
quantize_qat(net) |
|
|
|
ground_truth = ( |
|
|
|
"Net(\n" |
|
|
|
" (conv_bn): QAT.ConvBnRelu2d(\n" |
|
|
|
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n" |
|
|
|
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" |
|
|
|
" (act_observer): ExponentialMovingAverageObserver()\n" |
|
|
|
" (act_fake_quant): FakeQuantize()\n" |
|
|
|
" (weight_observer): MinMaxObserver()\n" |
|
|
|
" (weight_fake_quant): FakeQuantize()\n" |
|
|
|
" )\n" |
|
|
|
" (linear): QAT.Linear(\n" |
|
|
|
" in_features=3, out_features=3, bias=True\n" |
|
|
|
" (act_observer): ExponentialMovingAverageObserver()\n" |
|
|
|
" (act_fake_quant): FakeQuantize()\n" |
|
|
|
" (weight_observer): MinMaxObserver()\n" |
|
|
|
" (weight_fake_quant): FakeQuantize()\n" |
|
|
|
" )\n" |
|
|
|
")" |
|
|
|
) |
|
|
|
assert net.__repr__() == ground_truth |
|
|
|
quantize(net) |
|
|
|
ground_truth = ( |
|
|
|
"Net(\n" |
|
|
|
" (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n" |
|
|
|
" (linear): Quantized.Linear()\n" |
|
|
|
")" |
|
|
|
) |
|
|
|
assert net.__repr__() == ground_truth |