Browse Source

test(mge/utils): cover all test data

GitOrigin-RevId: e676476b9d
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
a83098890b
2 changed files with 37 additions and 33 deletions
  1. +36
    -32
      imperative/python/test/helpers/utils.py
  2. +1
    -1
      imperative/python/test/unit/functional/test_functional.py

+ 36
- 32
imperative/python/test/helpers/utils.py View File

@@ -94,45 +94,49 @@ def opr_test(

return inp, outp

if len(cases) == 0:
raise ValueError("should give one case at least")
def run_index(index):
inp, outp = get_param(cases, index)
inp_tensor = [make_tensor(inpi, network) for inpi in inp]

if not callable(func):
raise ValueError("the input func should be callable")
if test_trace and not network:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)

inp, outp = get_param(cases, 0)
inp_tensor = [make_tensor(inpi, network) for inpi in inp]
for _ in range(3):
traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)

if test_trace and not network:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
dumped_func = trace(symbolic=True, capture_as_const=True)(func)
dumped_results = dumped_func(*copied_inp, **kwargs)
check_results(dumped_results, outp)

for _ in range(3):
traced_results = traced_func(*copied_inp, **kwargs)
check_results(traced_results, outp)
file = io.BytesIO()
dump_info = dumped_func.dump(file)
file.seek(0)

dumped_func = trace(symbolic=True, capture_as_const=True)(func)
dumped_results = dumped_func(*copied_inp, **kwargs)
check_results(dumped_results, outp)
# arg_name has pattern arg_xxx, xxx is int value
def take_number(arg_name):
return int(arg_name.split("_")[-1])

file = io.BytesIO()
dump_info = dumped_func.dump(file)
file.seek(0)
input_names = dump_info[4]
inps_np = [i.numpy() for i in copied_inp]
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)

# arg_name has pattern arg_xxx, xxx is int value
def take_number(arg_name):
return int(arg_name.split("_")[-1])
# assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)

input_names = dump_info[4]
inps_np = [i.numpy() for i in copied_inp]
input_names.sort(key=take_number)
inp_dict = dict(zip(input_names, inps_np))
infer_cg = cgtools.GraphInference(file)
results = func(*inp_tensor, **kwargs)
check_results(results, outp)

# assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)
if len(cases) == 0:
raise ValueError("should give one case at least")

if not callable(func):
raise ValueError("the input func should be callable")

results = func(*inp_tensor, **kwargs)
check_results(results, outp)
for index in range(len(cases)):
run_index(index)

+ 1
- 1
imperative/python/test/unit/functional/test_functional.py View File

@@ -79,7 +79,7 @@ def test_matinv():
opr_test(
cases,
F.matinv,
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5),
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4),
ref_fn=np.linalg.inv,
)



Loading…
Cancel
Save