|
|
@@ -94,45 +94,49 @@ def opr_test( |
|
|
|
|
|
|
|
return inp, outp |
|
|
|
|
|
|
|
if len(cases) == 0: |
|
|
|
raise ValueError("should give one case at least") |
|
|
|
def run_index(index): |
|
|
|
inp, outp = get_param(cases, index) |
|
|
|
inp_tensor = [make_tensor(inpi, network) for inpi in inp] |
|
|
|
|
|
|
|
if not callable(func): |
|
|
|
raise ValueError("the input func should be callable") |
|
|
|
if test_trace and not network: |
|
|
|
copied_inp = inp_tensor.copy() |
|
|
|
for symbolic in [False, True]: |
|
|
|
traced_func = trace(symbolic=symbolic)(func) |
|
|
|
|
|
|
|
inp, outp = get_param(cases, 0) |
|
|
|
inp_tensor = [make_tensor(inpi, network) for inpi in inp] |
|
|
|
for _ in range(3): |
|
|
|
traced_results = traced_func(*copied_inp, **kwargs) |
|
|
|
check_results(traced_results, outp) |
|
|
|
|
|
|
|
if test_trace and not network: |
|
|
|
copied_inp = inp_tensor.copy() |
|
|
|
for symbolic in [False, True]: |
|
|
|
traced_func = trace(symbolic=symbolic)(func) |
|
|
|
dumped_func = trace(symbolic=True, capture_as_const=True)(func) |
|
|
|
dumped_results = dumped_func(*copied_inp, **kwargs) |
|
|
|
check_results(dumped_results, outp) |
|
|
|
|
|
|
|
for _ in range(3): |
|
|
|
traced_results = traced_func(*copied_inp, **kwargs) |
|
|
|
check_results(traced_results, outp) |
|
|
|
file = io.BytesIO() |
|
|
|
dump_info = dumped_func.dump(file) |
|
|
|
file.seek(0) |
|
|
|
|
|
|
|
dumped_func = trace(symbolic=True, capture_as_const=True)(func) |
|
|
|
dumped_results = dumped_func(*copied_inp, **kwargs) |
|
|
|
check_results(dumped_results, outp) |
|
|
|
# arg_name has pattern arg_xxx, xxx is int value |
|
|
|
def take_number(arg_name): |
|
|
|
return int(arg_name.split("_")[-1]) |
|
|
|
|
|
|
|
file = io.BytesIO() |
|
|
|
dump_info = dumped_func.dump(file) |
|
|
|
file.seek(0) |
|
|
|
input_names = dump_info[4] |
|
|
|
inps_np = [i.numpy() for i in copied_inp] |
|
|
|
input_names.sort(key=take_number) |
|
|
|
inp_dict = dict(zip(input_names, inps_np)) |
|
|
|
infer_cg = cgtools.GraphInference(file) |
|
|
|
|
|
|
|
# arg_name has pattern arg_xxx, xxx is int value |
|
|
|
def take_number(arg_name): |
|
|
|
return int(arg_name.split("_")[-1]) |
|
|
|
# assume #outputs == 1 |
|
|
|
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] |
|
|
|
check_results(loaded_results, outp) |
|
|
|
|
|
|
|
input_names = dump_info[4] |
|
|
|
inps_np = [i.numpy() for i in copied_inp] |
|
|
|
input_names.sort(key=take_number) |
|
|
|
inp_dict = dict(zip(input_names, inps_np)) |
|
|
|
infer_cg = cgtools.GraphInference(file) |
|
|
|
results = func(*inp_tensor, **kwargs) |
|
|
|
check_results(results, outp) |
|
|
|
|
|
|
|
# assume #outputs == 1 |
|
|
|
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] |
|
|
|
check_results(loaded_results, outp) |
|
|
|
if len(cases) == 0: |
|
|
|
raise ValueError("should give one case at least") |
|
|
|
|
|
|
|
if not callable(func): |
|
|
|
raise ValueError("the input func should be callable") |
|
|
|
|
|
|
|
results = func(*inp_tensor, **kwargs) |
|
|
|
check_results(results, outp) |
|
|
|
for index in range(len(cases)): |
|
|
|
run_index(index) |