Browse Source

test(opr): add scalar check for opr_test

GitOrigin-RevId: dcfd7ad5d6
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
27346b0b65
3 changed files with 56 additions and 60 deletions
  1. +13
    -9
      imperative/python/test/helpers/utils.py
  2. +10
    -9
      imperative/python/test/unit/functional/test_functional.py
  3. +33
    -42
      imperative/python/test/unit/functional/test_tensor.py

+ 13
- 9
imperative/python/test/helpers/utils.py View File

@@ -11,12 +11,12 @@ from megengine.utils.network_node import VarNode




def _default_compare_fn(x, y): def _default_compare_fn(x, y):
if isinstance(x, np.ndarray):
np.testing.assert_allclose(x, y, rtol=1e-6)
elif isinstance(x, tensor):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
else:
np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6)
if isinstance(x, tensor):
x = x.numpy()
elif not isinstance(x, np.ndarray):
x = get_var_value(x)
assert isinstance(x, np.ndarray)
np.testing.assert_allclose(x, y, rtol=1e-6)




def make_tensor(x, network=None, device=None): def make_tensor(x, network=None, device=None):
@@ -69,12 +69,16 @@ def opr_test(


""" """


def check_results(results, expected):
def check_results(results, expected, check_shape=True):
if not isinstance(results, (tuple, list)): if not isinstance(results, (tuple, list)):
results = (results,) results = (results,)
for r, e in zip(results, expected): for r, e in zip(results, expected):
if not isinstance(r, (tensor, VarNode)): if not isinstance(r, (tensor, VarNode)):
r = tensor(r) r = tensor(r)
if check_shape:
r_shape = r.numpy().shape
e_shape = e.shape if isinstance(e, np.ndarray) else ()
assert r_shape == e_shape
compare_fn(r, e) compare_fn(r, e)


def get_param(cases, idx): def get_param(cases, idx):
@@ -127,10 +131,10 @@ def opr_test(


# assume #outputs == 1 # assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)
check_results(loaded_results, outp, check_shape=False) # scalar info lost


results = func(*inp_tensor, **kwargs) results = func(*inp_tensor, **kwargs)
check_results(results, outp)
check_results(results, outp, check_shape=(network is None))


if len(cases) == 0: if len(cases) == 0:
raise ValueError("should give one case at least") raise ValueError("should give one case at least")


+ 10
- 9
imperative/python/test/unit/functional/test_functional.py View File

@@ -39,12 +39,6 @@ def test_where():
xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)


cases = [
{"input": [maskv0, xv0, yv0]},
{"input": [maskv1, xv1, yv1]},
]
opr_test(cases, F.where, ref_fn=np.where, test_trace=False)

maskv2 = np.array([1, 1, 1], dtype=np.bool_) maskv2 = np.array([1, 1, 1], dtype=np.bool_)
xv2 = np.array([1, 3, 2], dtype=np.float32) xv2 = np.array([1, 3, 2], dtype=np.float32)
yv2 = np.array([5, 6, 9], dtype=np.float32) yv2 = np.array([5, 6, 9], dtype=np.float32)
@@ -53,11 +47,18 @@ def test_where():
xv3 = np.array([1, 3, 2], dtype=np.float32) xv3 = np.array([1, 3, 2], dtype=np.float32)
yv3 = np.array([5, 6, 9], dtype=np.float32) yv3 = np.array([5, 6, 9], dtype=np.float32)


maskv4 = np.array(1, dtype=np.bool_)
xv4 = np.array(1, dtype=np.float32)
yv4 = np.array(0, dtype=np.float32)

cases = [ cases = [
{"input": [maskv0, xv0, yv0]},
{"input": [maskv1, xv1, yv1]},
{"input": [maskv2, xv2, yv2]}, {"input": [maskv2, xv2, yv2]},
{"input": [maskv3, xv3, yv3]}, {"input": [maskv3, xv3, yv3]},
{"input": [maskv4, xv4, yv4]},
] ]
opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
opr_test(cases, F.where, ref_fn=np.where, test_trace=True)




def test_dropout(): def test_dropout():
@@ -618,12 +619,12 @@ def test_binary_cross_entropy():
np.random.seed(123) np.random.seed(123)
data1 = np.random.uniform(size=data1_shape).astype(np.float32) data1 = np.random.uniform(size=data1_shape).astype(np.float32)
label1 = np.random.uniform(size=label1_shape).astype(np.float32) label1 = np.random.uniform(size=label1_shape).astype(np.float32)
expect1 = np.array([0.6361], dtype=np.float32)
expect1 = np.array(0.6361, dtype=np.float32)


np.random.seed(123) np.random.seed(123)
data2 = np.random.uniform(size=data2_shape).astype(np.float32) data2 = np.random.uniform(size=data2_shape).astype(np.float32)
label2 = np.random.uniform(size=label2_shape).astype(np.float32) label2 = np.random.uniform(size=label2_shape).astype(np.float32)
expect2 = np.array([0.6750], dtype=np.float32)
expect2 = np.array(0.6750, dtype=np.float32)


cases = [ cases = [
{"input": [data1, label1], "output": expect1,}, {"input": [data1, label1], "output": expect1,},


+ 33
- 42
imperative/python/test/unit/functional/test_tensor.py View File

@@ -335,18 +335,18 @@ def test_reshape_shape_inference(is_varnode):
source = output.shape source = output.shape
if isinstance(source, tensor): if isinstance(source, tensor):
source = source.numpy() source = source.numpy()
np.testing.assert_equal(source, target)
np.testing.assert_equal(source, target.shape)


def func(x, target_shape): def func(x, target_shape):
return x.reshape(target_shape) return x.reshape(target_shape)


cases = [ cases = [
{"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
] ]
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
if is_varnode: if is_varnode:
@@ -533,46 +533,30 @@ def test_flatten(is_varnode):
data0 = np.random.random(data0_shape).astype(np.float32) data0 = np.random.random(data0_shape).astype(np.float32)
data1 = np.random.random(data1_shape).astype(np.float32) data1 = np.random.random(data1_shape).astype(np.float32)


def compare_fn(x, y):
assert x._tuple_shape[0] == y

output0 = (2 * 3 * 4 * 5,)
output1 = (4 * 5 * 6 * 7,)
cases = [ cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.flatten()},
{"input": data1, "output": data1.flatten()},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
opr_test(cases, F.flatten, network=network)


output0 = (2, 3 * 4 * 5)
output1 = (4, 5 * 6 * 7)
cases = [ cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.reshape(2, -1)},
{"input": data1, "output": data1.reshape(4, -1)},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
opr_test(cases, F.flatten, start_axis=1, network=network)


output0 = (2, 3, 4 * 5)
output1 = (4, 5, 6 * 7)
cases = [ cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.reshape(2, 3, -1)},
{"input": data1, "output": data1.reshape(4, 5, -1)},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
opr_test(cases, F.flatten, start_axis=2, network=network)


output0 = (2, 3 * 4, 5)
output1 = (4, 5 * 6, 7)
cases = [ cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.reshape(2, -1, 5)},
{"input": data1, "output": data1.reshape(4, -1, 7)},
] ]
opr_test( opr_test(
cases,
F.flatten,
compare_fn=compare_fn,
start_axis=1,
end_axis=2,
network=network,
cases, F.flatten, start_axis=1, end_axis=2, network=network,
) )




@@ -595,15 +579,22 @@ def test_broadcast(is_varnode):
output3_shape = (10, 10) output3_shape = (10, 10)
data3 = np.random.random(input3_shape).astype(np.float32) data3 = np.random.random(input3_shape).astype(np.float32)


def compare_fn(x, y):
assert x._tuple_shape[0] == y

cases = [ cases = [
{"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape},
{"input": [data3, output3_shape], "output": output3_shape},
{
"input": [data1, output1_shape],
"output": np.broadcast_to(data1, output1_shape),
},
{
"input": [data2, output2_shape],
"output": np.broadcast_to(data2, output2_shape),
},
{
"input": [data3, output3_shape],
"output": np.broadcast_to(data3, output3_shape),
},
] ]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)

opr_test(cases, F.broadcast_to, network=network)


x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):


Loading…
Cancel
Save