Browse Source

feat(mge/parampack): add parampack in allreduce callback

GitOrigin-RevId: 73d53eeba1
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
b5016b9d29
7 changed files with 173 additions and 115 deletions
  1. +47
    -11
      imperative/python/megengine/autodiff/grad_manager.py
  2. +70
    -5
      imperative/python/megengine/distributed/helper.py
  3. +1
    -19
      imperative/python/megengine/distributed/server.py
  4. +19
    -0
      imperative/python/megengine/distributed/util.py
  5. +2
    -1
      imperative/python/megengine/functional/__init__.py
  6. +34
    -0
      imperative/python/megengine/functional/param_pack.py
  7. +0
    -79
      imperative/python/megengine/optimizer/param_pack.py

+ 47
- 11
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,19 +1,39 @@
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager


from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..distributed.util import Future
from ..tensor import tensor from ..tensor import tensor


backwarding_grad_manager = None


def get_backwarding_grad_manager():
return backwarding_grad_manager



class GradManager: class GradManager:
def __init__(self): def __init__(self):
self._call_back_pair = []
self._call_back_dict = defaultdict(list)
self._param_dict = dict()
self._recording = False self._recording = False
self._grad = None self._grad = None
self._after_backward_callback = []
self._gradients = dict()

def register(self, params, callbacks=[]):
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
self._call_back_dict[id(p)].append(cb)


def register(self, params, callbacks=None):
self._call_back_pair.append([list(params), callbacks or []])
def register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)


def backward(self, ys, dys=None): def backward(self, ys, dys=None):
global backwarding_grad_manager
cache = backwarding_grad_manager
backwarding_grad_manager = self
if not self._recording: if not self._recording:
raise RuntimeError( raise RuntimeError(
"no computation history. " "no computation history. "
@@ -29,8 +49,20 @@ class GradManager:
dys = [dys] dys = [dys]
try: try:
self._grad(ys, dys) self._grad(ys, dys)
for callback in self._after_backward_callback:
callback()
for p, grad in self._gradients.items():
if isinstance(grad, Future):
grad = grad.get()
param = self._param_dict[p]
if getattr(param, "grad", None) is None:
param.grad = grad
else:
param.grad += grad
finally: finally:
self._grad = None self._grad = None
self._gradients = dict()
backwarding_grad_manager = cache


def record(self): def record(self):
@contextmanager @contextmanager
@@ -41,20 +73,24 @@ class GradManager:
try: try:
self._recording = True self._recording = True
self._grad = grad self._grad = grad
for params, callbacks in self._call_back_pair:
for p in params:
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]


def callback(param, grad, callbacks=callbacks, p=p):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
p.grad = ret
def callback(
param, grad, callbacks=callbacks, p=param_wrapper, gm=self
):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret


grad.wrt(p, callback=callback)
grad.wrt(param_wrapper, callback=callback)
with grad: with grad:
yield yield
finally: finally:
self._recording = False self._recording = False
self._grad = None self._grad = None
self._gradients = dict()


return recorder() return recorder()

+ 70
- 5
imperative/python/megengine/distributed/helper.py View File

@@ -8,12 +8,33 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools import functools
import multiprocessing as mp import multiprocessing as mp
from collections import defaultdict
from typing import Callable from typing import Callable


from megengine.device import get_device_count
import numpy as np


from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count

from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
from .functional import all_reduce_sum, broadcast from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed from .group import WORLD, group_barrier, is_distributed
from .util import Future


class FakeTensor(Future):
def device(self):
raise "Sorry, this tensor is not ready"

def numpy(self):
raise "Sorry, this tensor is not ready"

def shape(self):
raise "Sorry, this tensor is not ready"

def dtype(self):
raise "Sorry, this tensor is not ready"




def synchronized(func: Callable): def synchronized(func: Callable):
@@ -52,14 +73,58 @@ def bcast_params_(params, group):


class AllreduceCallback: class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD): def __init__(self, reduce_method, group=WORLD):
reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method self._reduce_method = reduce_method
self._group = group self._group = group
self._gm_set = set()
self._param_pack_thd = 10 * 1024 * 1024
self._reset()

def _reset(self):
self._params = []
self._gradients_dict = dict()
self._futures_dict = dict()
self._packing_list = defaultdict(list)
self._packing_size = defaultdict(int)

def _pack(self, dtype):
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p.shape for p in self._packing_list[dtype]]
reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method
)
for param, grad in zip(self._packing_list[dtype], reduced_grads):
self._gradients_dict[param] = grad
self._packing_list[dtype] = []
self._packing_size[dtype] = 0


def __call__(self, param, grad): def __call__(self, param, grad):
ret = all_reduce_sum(grad, self._group)
if self._reduce_method == "MEAN":
ret = ret / self._group.size
return ret
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager)
if gm not in self._gm_set:
gm.register_after_backward_callback(self._flush)
self._gm_set.add(gm)
self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False)
self._gradients_dict[param] = grad

self._packing_list[param.dtype].append(param)
self._packing_size[param.dtype] += (
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize
)
if self._packing_size[param.dtype] > self._param_pack_thd:
self._pack(param.dtype)
return self._futures_dict[param]

def _flush(self):
for dtype in self._packing_list.keys():
self._pack(dtype)
for param in self._params:
grad = self._gradients_dict[param]
grad = copy(grad, get_default_device())
self._futures_dict[param].set(grad)
self._reset()




make_allreduce_cb = AllreduceCallback make_allreduce_cb = AllreduceCallback

+ 1
- 19
imperative/python/megengine/distributed/server.py View File

@@ -16,25 +16,7 @@ from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer


from ..core._imperative_rt.utils import create_mm_server from ..core._imperative_rt.utils import create_mm_server
from .util import get_free_ports


class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None

def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()

def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value
from .util import Future, get_free_ports




class Methods: class Methods:


+ 19
- 0
imperative/python/megengine/distributed/util.py View File

@@ -8,9 +8,28 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools import functools
import socket import socket
import threading
from typing import List from typing import List




class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None

def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()

def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value


def get_free_ports(num: int) -> List[int]: def get_free_ports(num: int) -> List[int]:
"""Get one or more free ports. """Get one or more free ports.
""" """


+ 2
- 1
imperative/python/megengine/functional/__init__.py View File

@@ -7,7 +7,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from . import distributed
from .elemwise import * from .elemwise import *
from .graph import add_update from .graph import add_update
from .loss import ( from .loss import (
@@ -27,6 +26,8 @@ from .quantized import conv_bias_activation
from .tensor import * from .tensor import *
from .utils import accuracy, copy, zero_grad from .utils import accuracy, copy, zero_grad


from . import distributed # isort:skip

# delete namespace # delete namespace
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
# del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] # del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined]

+ 34
- 0
imperative/python/megengine/functional/param_pack.py View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from ..functional.distributed import all_reduce_sum
from ..tensor import Tensor
from .tensor import param_pack_concat, param_pack_split


def get_offsets(shapes):
offsets = []
offset = 0
for shape in shapes:
offsets.append(offset)
offset += int(np.prod(shape))
offsets.append(offset)
return offsets


def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean":
packed_grads /= dist_group.size
grads = param_pack_split(packed_grads, offsets_val, shapes)
return grads

+ 0
- 79
imperative/python/megengine/optimizer/param_pack.py View File

@@ -1,79 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from ..functional import param_pack_concat, param_pack_split
from ..functional.distributed import all_reduce_sum
from ..tensor import Tensor


def get_offsets(shapes):
offsets = []
offset = 0
for shape in shapes:
offsets.append(offset)
offset += int(np.prod(shape))
offsets.append(offset)
return offsets


def get_pack_list(param_group, param_pack_thd):
pack_list = dict()
shape_list = dict()
pack_sum = dict()
pack_ret, shape_ret = [], []
ignore_first = 8
ignore_last = 0
orders_len = len(param_group["orders"])
for i, idx in enumerate(param_group["orders"]):
param = param_group["params"][idx]
dtype = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize
shape = param.shape
if ignore_first > 0:
ignore_first -= 1
pack_ret.append([idx])
shape_ret.append([shape])
continue
if dtype in pack_list.keys():
pack_list[dtype].append(idx)
shape_list[dtype].append(shape)
pack_sum[dtype] += int(np.prod(shape))
else:
pack_list[dtype] = [idx]
shape_list[dtype] = [shape]
pack_sum[dtype] = int(np.prod(shape))
if (
pack_sum[dtype] * dtype_size > param_pack_thd
or i + ignore_last > orders_len
):
pack_ret.append(pack_list[dtype])
shape_ret.append(shape_list[dtype])
pack_list[dtype] = []
shape_list[dtype] = []
pack_sum[dtype] = 0
for key in sorted(pack_list.keys()):
if len(pack_list[key]) > 0:
pack_ret.append(pack_list[key])
shape_ret.append(shape_list[key])
return pack_ret, shape_ret


def pack_allreduce_split(group, pack, shapes, reduce_method):
dist_group = group["dist_group"]
grads = [group["grads"][idx] for idx in pack]
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(grads, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, dist_group, dist_group.comp_node)
if reduce_method == "mean":
packed_grads /= dist_group.size
grads = param_pack_split(packed_grads, offsets_val, shapes)
for i, grad in enumerate(grads):
group["grads"][pack[i]] = grad

Loading…
Cancel
Save