|
|
@@ -335,18 +335,18 @@ def test_reshape_shape_inference(is_varnode): |
|
|
|
source = output.shape |
|
|
|
if isinstance(source, tensor): |
|
|
|
source = source.numpy() |
|
|
|
np.testing.assert_equal(source, target) |
|
|
|
np.testing.assert_equal(source, target.shape) |
|
|
|
|
|
|
|
def func(x, target_shape): |
|
|
|
return x.reshape(target_shape) |
|
|
|
|
|
|
|
cases = [ |
|
|
|
{"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]}, |
|
|
|
{"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]}, |
|
|
|
{"input": [x_shape_known, tshp_known], "output": [(2, 2),]}, |
|
|
|
{"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]}, |
|
|
|
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, |
|
|
|
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, |
|
|
|
{"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]}, |
|
|
|
{"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]}, |
|
|
|
{"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]}, |
|
|
|
{"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]}, |
|
|
|
{"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]}, |
|
|
|
{"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]}, |
|
|
|
] |
|
|
|
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) |
|
|
|
if is_varnode: |
|
|
@@ -533,46 +533,30 @@ def test_flatten(is_varnode): |
|
|
|
data0 = np.random.random(data0_shape).astype(np.float32) |
|
|
|
data1 = np.random.random(data1_shape).astype(np.float32) |
|
|
|
|
|
|
|
def compare_fn(x, y): |
|
|
|
assert x._tuple_shape[0] == y |
|
|
|
|
|
|
|
output0 = (2 * 3 * 4 * 5,) |
|
|
|
output1 = (4 * 5 * 6 * 7,) |
|
|
|
cases = [ |
|
|
|
{"input": data0, "output": output0}, |
|
|
|
{"input": data1, "output": output1}, |
|
|
|
{"input": data0, "output": data0.flatten()}, |
|
|
|
{"input": data1, "output": data1.flatten()}, |
|
|
|
] |
|
|
|
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) |
|
|
|
opr_test(cases, F.flatten, network=network) |
|
|
|
|
|
|
|
output0 = (2, 3 * 4 * 5) |
|
|
|
output1 = (4, 5 * 6 * 7) |
|
|
|
cases = [ |
|
|
|
{"input": data0, "output": output0}, |
|
|
|
{"input": data1, "output": output1}, |
|
|
|
{"input": data0, "output": data0.reshape(2, -1)}, |
|
|
|
{"input": data1, "output": data1.reshape(4, -1)}, |
|
|
|
] |
|
|
|
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) |
|
|
|
opr_test(cases, F.flatten, start_axis=1, network=network) |
|
|
|
|
|
|
|
output0 = (2, 3, 4 * 5) |
|
|
|
output1 = (4, 5, 6 * 7) |
|
|
|
cases = [ |
|
|
|
{"input": data0, "output": output0}, |
|
|
|
{"input": data1, "output": output1}, |
|
|
|
{"input": data0, "output": data0.reshape(2, 3, -1)}, |
|
|
|
{"input": data1, "output": data1.reshape(4, 5, -1)}, |
|
|
|
] |
|
|
|
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) |
|
|
|
opr_test(cases, F.flatten, start_axis=2, network=network) |
|
|
|
|
|
|
|
output0 = (2, 3 * 4, 5) |
|
|
|
output1 = (4, 5 * 6, 7) |
|
|
|
cases = [ |
|
|
|
{"input": data0, "output": output0}, |
|
|
|
{"input": data1, "output": output1}, |
|
|
|
{"input": data0, "output": data0.reshape(2, -1, 5)}, |
|
|
|
{"input": data1, "output": data1.reshape(4, -1, 7)}, |
|
|
|
] |
|
|
|
opr_test( |
|
|
|
cases, |
|
|
|
F.flatten, |
|
|
|
compare_fn=compare_fn, |
|
|
|
start_axis=1, |
|
|
|
end_axis=2, |
|
|
|
network=network, |
|
|
|
cases, F.flatten, start_axis=1, end_axis=2, network=network, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
@@ -595,15 +579,22 @@ def test_broadcast(is_varnode): |
|
|
|
output3_shape = (10, 10) |
|
|
|
data3 = np.random.random(input3_shape).astype(np.float32) |
|
|
|
|
|
|
|
def compare_fn(x, y): |
|
|
|
assert x._tuple_shape[0] == y |
|
|
|
|
|
|
|
cases = [ |
|
|
|
{"input": [data1, output1_shape], "output": output1_shape}, |
|
|
|
{"input": [data2, output2_shape], "output": output2_shape}, |
|
|
|
{"input": [data3, output3_shape], "output": output3_shape}, |
|
|
|
{ |
|
|
|
"input": [data1, output1_shape], |
|
|
|
"output": np.broadcast_to(data1, output1_shape), |
|
|
|
}, |
|
|
|
{ |
|
|
|
"input": [data2, output2_shape], |
|
|
|
"output": np.broadcast_to(data2, output2_shape), |
|
|
|
}, |
|
|
|
{ |
|
|
|
"input": [data3, output3_shape], |
|
|
|
"output": np.broadcast_to(data3, output3_shape), |
|
|
|
}, |
|
|
|
] |
|
|
|
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) |
|
|
|
|
|
|
|
opr_test(cases, F.broadcast_to, network=network) |
|
|
|
|
|
|
|
x = F.ones((2, 1, 3)) |
|
|
|
with pytest.raises(RuntimeError): |
|
|
|