GitOrigin-RevId: c915c843b8
release-1.1
@@ -15,6 +15,7 @@ from .qconfig import ( | |||
ema_fakequant_qconfig, | |||
ema_lowbit_fakequant_qconfig, | |||
min_max_fakequant_qconfig, | |||
sync_ema_fakequant_qconfig, | |||
tqt_quant_qconfig, | |||
) | |||
from .utils import QuantMode |
@@ -12,6 +12,8 @@ import numpy as np | |||
from .. import functional as F | |||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
from ..distributed import WORLD, get_rank, is_distributed | |||
from ..functional.distributed import all_reduce_max, all_reduce_min | |||
from ..module import Module | |||
from ..tensor import Tensor | |||
from .utils import QuantMode, Round, get_qparam_dict | |||
@@ -123,6 +125,21 @@ class MinMaxObserver(Observer): | |||
return x_orig | |||
class SyncMinMaxObserver(MinMaxObserver): | |||
def forward(self, x_orig): | |||
if self.enable: | |||
x = x_orig.detach() | |||
if is_distributed(): | |||
min_x = all_reduce_min(x.min(), WORLD) | |||
max_x = all_reduce_max(x.max(), WORLD) | |||
else: | |||
min_x = x.min() | |||
max_x = x.max() | |||
self.min_val._reset(F.minimum(self.min_val, min_x)) | |||
self.max_val._reset(F.maximum(self.max_val, max_x)) | |||
return x_orig | |||
class ExponentialMovingAverageObserver(MinMaxObserver): | |||
def __init__( | |||
self, | |||
@@ -157,6 +174,28 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||
return x_orig | |||
class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): | |||
def forward(self, x_orig): | |||
if self.enabled: | |||
x = x_orig.detach() | |||
if is_distributed: | |||
min_x = all_reduce_min(x.min(), WORLD) | |||
max_x = all_reduce_max(x.max(), WORLD) | |||
else: | |||
min_x = x.min() | |||
max_x = x.max() | |||
self.min_val._reset( | |||
self.min_val * self.runtime_momentum | |||
+ (1 - self.runtime_momentum) * min_x | |||
) | |||
self.max_val._reset( | |||
self.max_val * self.runtime_momentum | |||
+ (1 - self.runtime_momentum) * max_x | |||
) | |||
self.runtime_momentum = self.momentum | |||
return x_orig | |||
class HistogramObserver(MinMaxObserver): | |||
def __init__( | |||
self, | |||
@@ -13,6 +13,8 @@ from .observer import ( | |||
ExponentialMovingAverageObserver, | |||
HistogramObserver, | |||
MinMaxObserver, | |||
SyncExponentialMovingAverageObserver, | |||
SyncMinMaxObserver, | |||
) | |||
@@ -92,6 +94,15 @@ ema_fakequant_qconfig = QConfig( | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
) | |||
sync_ema_fakequant_qconfig = QConfig( | |||
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial( | |||
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||
), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
) | |||
ema_lowbit_fakequant_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), | |||
act_observer=partial( | |||
@@ -143,7 +143,6 @@ def test_batchnorm(): | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.isolated_distributed | |||
def test_syncbn1d(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 4) | |||
@@ -234,7 +233,6 @@ def test_batchnorm2d(): | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.isolated_distributed | |||
def test_syncbn2d(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 16, 16) | |||
@@ -305,7 +303,6 @@ def test_batchnorm_no_stats(): | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.isolated_distributed | |||
def test_syncbn_no_stats(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 4) | |||
@@ -354,7 +351,6 @@ def test_batchnorm2d_no_stats(): | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.isolated_distributed | |||
def test_syncbn2d_no_stats(): | |||
nr_chan = 8 | |||
data_shape = (3, nr_chan, 16, 16) | |||
@@ -0,0 +1,52 @@ | |||
import multiprocessing as mp | |||
import platform | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
import megengine.quantization.observer as ob | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
def test_min_max_observer(): | |||
x = np.random.rand(3, 3, 3, 3).astype("float32") | |||
np_min, np_max = x.min(), x.max() | |||
x = mge.tensor(x) | |||
m = ob.MinMaxObserver() | |||
m(x) | |||
assert m.min_val == np_min and m.max_val == np_max | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.skipif( | |||
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||
) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.isolated_distributed | |||
def test_sync_min_max_observer(): | |||
x = np.random.rand(6, 3, 3, 3).astype("float32") | |||
np_min, np_max = x.min(), x.max() | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
def worker(rank, slc): | |||
dist.init_process_group("localhost", port, world_size, rank, rank) | |||
m = ob.SyncMinMaxObserver() | |||
y = mge.tensor(x[slc]) | |||
m(y) | |||
assert m.min_val == np_min and m.max_val == np_max | |||
procs = [] | |||
for rank in range(world_size): | |||
slc = slice(rank * 3, (rank + 1) * 3) | |||
p = mp.Process(target=worker, args=(rank, slc,), daemon=True) | |||
p.start() | |||
procs.append(p) | |||
for p in procs: | |||
p.join(20) | |||
assert p.exitcode == 0 |