|
|
@@ -20,7 +20,6 @@ import megengine.core.tensor.megbrain_graph as G |
|
|
|
from megengine import tensor |
|
|
|
from megengine.core._imperative_rt.core2 import apply |
|
|
|
from megengine.core.ops import builtin |
|
|
|
from megengine.core.tensor.megbrain_graph import VarNode |
|
|
|
from megengine.utils import comp_graph_tools as cgtools |
|
|
|
|
|
|
|
logger = mge.get_logger(__name__) |
|
|
@@ -268,8 +267,8 @@ def make_feeds(args): |
|
|
|
|
|
|
|
def assert_equal(expect, real, **kwargs): |
|
|
|
op = builtin.AssertEqual(**kwargs) |
|
|
|
(res,) = G.apply_normal_varnode(op, expect, real) |
|
|
|
return G.VarNode(res) |
|
|
|
(res,) = apply(op, expect, real) |
|
|
|
return res |
|
|
|
|
|
|
|
verbose = not args.silent |
|
|
|
|
|
|
@@ -284,8 +283,8 @@ def make_feeds(args): |
|
|
|
# insert assert opr to check expect and real. |
|
|
|
outputs_new.append( |
|
|
|
assert_equal( |
|
|
|
G.VarNode(expect_get), |
|
|
|
G.VarNode(i), |
|
|
|
expect_get, |
|
|
|
i, |
|
|
|
verbose=verbose, |
|
|
|
maxerr=args.maxerr, |
|
|
|
) |
|
|
@@ -297,29 +296,26 @@ def make_feeds(args): |
|
|
|
|
|
|
|
|
|
|
|
def optimize_for_inference(args, outputs): |
|
|
|
args_map = { |
|
|
|
"enable_io16xc32": "f16_io_f32_comp", |
|
|
|
"enable_ioc16": "f16_io_comp", |
|
|
|
"enable_hwcd4": "use_nhwcd4", |
|
|
|
"enable_nchw4": "use_nchw4", |
|
|
|
"enable_nchw88": "use_nchw88", |
|
|
|
"enable_nchw44": "use_nchw44", |
|
|
|
"enable_nchw44_dot": "use_nchw44_dot", |
|
|
|
"enable_nchw32": "use_nchw32", |
|
|
|
"enable_chwn4": "use_chwn4", |
|
|
|
"enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", |
|
|
|
"enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", |
|
|
|
} |
|
|
|
args_list = [ |
|
|
|
"enable_io16xc32", |
|
|
|
"enable_ioc16", |
|
|
|
"enable_hwcd4", |
|
|
|
"enable_nchw4", |
|
|
|
"enable_nchw88", |
|
|
|
"enable_nchw44", |
|
|
|
"enable_nchw44_dot", |
|
|
|
"enable_nchw32", |
|
|
|
"enable_chwn4", |
|
|
|
"enable_fuse_conv_bias_nonlinearity", |
|
|
|
"enable_fuse_conv_bias_with_z", |
|
|
|
] |
|
|
|
kwargs = {} |
|
|
|
for k, v in args_map.items(): |
|
|
|
for k in args_list: |
|
|
|
if getattr(args, k): |
|
|
|
assert ( |
|
|
|
args.optimize_for_inference |
|
|
|
), "optimize_for_inference should be set when {} is given".format(k) |
|
|
|
kwargs[v] = True |
|
|
|
kwargs[k] = True |
|
|
|
|
|
|
|
if args.optimize_for_inference: |
|
|
|
outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)] |
|
|
|
outputs = G.optimize_for_inference(outputs, **kwargs) |
|
|
|
|
|
|
|
return outputs |
|
|
|
|
|
|
@@ -476,7 +472,6 @@ def main(): |
|
|
|
|
|
|
|
output_mgbvars = feeds["outputs"] |
|
|
|
output_mgbvars = optimize_for_inference(args, output_mgbvars) |
|
|
|
output_mgbvars = [var._node for var in output_mgbvars] |
|
|
|
|
|
|
|
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") |
|
|
|
inputs = sorted((i.name, i.dtype) for i in inputs) |
|
|
@@ -491,12 +486,8 @@ def main(): |
|
|
|
with open(args.output, "wb") as fout: |
|
|
|
fout.write(b"mgbtest0") |
|
|
|
fout.write(struct.pack("I", len(feeds["testcases"]))) |
|
|
|
if isinstance(output_mgbvars, dict): |
|
|
|
wrap_output_vars = dict([(i, VarNode(j)) for i, j in output_mgbvars]) |
|
|
|
else: |
|
|
|
wrap_output_vars = [VarNode(i) for i in output_mgbvars] |
|
|
|
dump_content, stat = G.dump_graph( |
|
|
|
wrap_output_vars, |
|
|
|
output_mgbvars, |
|
|
|
append_json=True, |
|
|
|
strip_info_file=strip_info_file, |
|
|
|
**sereg_kwargs, |
|
|
|