Browse Source

fix(mge/imperative): support broadcast with None

GitOrigin-RevId: dd330a2a1d
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
68cde8734e
2 changed files with 61 additions and 1 deletions
  1. +33
    -1
      imperative/python/megengine/core/tensor/array_method.py
  2. +28
    -0
      imperative/python/test/unit/functional/test_tensor.py

+ 33
- 1
imperative/python/megengine/core/tensor/array_method.py View File

@@ -99,7 +99,39 @@ def _transpose(data, axes):




def _broadcast(inp, shape): def _broadcast(inp, shape):
shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
auto_infer = False
if isinstance(shape, (list, tuple)):
shape_tuple = list(shape)
for i, s in enumerate(shape_tuple):
if isinstance(s, type(None)):
if s is None:
right = i - len(shape_tuple)
inp_shape = inp._tuple_shape
if len(inp_shape) + right >= 0:
shape_tuple[right] = list(inp_shape)[right]
auto_infer = True
continue
else:
raise ValueError("invalided Broadcast shape")
else:
raise ValueError(
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format(
i, s
)
)
if s < 0:
raise ValueError(
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format(
i, s
)
)
if auto_infer:
shape = tuple(shape_tuple)
try:
shape_tuple = make_shape_tuple(shape)
except ValueError:
shape_tuple = shape
shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape) (result,) = apply(builtin.Broadcast(), inp, shape)
return result return result




+ 28
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -223,6 +223,34 @@ def test_reshape(is_varnode):
np.testing.assert_equal(yy.numpy(), y) np.testing.assert_equal(yy.numpy(), y)




@pytest.mark.parametrize("is_varnode", [True, False])
def test_broadcast_auto_infer(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.random.random((1, 2, 3)).astype(np.float32)
xx = make_tensor(x, network)

for shape in [
(1, 2, 3),
(1, None, 3),
]:
yy = F.broadcast_to(xx, shape)
np.testing.assert_equal(yy.numpy(), x)

with pytest.raises(ValueError):
F.broadcast_to(xx, (1, -1, 3))

with pytest.raises(ValueError):
F.broadcast_to(xx, (None, 1, 2, 3))

F.broadcast_to(xx, (1, None, 2, 3))
t = tensor(2, dtype=np.int32)
F.broadcast_to(xx, (t, None, 2, 3))


@pytest.mark.parametrize("is_trace", [True, False]) @pytest.mark.parametrize("is_trace", [True, False])
def test_reshape_on_empty_tensor(is_trace): def test_reshape_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1) input1_shape = (100, 0, 1)


Loading…
Cancel
Save