Browse Source

feat(mge): remove param_pack_* from functional

GitOrigin-RevId: a5fe25be8c
release-1.1
Megvii Engine Team 4 years ago
parent
commit
09241a1ff7
5 changed files with 124 additions and 132 deletions
  1. +102
    -1
      imperative/python/megengine/distributed/helper.py
  2. +0
    -34
      imperative/python/megengine/functional/param_pack.py
  3. +0
    -80
      imperative/python/megengine/functional/tensor.py
  4. +22
    -1
      imperative/python/test/unit/distributed/test_distributed.py
  5. +0
    -16
      imperative/python/test/unit/functional/test_tensor.py

+ 102
- 1
imperative/python/megengine/distributed/helper.py View File

@@ -17,13 +17,114 @@ import numpy as np
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count from megengine.device import get_default_device, get_device_count


from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..core.tensor.core import apply
from ..functional.utils import copy from ..functional.utils import copy
from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future
from .functional import all_reduce_sum, broadcast from .functional import all_reduce_sum, broadcast
from .group import WORLD, Group, group_barrier, is_distributed from .group import WORLD, Group, group_barrier, is_distributed




def param_pack_split(inp: Tensor, offsets: list, shapes: list):
r"""
Returns split tensor to tensor list as offsets and shapes described,
only used for ``parampack``.

:param inp: input tensor.
:param offsets: offsets of outputs, length of `2 * n`,
while n is tensor nums you want to split,
format `[begin0, end0, begin1, end1]`.
:param shapes: tensor shapes of outputs.
:return: splitted tensors.

Examples:

.. testcode::

import numpy as np
from megengine import tensor
from megengine.distributed.helper import param_pack_split

a = tensor(np.ones((10,), np.int32))
b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy())
print(c.numpy())

Outputs:

.. testoutput::

[1]
[[1 1 1]
[1 1 1]
[1 1 1]]

"""
op = ParamPackSplit()
op.offsets = offsets
op.shapes = shapes
return apply(op, inp)


def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
r"""
Returns concated tensor, only used for ``parampack``.

:param inps: input tensors.
:param offsets: device value of offsets.
:param offsets_val: offsets of inputs, length of `2 * n`,
format `[begin0, end0, begin1, end1]`.
:return: concated tensor.

Examples:

.. testcode::

import numpy as np
from megengine import tensor
from megengine.distributed.helper import param_pack_concat

a = tensor(np.ones((1,), np.int32))
b = tensor(np.ones((3, 3), np.int32))
offsets_val = [0, 1, 1, 10]
offsets = tensor(offsets_val, np.int32)
c = param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())

Outputs:

.. testoutput::

[1 1 1 1 1 1 1 1 1 1]

"""
op = ParamPackConcat()
op.offsets = offsets_val
return apply(op, *inps, offsets)[0]


def get_offsets(shapes):
offsets = []
offset = 0
for shape in shapes:
offsets.append(offset)
offset += int(np.prod(shape))
offsets.append(offset)
return offsets


def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean":
packed_grads /= group.size
grads = param_pack_split(packed_grads, offsets_val, shapes)
return grads


class TensorFuture(Future): class TensorFuture(Future):
def device(self): def device(self):
raise "Sorry, this tensor is not ready" raise "Sorry, this tensor is not ready"


+ 0
- 34
imperative/python/megengine/functional/param_pack.py View File

@@ -1,34 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from ..tensor import Tensor
from .distributed import all_reduce_sum
from .tensor import param_pack_concat, param_pack_split


def get_offsets(shapes):
offsets = []
offset = 0
for shape in shapes:
offsets.append(offset)
offset += int(np.prod(shape))
offsets.append(offset)
return offsets


def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean":
packed_grads /= group.size
grads = param_pack_split(packed_grads, offsets_val, shapes)
return grads

+ 0
- 80
imperative/python/megengine/functional/tensor.py View File

@@ -46,8 +46,6 @@ __all__ = [
"linspace", "linspace",
"ones", "ones",
"ones_like", "ones_like",
"param_pack_concat",
"param_pack_split",
"reshape", "reshape",
"split", "split",
"squeeze", "squeeze",
@@ -975,81 +973,3 @@ def arange(
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:
return result.astype(dtype) return result.astype(dtype)
return result return result


def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
r"""
Returns split tensor to tensor list as offsets and shapes described,
only used for ``parampack``.

:param inp: input tensor.
:param offsets: offsets of outputs, length of `2 * n`,
while n is tensor nums you want to split,
format `[begin0, end0, begin1, end1]`.
:param shapes: tensor shapes of outputs.
:return: splitted tensors.

Examples:

.. testcode::

import numpy as np
import megengine.functional as F
from megengine import tensor

a = tensor(np.ones((10,), np.int32))
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy())
print(c.numpy())

Outputs:

.. testoutput::

[1]
[[1 1 1]
[1 1 1]
[1 1 1]]

"""
op = builtin.ParamPackSplit()
op.offsets = offsets
op.shapes = shapes
return apply(op, inp)


def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
r"""
Returns concated tensor, only used for ``parampack``.

:param inps: input tensors.
:param offsets: device value of offsets.
:param offsets_val: offsets of inputs, length of `2 * n`,
format `[begin0, end0, begin1, end1]`.
:return: concated tensor.

Examples:

.. testcode::

import numpy as np
import megengine.functional as F
from megengine import tensor

a = tensor(np.ones((1,), np.int32))
b = tensor(np.ones((3, 3), np.int32))
offsets_val = [0, 1, 1, 10]
offsets = tensor(offsets_val, np.int32)
c = F.param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())

Outputs:

.. testoutput::

[1 1 1 1 1 1 1 1 1 1]

"""
op = builtin.ParamPackConcat()
op.offsets = offsets_val
return apply(op, *inps, offsets)[0]

+ 22
- 1
imperative/python/test/unit/distributed/test_distributed.py View File

@@ -10,12 +10,17 @@ import multiprocessing as mp
import platform import platform
import queue import queue


import numpy as np
import pytest import pytest


import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit
from megengine.distributed.helper import get_device_count_by_fork
from megengine.distributed.helper import (
get_device_count_by_fork,
param_pack_concat,
param_pack_split,
)




def _assert_q_empty(q): def _assert_q_empty(q):
@@ -195,3 +200,19 @@ def test_oprmm_hashable():
rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit()) rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
assert lhs == rhs assert lhs == rhs
assert hash(lhs) == hash(rhs) assert hash(lhs) == hash(rhs)


def test_param_pack_split():
a = mge.Tensor(np.ones((10,), np.int32))
b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
assert np.allclose(b.numpy(), a.numpy()[1])
assert np.allclose(c.numpy(), a.numpy()[1:].reshape(3, 3))


def test_param_pack_concat():
a = mge.Tensor(np.ones((1,), np.int32))
b = mge.Tensor(np.ones((3, 3), np.int32))
offsets_val = [0, 1, 1, 10]
offsets = mge.Tensor(offsets_val, np.int32)
c = param_pack_concat([a, b], offsets, offsets_val)
assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy())

+ 0
- 16
imperative/python/test/unit/functional/test_tensor.py View File

@@ -359,19 +359,3 @@ def test_copy_d2h():
def test_copy_d2d(): def test_copy_d2d():
copy_test("gpu0", "gpu1") copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1") copy_test("gpu0:0", "gpu0:1")


def test_param_pack_split():
a = tensor(np.ones((10,), np.int32))
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
assert np.allclose(b.numpy(), a.numpy()[1])
assert np.allclose(c.numpy(), a.numpy()[1:].reshape(3, 3))


def test_param_pack_concat():
a = tensor(np.ones((1,), np.int32))
b = tensor(np.ones((3, 3), np.int32))
offsets_val = [0, 1, 1, 10]
offsets = tensor(offsets_val, np.int32)
c = F.param_pack_concat([a, b], offsets, offsets_val)
assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy())

Loading…
Cancel
Save