Browse Source

fix(mge/functional): fix tensor split

GitOrigin-RevId: 0a112ab0bd
release-1.2
Megvii Engine Team 4 years ago
parent
commit
5c7d48cdb9
3 changed files with 78 additions and 37 deletions
  1. +1
    -1
      imperative/python/megengine/functional/elemwise.py
  2. +54
    -33
      imperative/python/megengine/functional/tensor.py
  3. +23
    -3
      imperative/python/test/unit/functional/test_tensor.py

+ 1
- 1
imperative/python/megengine/functional/elemwise.py View File

@@ -158,7 +158,7 @@ def div(x, y):

def floor_div(x, y):
"""Element-wise `floor(x / y)`."""
return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIVIDE)
return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV)


def neg(x):


+ 54
- 33
imperative/python/megengine/functional/tensor.py View File

@@ -28,7 +28,7 @@ from ..core.tensor.utils import (
)
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil
from .elemwise import ceil, floor_div

__all__ = [
"arange",
@@ -324,52 +324,73 @@ def split(inp, nsplits_or_sections, axis=0):

.. testcode::

import os
import numpy as np
from megengine import tensor
import megengine.functional as F

x = tensor(np.random.random((2,3,4,5)), dtype=np.float32)
out = F.split(x, 2, axis=3)
print(out[0].numpy().shape, out[1].numpy().shape)
x = tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3)
z = F.split(x, [6, 17], axis=1)
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
print([tuple(i.shape.numpy().tolist()) for i in y])
print([tuple(i.shape.numpy().tolist()) for i in z])
else:
print([i.shape for i in y])
print([i.shape for i in z])

Outputs:

.. testoutput::

(2, 3, 4, 3) (2, 3, 4, 2)
[(4, 20), (3, 20), (3, 20)]
[(10, 6), (10, 11), (10, 3)]

"""
sub_tensors = []
sections = []

def swapaxis(inp, src, dst):
if src == dst:
return inp
shape = [i for i in range(inp.ndim)]
shape[src] = dst
shape[dst] = src
return inp.transpose(shape)

inp = swapaxis(inp, 0, axis)

if isinstance(nsplits_or_sections, int):
incr_step = ceil(inp.shape[0] / nsplits_or_sections)
nsplits = nsplits_or_sections
while nsplits > 0:
nsplits -= 1
sections.append(incr_step.astype("int32"))
incr_step += nsplits_or_sections
else:
sections = nsplits_or_sections
ndim = len(inp.shape)
if axis >= ndim:
raise ValueError("Invalid axis {}".format(axis))

st = 0
for se in sections:
sub_tensors.append(swapaxis(inp[st:se], axis, 0))
st = se
Ntotal = inp.shape[axis]

if st < inp.shape[0]:
sub_tensors.append(swapaxis(inp[st:], axis, 0))
try:
Nsections = len(nsplits_or_sections) + 1
is_array = True
except TypeError:
Nsections = int(nsplits_or_sections)
is_array = False

if is_array:
div_points = [0] + list(nsplits_or_sections) + [Ntotal]
for i in range(1, len(div_points)):
if div_points[i - 1] >= div_points[i]:
raise ValueError(
"Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
)
else: # scalar
if Nsections <= 0:
raise ValueError("Number sections must be larger than 0")
if Nsections > Ntotal:
raise ValueError(
"The size {} at dim {} cannot be split into {} sections".format(
Ntotal, axis, Nsections
)
)
div_points = [0] + [
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
]
for i in range(2, Nsections + 1):
div_points[i] = div_points[i - 1] + div_points[i]

sub_tensors = []
for i in range(Nsections):
l = div_points[i]
r = div_points[i + 1]
slices = tuple(
[slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1)
)
sub_tensors.append(inp[slices])
return sub_tensors




+ 23
- 3
imperative/python/test/unit/functional/test_tensor.py View File

@@ -77,14 +77,34 @@ def test_stack():

def test_split():
data = np.random.random((2, 3, 4, 5)).astype(np.float32)
mge_out1 = F.split(tensor(data), 2, axis=3)
mge_out2 = F.split(tensor(data), [3, 5], axis=3)
inp = tensor(data)

mge_out0 = F.split(inp, 2, axis=3)
mge_out1 = F.split(inp, [3], axis=3)

np_out = np.split(data, [3, 5], axis=3)

np.testing.assert_equal(mge_out1[0].numpy(), mge_out2[0].numpy())
assert len(mge_out0) == 2
assert len(mge_out1) == 2

np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])

np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])

try:
F.split(inp, 4)
assert False
except ValueError as e:
pass

try:
F.split(inp, [3, 3, 5], axis=3)
assert False
except ValueError as e:
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"


def test_reshape():
x = np.arange(6, dtype="float32")


Loading…
Cancel
Save