@@ -173,7 +173,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
item.append(True) | item.append(True) | ||||
v = get_index(v) | v = get_index(v) | ||||
assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( | assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( | ||||
v.dtype, np.bool | |||||
v.dtype, np.bool_ | |||||
), "var type in the subscript must be int or bool" | ), "var type in the subscript must be int or bool" | ||||
tensors.append(v) | tensors.append(v) | ||||
@@ -65,10 +65,10 @@ def _broadcast(inp, shape): | |||||
) | ) | ||||
) | ) | ||||
if isinstance(src, (Tensor, TensorWrapperBase)): | |||||
if isinstance(src, (TensorBase, TensorWrapperBase)): | |||||
src = src.numpy() | src = src.numpy() | ||||
if isinstance(tar, (Tensor, TensorWrapperBase)): | |||||
if isinstance(tar, (TensorBase, TensorWrapperBase)): | |||||
tar = tar.numpy() | tar = tar.numpy() | ||||
if len(src) > len(tar): | if len(src) > len(tar): | ||||
@@ -78,8 +78,8 @@ def _broadcast(inp, shape): | |||||
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: | if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: | ||||
failed() | failed() | ||||
valid_broadcast(inp.shape, shape) | |||||
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | ||||
valid_broadcast(inp.shape, shape) | |||||
(result,) = apply(builtin.Broadcast(), inp, shape) | (result,) = apply(builtin.Broadcast(), inp, shape) | ||||
return result | return result | ||||
@@ -379,3 +379,18 @@ def test_trace_nms(): | |||||
f(*make_inputs(10)) | f(*make_inputs(10)) | ||||
f(*make_inputs(20)) | f(*make_inputs(20)) | ||||
f(*make_inputs(30)) | f(*make_inputs(30)) | ||||
def test_trace_valid_broadcast(): | |||||
set_tensor_shape(True) | |||||
x1 = tensor(np.random.randn(1, 1)) | |||||
x2 = tensor(np.random.randn(1, 2)) | |||||
shape = (tensor([2]), tensor([2])) | |||||
@trace(symbolic=False) | |||||
def f(x, shape): | |||||
y = F.broadcast_to(x, shape) | |||||
return y | |||||
f(x1, shape) | |||||
f(x2, shape) |