|
|
@@ -234,15 +234,10 @@ def setitem(tensor, index, value): |
|
|
|
try_result = try_condtake(tensor, index) |
|
|
|
if len(try_result) == 2: |
|
|
|
index = try_result[1] |
|
|
|
if index.shape[0] == 0: |
|
|
|
return tensor |
|
|
|
tensor = tensor.reshape(-1) |
|
|
|
if not isinstance(value, Tensor): |
|
|
|
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() |
|
|
|
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) |
|
|
|
for v in tensors: |
|
|
|
if len(v.shape) > 0 and v.shape[0] == 0: |
|
|
|
return tensor |
|
|
|
if use_subtensor: |
|
|
|
op = builtin.Subtensor(items=items) |
|
|
|
else: |
|
|
@@ -250,19 +245,17 @@ def setitem(tensor, index, value): |
|
|
|
|
|
|
|
(tmp_result,) = apply(op, tensor, *tensors) |
|
|
|
|
|
|
|
# XXX: broadcast can always be applied even if shapes are equal |
|
|
|
if make_shape_tuple(value.shape) != make_shape_tuple(tmp_result.shape): |
|
|
|
for i in range(min(len(value.shape), len(tmp_result.shape))): |
|
|
|
if ( |
|
|
|
value.shape[-i - 1] != 1 |
|
|
|
and value.shape[-i - 1] != tmp_result.shape[-i - 1] |
|
|
|
): |
|
|
|
raise ValueError( |
|
|
|
"cannot copy tensor with shape {} to subtensor with shape {}".format( |
|
|
|
value.shape, tmp_result.shape |
|
|
|
) |
|
|
|
for i in range(min(len(value.shape), len(tmp_result.shape))): |
|
|
|
if (value.shape[-i - 1] != 1) & ( |
|
|
|
value.shape[-i - 1] != tmp_result.shape[-i - 1] |
|
|
|
): |
|
|
|
raise ValueError( |
|
|
|
"cannot copy tensor with shape {} to subtensor with shape {}".format( |
|
|
|
value.shape, tmp_result.shape |
|
|
|
) |
|
|
|
value = value._broadcast(tmp_result.shape) |
|
|
|
) |
|
|
|
value = value._broadcast(tmp_result.shape) |
|
|
|
|
|
|
|
if use_subtensor: |
|
|
|
op = builtin.SetSubtensor(items=items) |
|
|
|
else: |
|
|
|