GitOrigin-RevId: c915c843b8
release-1.1
@@ -15,6 +15,7 @@ from .qconfig import ( | |||||
ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
ema_lowbit_fakequant_qconfig, | ema_lowbit_fakequant_qconfig, | ||||
min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
sync_ema_fakequant_qconfig, | |||||
tqt_quant_qconfig, | tqt_quant_qconfig, | ||||
) | ) | ||||
from .utils import QuantMode | from .utils import QuantMode |
@@ -12,6 +12,8 @@ import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | ||||
from ..distributed import WORLD, get_rank, is_distributed | |||||
from ..functional.distributed import all_reduce_max, all_reduce_min | |||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .utils import QuantMode, Round, get_qparam_dict | from .utils import QuantMode, Round, get_qparam_dict | ||||
@@ -123,6 +125,21 @@ class MinMaxObserver(Observer): | |||||
return x_orig | return x_orig | ||||
class SyncMinMaxObserver(MinMaxObserver): | |||||
def forward(self, x_orig): | |||||
if self.enable: | |||||
x = x_orig.detach() | |||||
if is_distributed(): | |||||
min_x = all_reduce_min(x.min(), WORLD) | |||||
max_x = all_reduce_max(x.max(), WORLD) | |||||
else: | |||||
min_x = x.min() | |||||
max_x = x.max() | |||||
self.min_val._reset(F.minimum(self.min_val, min_x)) | |||||
self.max_val._reset(F.maximum(self.max_val, max_x)) | |||||
return x_orig | |||||
class ExponentialMovingAverageObserver(MinMaxObserver): | class ExponentialMovingAverageObserver(MinMaxObserver): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
@@ -157,6 +174,28 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
return x_orig | return x_orig | ||||
class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): | |||||
def forward(self, x_orig): | |||||
if self.enabled: | |||||
x = x_orig.detach() | |||||
if is_distributed: | |||||
min_x = all_reduce_min(x.min(), WORLD) | |||||
max_x = all_reduce_max(x.max(), WORLD) | |||||
else: | |||||
min_x = x.min() | |||||
max_x = x.max() | |||||
self.min_val._reset( | |||||
self.min_val * self.runtime_momentum | |||||
+ (1 - self.runtime_momentum) * min_x | |||||
) | |||||
self.max_val._reset( | |||||
self.max_val * self.runtime_momentum | |||||
+ (1 - self.runtime_momentum) * max_x | |||||
) | |||||
self.runtime_momentum = self.momentum | |||||
return x_orig | |||||
class HistogramObserver(MinMaxObserver): | class HistogramObserver(MinMaxObserver): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
@@ -13,6 +13,8 @@ from .observer import ( | |||||
ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
HistogramObserver, | HistogramObserver, | ||||
MinMaxObserver, | MinMaxObserver, | ||||
SyncExponentialMovingAverageObserver, | |||||
SyncMinMaxObserver, | |||||
) | ) | ||||
@@ -92,6 +94,15 @@ ema_fakequant_qconfig = QConfig( | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | ||||
) | ) | ||||
sync_ema_fakequant_qconfig = QConfig( | |||||
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial( | |||||
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
) | |||||
ema_lowbit_fakequant_qconfig = QConfig( | ema_lowbit_fakequant_qconfig = QConfig( | ||||
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), | weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), | ||||
act_observer=partial( | act_observer=partial( | ||||
@@ -143,7 +143,6 @@ def test_batchnorm(): | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | ||||
) | ) | ||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn1d(): | def test_syncbn1d(): | ||||
nr_chan = 8 | nr_chan = 8 | ||||
data_shape = (3, nr_chan, 4) | data_shape = (3, nr_chan, 4) | ||||
@@ -234,7 +233,6 @@ def test_batchnorm2d(): | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | ||||
) | ) | ||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn2d(): | def test_syncbn2d(): | ||||
nr_chan = 8 | nr_chan = 8 | ||||
data_shape = (3, nr_chan, 16, 16) | data_shape = (3, nr_chan, 16, 16) | ||||
@@ -305,7 +303,6 @@ def test_batchnorm_no_stats(): | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | ||||
) | ) | ||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn_no_stats(): | def test_syncbn_no_stats(): | ||||
nr_chan = 8 | nr_chan = 8 | ||||
data_shape = (3, nr_chan, 4) | data_shape = (3, nr_chan, 4) | ||||
@@ -354,7 +351,6 @@ def test_batchnorm2d_no_stats(): | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | ||||
) | ) | ||||
@pytest.mark.isolated_distributed | |||||
def test_syncbn2d_no_stats(): | def test_syncbn2d_no_stats(): | ||||
nr_chan = 8 | nr_chan = 8 | ||||
data_shape = (3, nr_chan, 16, 16) | data_shape = (3, nr_chan, 16, 16) | ||||
@@ -0,0 +1,52 @@ | |||||
import multiprocessing as mp | |||||
import platform | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.distributed as dist | |||||
import megengine.quantization.observer as ob | |||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
def test_min_max_observer(): | |||||
x = np.random.rand(3, 3, 3, 3).astype("float32") | |||||
np_min, np_max = x.min(), x.max() | |||||
x = mge.tensor(x) | |||||
m = ob.MinMaxObserver() | |||||
m(x) | |||||
assert m.min_val == np_min and m.max_val == np_max | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||||
) | |||||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
@pytest.mark.isolated_distributed | |||||
def test_sync_min_max_observer(): | |||||
x = np.random.rand(6, 3, 3, 3).astype("float32") | |||||
np_min, np_max = x.min(), x.max() | |||||
world_size = 2 | |||||
port = dist.get_free_ports(1)[0] | |||||
server = dist.Server(port) | |||||
def worker(rank, slc): | |||||
dist.init_process_group("localhost", port, world_size, rank, rank) | |||||
m = ob.SyncMinMaxObserver() | |||||
y = mge.tensor(x[slc]) | |||||
m(y) | |||||
assert m.min_val == np_min and m.max_val == np_max | |||||
procs = [] | |||||
for rank in range(world_size): | |||||
slc = slice(rank * 3, (rank + 1) * 3) | |||||
p = mp.Process(target=worker, args=(rank, slc,), daemon=True) | |||||
p.start() | |||||
procs.append(p) | |||||
for p in procs: | |||||
p.join(20) | |||||
assert p.exitcode == 0 |