From 68cde8734e735361d40a0a6cf48b920c0cb8ff45 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 11 Jan 2022 14:29:41 +0800 Subject: [PATCH] fix(mge/imperative): support broadcast with None GitOrigin-RevId: dd330a2a1dc603ea52a655b350ee1a421015e7d7 --- .../python/megengine/core/tensor/array_method.py | 34 +++++++++++++++++++++- .../python/test/unit/functional/test_tensor.py | 28 ++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index fe15d2bf..a7c086c8 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -99,7 +99,39 @@ def _transpose(data, axes): def _broadcast(inp, shape): - shape = astensor1d(shape, inp, dtype="int32", device=inp.device) + auto_infer = False + if isinstance(shape, (list, tuple)): + shape_tuple = list(shape) + for i, s in enumerate(shape_tuple): + if isinstance(s, type(None)): + if s is None: + right = i - len(shape_tuple) + inp_shape = inp._tuple_shape + if len(inp_shape) + right >= 0: + shape_tuple[right] = list(inp_shape)[right] + auto_infer = True + continue + else: + raise ValueError("invalided Broadcast shape") + else: + raise ValueError( + "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( + i, s + ) + ) + if s < 0: + raise ValueError( + "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( + i, s + ) + ) + if auto_infer: + shape = tuple(shape_tuple) + try: + shape_tuple = make_shape_tuple(shape) + except ValueError: + shape_tuple = shape + shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), inp, shape) return result diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index bc764537..8f4f8d25 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -223,6 +223,34 @@ def test_reshape(is_varnode): np.testing.assert_equal(yy.numpy(), y) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_broadcast_auto_infer(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x = np.random.random((1, 2, 3)).astype(np.float32) + xx = make_tensor(x, network) + + for shape in [ + (1, 2, 3), + (1, None, 3), + ]: + yy = F.broadcast_to(xx, shape) + np.testing.assert_equal(yy.numpy(), x) + + with pytest.raises(ValueError): + F.broadcast_to(xx, (1, -1, 3)) + + with pytest.raises(ValueError): + F.broadcast_to(xx, (None, 1, 2, 3)) + + F.broadcast_to(xx, (1, None, 2, 3)) + t = tensor(2, dtype=np.int32) + F.broadcast_to(xx, (t, None, 2, 3)) + + @pytest.mark.parametrize("is_trace", [True, False]) def test_reshape_on_empty_tensor(is_trace): input1_shape = (100, 0, 1)