GitOrigin-RevId: dd330a2a1d
tags/v1.8.0
@@ -99,7 +99,39 @@ def _transpose(data, axes): | |||||
def _broadcast(inp, shape): | def _broadcast(inp, shape): | ||||
shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||||
auto_infer = False | |||||
if isinstance(shape, (list, tuple)): | |||||
shape_tuple = list(shape) | |||||
for i, s in enumerate(shape_tuple): | |||||
if isinstance(s, type(None)): | |||||
if s is None: | |||||
right = i - len(shape_tuple) | |||||
inp_shape = inp._tuple_shape | |||||
if len(inp_shape) + right >= 0: | |||||
shape_tuple[right] = list(inp_shape)[right] | |||||
auto_infer = True | |||||
continue | |||||
else: | |||||
raise ValueError("invalided Broadcast shape") | |||||
else: | |||||
raise ValueError( | |||||
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||||
i, s | |||||
) | |||||
) | |||||
if s < 0: | |||||
raise ValueError( | |||||
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||||
i, s | |||||
) | |||||
) | |||||
if auto_infer: | |||||
shape = tuple(shape_tuple) | |||||
try: | |||||
shape_tuple = make_shape_tuple(shape) | |||||
except ValueError: | |||||
shape_tuple = shape | |||||
shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device) | |||||
(result,) = apply(builtin.Broadcast(), inp, shape) | (result,) = apply(builtin.Broadcast(), inp, shape) | ||||
return result | return result | ||||
@@ -223,6 +223,34 @@ def test_reshape(is_varnode): | |||||
np.testing.assert_equal(yy.numpy(), y) | np.testing.assert_equal(yy.numpy(), y) | ||||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||||
def test_broadcast_auto_infer(is_varnode): | |||||
if is_varnode: | |||||
network = Network() | |||||
else: | |||||
network = None | |||||
x = np.random.random((1, 2, 3)).astype(np.float32) | |||||
xx = make_tensor(x, network) | |||||
for shape in [ | |||||
(1, 2, 3), | |||||
(1, None, 3), | |||||
]: | |||||
yy = F.broadcast_to(xx, shape) | |||||
np.testing.assert_equal(yy.numpy(), x) | |||||
with pytest.raises(ValueError): | |||||
F.broadcast_to(xx, (1, -1, 3)) | |||||
with pytest.raises(ValueError): | |||||
F.broadcast_to(xx, (None, 1, 2, 3)) | |||||
F.broadcast_to(xx, (1, None, 2, 3)) | |||||
t = tensor(2, dtype=np.int32) | |||||
F.broadcast_to(xx, (t, None, 2, 3)) | |||||
@pytest.mark.parametrize("is_trace", [True, False]) | @pytest.mark.parametrize("is_trace", [True, False]) | ||||
def test_reshape_on_empty_tensor(is_trace): | def test_reshape_on_empty_tensor(is_trace): | ||||
input1_shape = (100, 0, 1) | input1_shape = (100, 0, 1) | ||||