|
@@ -28,14 +28,14 @@ from .group import WORLD, Group, group_barrier, is_distributed, override_backend |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def param_pack_split(inp: Tensor, offsets: list, shapes: list): |
|
|
def param_pack_split(inp: Tensor, offsets: list, shapes: list): |
|
|
r"""Returns split tensor to tensor list as offsets and shapes described, |
|
|
|
|
|
|
|
|
r"""Returns split tensor to list of tensors as offsets and shapes described, |
|
|
only used for ``parampack``. |
|
|
only used for ``parampack``. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
inp: input tensor. |
|
|
inp: input tensor. |
|
|
offsets: offsets of outputs, length of `2 * n`, |
|
|
|
|
|
while n is tensor nums you want to split, |
|
|
|
|
|
format `[begin0, end0, begin1, end1]`. |
|
|
|
|
|
|
|
|
offsets: offsets of outputs, length of ``2 * n``, |
|
|
|
|
|
where ``n`` is the number of tensor you want to split, |
|
|
|
|
|
format ``[begin0, end0, begin1, end1]``. |
|
|
shapes: tensor shapes of outputs. |
|
|
shapes: tensor shapes of outputs. |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
@@ -43,25 +43,14 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): |
|
|
|
|
|
|
|
|
Examples: |
|
|
Examples: |
|
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from megengine import tensor |
|
|
|
|
|
from megengine.distributed.helper import param_pack_split |
|
|
|
|
|
|
|
|
|
|
|
a = tensor(np.ones((10,), np.int32)) |
|
|
|
|
|
b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) |
|
|
|
|
|
print(b.numpy()) |
|
|
|
|
|
print(c.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
|
|
|
|
[1] |
|
|
|
|
|
[[1 1 1] |
|
|
|
|
|
[1 1 1] |
|
|
|
|
|
[1 1 1]] |
|
|
|
|
|
|
|
|
>>> a = F.ones(10) |
|
|
|
|
|
>>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) |
|
|
|
|
|
>>> b |
|
|
|
|
|
Tensor([1.], device=xpux:0) |
|
|
|
|
|
>>> c |
|
|
|
|
|
Tensor([[1. 1. 1.] |
|
|
|
|
|
[1. 1. 1.] |
|
|
|
|
|
[1. 1. 1.]], device=xpux:0) |
|
|
""" |
|
|
""" |
|
|
op = ParamPackSplit() |
|
|
op = ParamPackSplit() |
|
|
op.offsets = offsets |
|
|
op.offsets = offsets |
|
@@ -74,34 +63,22 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): |
|
|
r"""Returns concated tensor, only used for ``parampack``. |
|
|
r"""Returns concated tensor, only used for ``parampack``. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
inps: input tensors. |
|
|
|
|
|
|
|
|
inps: list of input tensors. |
|
|
offsets: device value of offsets. |
|
|
offsets: device value of offsets. |
|
|
offsets_val: offsets of inputs, length of `2 * n`, |
|
|
|
|
|
format `[begin0, end0, begin1, end1]`. |
|
|
|
|
|
|
|
|
offsets_val: offsets of inputs, length of ``2 * n``, |
|
|
|
|
|
format ``[begin0, end0, begin1, end1]``. |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
concated tensor. |
|
|
concated tensor. |
|
|
|
|
|
|
|
|
Examples: |
|
|
Examples: |
|
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from megengine import tensor |
|
|
|
|
|
from megengine.distributed.helper import param_pack_concat |
|
|
|
|
|
|
|
|
|
|
|
a = tensor(np.ones((1,), np.int32)) |
|
|
|
|
|
b = tensor(np.ones((3, 3), np.int32)) |
|
|
|
|
|
offsets_val = [0, 1, 1, 10] |
|
|
|
|
|
offsets = tensor(offsets_val, np.int32) |
|
|
|
|
|
c = param_pack_concat([a, b], offsets, offsets_val) |
|
|
|
|
|
print(c.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
|
|
|
|
[1 1 1 1 1 1 1 1 1 1] |
|
|
|
|
|
|
|
|
>>> a = F.ones(1) |
|
|
|
|
|
>>> b = F.ones((3, 3)) |
|
|
|
|
|
>>> offsets_val = [0, 1, 1, 10] |
|
|
|
|
|
>>> offsets = Tensor(offsets_val) |
|
|
|
|
|
>>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val) # doctest: +SKIP |
|
|
|
|
|
Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0) |
|
|
""" |
|
|
""" |
|
|
op = ParamPackConcat() |
|
|
op = ParamPackConcat() |
|
|
op.offsets = offsets_val |
|
|
op.offsets = offsets_val |
|
|