@@ -57,6 +57,28 @@ def _transpose(data, axes): | |||||
def _broadcast(inp, shape): | def _broadcast(inp, shape): | ||||
def valid_broadcast(src, tar): | |||||
def failed(): | |||||
raise ValueError( | |||||
"the input shape {} can not be broadcasted to target shape {}".format( | |||||
src, tar | |||||
) | |||||
) | |||||
if isinstance(src, (Tensor, TensorWrapperBase)): | |||||
src = src.numpy() | |||||
if isinstance(tar, (Tensor, TensorWrapperBase)): | |||||
tar = tar.numpy() | |||||
if len(src) > len(tar): | |||||
failed() | |||||
for i in range(min(len(src), len(tar))): | |||||
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: | |||||
failed() | |||||
valid_broadcast(inp.shape, shape) | |||||
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | ||||
(result,) = apply(builtin.Broadcast(), inp, shape) | (result,) = apply(builtin.Broadcast(), inp, shape) | ||||
return result | return result | ||||
@@ -240,7 +240,7 @@ def test_broadcast(): | |||||
output1_shape = (30, 20, 30) | output1_shape = (30, 20, 30) | ||||
data1 = np.random.random(input1_shape).astype(np.float32) | data1 = np.random.random(input1_shape).astype(np.float32) | ||||
input2_shape = (10, 20) | |||||
input2_shape = (10, 1) | |||||
output2_shape = (20, 10, 20) | output2_shape = (20, 10, 20) | ||||
data2 = np.random.random(input2_shape).astype(np.float32) | data2 = np.random.random(input2_shape).astype(np.float32) | ||||
@@ -253,6 +253,16 @@ def test_broadcast(): | |||||
] | ] | ||||
opr_test(cases, F.broadcast, compare_fn=compare_fn) | opr_test(cases, F.broadcast, compare_fn=compare_fn) | ||||
x = F.ones((2, 1, 3)) | |||||
with pytest.raises(ValueError): | |||||
F.broadcast(x, (2, 3, 4)) | |||||
with pytest.raises(ValueError): | |||||
F.broadcast(x, (4, 1, 3)) | |||||
with pytest.raises(ValueError): | |||||
F.broadcast(x, (1, 3)) | |||||
def test_utils_astensor1d(): | def test_utils_astensor1d(): | ||||
reference = tensor(0) | reference = tensor(0) | ||||
@@ -340,3 +340,20 @@ def test_raise_on_trace(): | |||||
step_count += 1 | step_count += 1 | ||||
assert catch_count == 1 | assert catch_count == 1 | ||||
def test_trace_broadcast(): | |||||
for symbolic in [False, True]: | |||||
set_tensor_shape(True) | |||||
x1 = tensor(np.random.randn(3, 1, 1)) | |||||
x2 = tensor(np.random.randn(1, 4, 1)) | |||||
x3 = tensor(np.random.randn(1, 1, 5)) | |||||
@trace(symbolic=symbolic, capture_as_const=True) | |||||
def f(x): | |||||
y = x.broadcast((3, 4, 5)) | |||||
return y | |||||
f(x1) | |||||
f(x2) | |||||
f(x3) |