diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index b70e2da1..78360fac 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -28,14 +28,14 @@ from .group import WORLD, Group, group_barrier, is_distributed, override_backend def param_pack_split(inp: Tensor, offsets: list, shapes: list): - r"""Returns split tensor to tensor list as offsets and shapes described, + r"""Returns split tensor to list of tensors as offsets and shapes described, only used for ``parampack``. Args: inp: input tensor. - offsets: offsets of outputs, length of `2 * n`, - while n is tensor nums you want to split, - format `[begin0, end0, begin1, end1]`. + offsets: offsets of outputs, length of ``2 * n``, + where ``n`` is the number of tensor you want to split, + format ``[begin0, end0, begin1, end1]``. shapes: tensor shapes of outputs. Returns: @@ -43,25 +43,14 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): Examples: - .. testcode:: - - import numpy as np - from megengine import tensor - from megengine.distributed.helper import param_pack_split - - a = tensor(np.ones((10,), np.int32)) - b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) - print(b.numpy()) - print(c.numpy()) - - Outputs: - - .. testoutput:: - - [1] - [[1 1 1] - [1 1 1] - [1 1 1]] + >>> a = F.ones(10) + >>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) + >>> b + Tensor([1.], device=xpux:0) + >>> c + Tensor([[1. 1. 1.] + [1. 1. 1.] + [1. 1. 1.]], device=xpux:0) """ op = ParamPackSplit() op.offsets = offsets @@ -74,34 +63,22 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): r"""Returns concated tensor, only used for ``parampack``. Args: - inps: input tensors. + inps: list of input tensors. offsets: device value of offsets. - offsets_val: offsets of inputs, length of `2 * n`, - format `[begin0, end0, begin1, end1]`. + offsets_val: offsets of inputs, length of ``2 * n``, + format ``[begin0, end0, begin1, end1]``. Returns: concated tensor. Examples: - .. testcode:: - - import numpy as np - from megengine import tensor - from megengine.distributed.helper import param_pack_concat - - a = tensor(np.ones((1,), np.int32)) - b = tensor(np.ones((3, 3), np.int32)) - offsets_val = [0, 1, 1, 10] - offsets = tensor(offsets_val, np.int32) - c = param_pack_concat([a, b], offsets, offsets_val) - print(c.numpy()) - - Outputs: - - .. testoutput:: - - [1 1 1 1 1 1 1 1 1 1] + >>> a = F.ones(1) + >>> b = F.ones((3, 3)) + >>> offsets_val = [0, 1, 1, 10] + >>> offsets = Tensor(offsets_val) + >>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val) # doctest: +SKIP + Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0) """ op = ParamPackConcat() op.offsets = offsets_val