|
@@ -28,7 +28,7 @@ from ..core.tensor.utils import ( |
|
|
) |
|
|
) |
|
|
from ..device import get_default_device |
|
|
from ..device import get_default_device |
|
|
from ..tensor import Tensor |
|
|
from ..tensor import Tensor |
|
|
from .elemwise import ceil |
|
|
|
|
|
|
|
|
from .elemwise import ceil, floor_div |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
__all__ = [ |
|
|
"arange", |
|
|
"arange", |
|
@@ -324,52 +324,73 @@ def split(inp, nsplits_or_sections, axis=0): |
|
|
|
|
|
|
|
|
.. testcode:: |
|
|
.. testcode:: |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
from megengine import tensor |
|
|
from megengine import tensor |
|
|
import megengine.functional as F |
|
|
import megengine.functional as F |
|
|
|
|
|
|
|
|
x = tensor(np.random.random((2,3,4,5)), dtype=np.float32) |
|
|
|
|
|
out = F.split(x, 2, axis=3) |
|
|
|
|
|
print(out[0].numpy().shape, out[1].numpy().shape) |
|
|
|
|
|
|
|
|
x = tensor(np.random.random((10, 20)), dtype=np.float32) |
|
|
|
|
|
y = F.split(x, 3) |
|
|
|
|
|
z = F.split(x, [6, 17], axis=1) |
|
|
|
|
|
|
|
|
|
|
|
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): |
|
|
|
|
|
print([tuple(i.shape.numpy().tolist()) for i in y]) |
|
|
|
|
|
print([tuple(i.shape.numpy().tolist()) for i in z]) |
|
|
|
|
|
else: |
|
|
|
|
|
print([i.shape for i in y]) |
|
|
|
|
|
print([i.shape for i in z]) |
|
|
|
|
|
|
|
|
Outputs: |
|
|
Outputs: |
|
|
|
|
|
|
|
|
.. testoutput:: |
|
|
.. testoutput:: |
|
|
|
|
|
|
|
|
(2, 3, 4, 3) (2, 3, 4, 2) |
|
|
|
|
|
|
|
|
[(4, 20), (3, 20), (3, 20)] |
|
|
|
|
|
[(10, 6), (10, 11), (10, 3)] |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
sub_tensors = [] |
|
|
|
|
|
sections = [] |
|
|
|
|
|
|
|
|
|
|
|
def swapaxis(inp, src, dst): |
|
|
|
|
|
if src == dst: |
|
|
|
|
|
return inp |
|
|
|
|
|
shape = [i for i in range(inp.ndim)] |
|
|
|
|
|
shape[src] = dst |
|
|
|
|
|
shape[dst] = src |
|
|
|
|
|
return inp.transpose(shape) |
|
|
|
|
|
|
|
|
|
|
|
inp = swapaxis(inp, 0, axis) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(nsplits_or_sections, int): |
|
|
|
|
|
incr_step = ceil(inp.shape[0] / nsplits_or_sections) |
|
|
|
|
|
nsplits = nsplits_or_sections |
|
|
|
|
|
while nsplits > 0: |
|
|
|
|
|
nsplits -= 1 |
|
|
|
|
|
sections.append(incr_step.astype("int32")) |
|
|
|
|
|
incr_step += nsplits_or_sections |
|
|
|
|
|
else: |
|
|
|
|
|
sections = nsplits_or_sections |
|
|
|
|
|
|
|
|
ndim = len(inp.shape) |
|
|
|
|
|
if axis >= ndim: |
|
|
|
|
|
raise ValueError("Invalid axis {}".format(axis)) |
|
|
|
|
|
|
|
|
st = 0 |
|
|
|
|
|
for se in sections: |
|
|
|
|
|
sub_tensors.append(swapaxis(inp[st:se], axis, 0)) |
|
|
|
|
|
st = se |
|
|
|
|
|
|
|
|
Ntotal = inp.shape[axis] |
|
|
|
|
|
|
|
|
if st < inp.shape[0]: |
|
|
|
|
|
sub_tensors.append(swapaxis(inp[st:], axis, 0)) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
Nsections = len(nsplits_or_sections) + 1 |
|
|
|
|
|
is_array = True |
|
|
|
|
|
except TypeError: |
|
|
|
|
|
Nsections = int(nsplits_or_sections) |
|
|
|
|
|
is_array = False |
|
|
|
|
|
|
|
|
|
|
|
if is_array: |
|
|
|
|
|
div_points = [0] + list(nsplits_or_sections) + [Ntotal] |
|
|
|
|
|
for i in range(1, len(div_points)): |
|
|
|
|
|
if div_points[i - 1] >= div_points[i]: |
|
|
|
|
|
raise ValueError( |
|
|
|
|
|
"Invalid nsplits_or_secions: {}".format(nsplits_or_sections) |
|
|
|
|
|
) |
|
|
|
|
|
else: # scalar |
|
|
|
|
|
if Nsections <= 0: |
|
|
|
|
|
raise ValueError("Number sections must be larger than 0") |
|
|
|
|
|
if Nsections > Ntotal: |
|
|
|
|
|
raise ValueError( |
|
|
|
|
|
"The size {} at dim {} cannot be split into {} sections".format( |
|
|
|
|
|
Ntotal, axis, Nsections |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
div_points = [0] + [ |
|
|
|
|
|
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) |
|
|
|
|
|
] |
|
|
|
|
|
for i in range(2, Nsections + 1): |
|
|
|
|
|
div_points[i] = div_points[i - 1] + div_points[i] |
|
|
|
|
|
|
|
|
|
|
|
sub_tensors = [] |
|
|
|
|
|
for i in range(Nsections): |
|
|
|
|
|
l = div_points[i] |
|
|
|
|
|
r = div_points[i + 1] |
|
|
|
|
|
slices = tuple( |
|
|
|
|
|
[slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1) |
|
|
|
|
|
) |
|
|
|
|
|
sub_tensors.append(inp[slices]) |
|
|
return sub_tensors |
|
|
return sub_tensors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|