Browse Source

fix(imperative/ops): fix infer_output_attrs_fallible for reshape

GitOrigin-RevId: a93567d79a
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
cbff4d7c1a
2 changed files with 32 additions and 1 deletions
  1. +27
    -0
      imperative/python/test/unit/functional/test_tensor.py
  2. +5
    -1
      imperative/src/impl/ops/broadcast.cpp

+ 27
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -122,6 +122,33 @@ def test_reshape():
np.testing.assert_equal(yy.numpy(), y) np.testing.assert_equal(yy.numpy(), y)




def test_reshape_shape_inference():
x_shape_known = tensor([1, 2, 3, 4], dtype="float32")
x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum())
tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known)
tshp_known = astensor1d((2, 2), x_shape_known)
tshp_known_unspec = astensor1d((2, -1), x_shape_known)

def check_shape(output, target):
source = output.shape
if isinstance(source, tensor):
source = source.numpy()
np.testing.assert_equal(source, target)

def func(x, target_shape):
return x.reshape(target_shape)

cases = [
{"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
]
opr_test(cases, func, compare_fn=check_shape, test_trace=True)


def test_squeeze(): def test_squeeze():
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = tensor(x) xx = tensor(x)


+ 5
- 1
imperative/src/impl/ops/broadcast.cpp View File

@@ -115,9 +115,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
} }
mgb_assert( mgb_assert(
tshp.layout.ndim == 1, tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
"target shape of Reshape expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim); tshp.layout.ndim);


if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}

size_t target_ndim = tshp.layout.shape[0]; size_t target_ndim = tshp.layout.shape[0];
out_shape.ndim = target_ndim; out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>(); auto* ptr = tshp.value.ptr<dt_int32>();


Loading…
Cancel
Save