Browse Source

fix(mge/tensor): fix valid_broadcast

GitOrigin-RevId: 562b7664e2
release-1.1
Megvii Engine Team 4 years ago
parent
commit
9f4bffbd00
3 changed files with 19 additions and 4 deletions
  1. +1
    -1
      imperative/python/megengine/core/tensor/indexing.py
  2. +3
    -3
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  3. +15
    -0
      imperative/python/test/unit/test_tracing.py

+ 1
- 1
imperative/python/megengine/core/tensor/indexing.py View File

@@ -173,7 +173,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
item.append(True) item.append(True)
v = get_index(v) v = get_index(v)
assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( assert np.issubdtype(v.dtype, np.integer) or np.issubdtype(
v.dtype, np.bool
v.dtype, np.bool_
), "var type in the subscript must be int or bool" ), "var type in the subscript must be int or bool"
tensors.append(v) tensors.append(v)




+ 3
- 3
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -65,10 +65,10 @@ def _broadcast(inp, shape):
) )
) )


if isinstance(src, (Tensor, TensorWrapperBase)):
if isinstance(src, (TensorBase, TensorWrapperBase)):
src = src.numpy() src = src.numpy()


if isinstance(tar, (Tensor, TensorWrapperBase)):
if isinstance(tar, (TensorBase, TensorWrapperBase)):
tar = tar.numpy() tar = tar.numpy()


if len(src) > len(tar): if len(src) > len(tar):
@@ -78,8 +78,8 @@ def _broadcast(inp, shape):
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]:
failed() failed()


valid_broadcast(inp.shape, shape)
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device)
valid_broadcast(inp.shape, shape)
(result,) = apply(builtin.Broadcast(), inp, shape) (result,) = apply(builtin.Broadcast(), inp, shape)
return result return result




+ 15
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -379,3 +379,18 @@ def test_trace_nms():
f(*make_inputs(10)) f(*make_inputs(10))
f(*make_inputs(20)) f(*make_inputs(20))
f(*make_inputs(30)) f(*make_inputs(30))


def test_trace_valid_broadcast():
set_tensor_shape(True)
x1 = tensor(np.random.randn(1, 1))
x2 = tensor(np.random.randn(1, 2))
shape = (tensor([2]), tensor([2]))

@trace(symbolic=False)
def f(x, shape):
y = F.broadcast_to(x, shape)
return y

f(x1, shape)
f(x2, shape)

Loading…
Cancel
Save