Browse Source

feat(mge/quantization): support distributed qat

GitOrigin-RevId: c915c843b8
release-1.1
Megvii Engine Team 4 years ago
parent
commit
33e8879af4
5 changed files with 103 additions and 4 deletions
  1. +1
    -0
      imperative/python/megengine/quantization/__init__.py
  2. +39
    -0
      imperative/python/megengine/quantization/observer.py
  3. +11
    -0
      imperative/python/megengine/quantization/qconfig.py
  4. +0
    -4
      imperative/python/test/unit/module/test_batchnorm.py
  5. +52
    -0
      imperative/python/test/unit/quantization/test_observer.py

+ 1
- 0
imperative/python/megengine/quantization/__init__.py View File

@@ -15,6 +15,7 @@ from .qconfig import (
ema_fakequant_qconfig,
ema_lowbit_fakequant_qconfig,
min_max_fakequant_qconfig,
sync_ema_fakequant_qconfig,
tqt_quant_qconfig,
)
from .utils import QuantMode

+ 39
- 0
imperative/python/megengine/quantization/observer.py View File

@@ -12,6 +12,8 @@ import numpy as np

from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min
from ..module import Module
from ..tensor import Tensor
from .utils import QuantMode, Round, get_qparam_dict
@@ -123,6 +125,21 @@ class MinMaxObserver(Observer):
return x_orig


class SyncMinMaxObserver(MinMaxObserver):
def forward(self, x_orig):
if self.enable:
x = x_orig.detach()
if is_distributed():
min_x = all_reduce_min(x.min(), WORLD)
max_x = all_reduce_max(x.max(), WORLD)
else:
min_x = x.min()
max_x = x.max()
self.min_val._reset(F.minimum(self.min_val, min_x))
self.max_val._reset(F.maximum(self.max_val, max_x))
return x_orig


class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(
self,
@@ -157,6 +174,28 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
return x_orig


class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
def forward(self, x_orig):
if self.enabled:
x = x_orig.detach()
if is_distributed:
min_x = all_reduce_min(x.min(), WORLD)
max_x = all_reduce_max(x.max(), WORLD)
else:
min_x = x.min()
max_x = x.max()
self.min_val._reset(
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * min_x
)
self.max_val._reset(
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * max_x
)
self.runtime_momentum = self.momentum
return x_orig


class HistogramObserver(MinMaxObserver):
def __init__(
self,


+ 11
- 0
imperative/python/megengine/quantization/qconfig.py View File

@@ -13,6 +13,8 @@ from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)


@@ -92,6 +94,15 @@ ema_fakequant_qconfig = QConfig(
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)

sync_ema_fakequant_qconfig = QConfig(
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)

ema_lowbit_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False),
act_observer=partial(


+ 0
- 4
imperative/python/test/unit/module/test_batchnorm.py View File

@@ -143,7 +143,6 @@ def test_batchnorm():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn1d():
nr_chan = 8
data_shape = (3, nr_chan, 4)
@@ -234,7 +233,6 @@ def test_batchnorm2d():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
@@ -305,7 +303,6 @@ def test_batchnorm_no_stats():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
@@ -354,7 +351,6 @@ def test_batchnorm2d_no_stats():
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.isolated_distributed
def test_syncbn2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)


+ 52
- 0
imperative/python/test/unit/quantization/test_observer.py View File

@@ -0,0 +1,52 @@
import multiprocessing as mp
import platform

import numpy as np
import pytest

import megengine as mge
import megengine.distributed as dist
import megengine.quantization.observer as ob
from megengine.distributed.helper import get_device_count_by_fork


def test_min_max_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
x = mge.tensor(x)
m = ob.MinMaxObserver()
m(x)
assert m.min_val == np_min and m.max_val == np_max


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_sync_min_max_observer():
x = np.random.rand(6, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
world_size = 2
port = dist.get_free_ports(1)[0]
server = dist.Server(port)

def worker(rank, slc):
dist.init_process_group("localhost", port, world_size, rank, rank)
m = ob.SyncMinMaxObserver()
y = mge.tensor(x[slc])
m(y)
assert m.min_val == np_min and m.max_val == np_max

procs = []
for rank in range(world_size):
slc = slice(rank * 3, (rank + 1) * 3)
p = mp.Process(target=worker, args=(rank, slc,), daemon=True)
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0

Loading…
Cancel
Save