Browse Source

fix(mge/dump): fix dump_with_testcase_mge Varnode type mismatch

GitOrigin-RevId: 05618e5ac5
release-1.4
Megvii Engine Team 4 years ago
parent
commit
71e007c310
1 changed files with 21 additions and 30 deletions
  1. +21
    -30
      sdk/load-and-run/dump_with_testcase_mge.py

+ 21
- 30
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -20,7 +20,6 @@ import megengine.core.tensor.megbrain_graph as G
from megengine import tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.core.tensor.megbrain_graph import VarNode
from megengine.utils import comp_graph_tools as cgtools

logger = mge.get_logger(__name__)
@@ -268,8 +267,8 @@ def make_feeds(args):

def assert_equal(expect, real, **kwargs):
op = builtin.AssertEqual(**kwargs)
(res,) = G.apply_normal_varnode(op, expect, real)
return G.VarNode(res)
(res,) = apply(op, expect, real)
return res

verbose = not args.silent

@@ -284,8 +283,8 @@ def make_feeds(args):
# insert assert opr to check expect and real.
outputs_new.append(
assert_equal(
G.VarNode(expect_get),
G.VarNode(i),
expect_get,
i,
verbose=verbose,
maxerr=args.maxerr,
)
@@ -297,29 +296,26 @@ def make_feeds(args):


def optimize_for_inference(args, outputs):
args_map = {
"enable_io16xc32": "f16_io_f32_comp",
"enable_ioc16": "f16_io_comp",
"enable_hwcd4": "use_nhwcd4",
"enable_nchw4": "use_nchw4",
"enable_nchw88": "use_nchw88",
"enable_nchw44": "use_nchw44",
"enable_nchw44_dot": "use_nchw44_dot",
"enable_nchw32": "use_nchw32",
"enable_chwn4": "use_chwn4",
"enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
"enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
}
args_list = [
"enable_io16xc32",
"enable_ioc16",
"enable_hwcd4",
"enable_nchw4",
"enable_nchw88",
"enable_nchw44",
"enable_nchw44_dot",
"enable_nchw32",
"enable_chwn4",
"enable_fuse_conv_bias_nonlinearity",
"enable_fuse_conv_bias_with_z",
]
kwargs = {}
for k, v in args_map.items():
for k in args_list:
if getattr(args, k):
assert (
args.optimize_for_inference
), "optimize_for_inference should be set when {} is given".format(k)
kwargs[v] = True
kwargs[k] = True

if args.optimize_for_inference:
outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)]
outputs = G.optimize_for_inference(outputs, **kwargs)

return outputs

@@ -476,7 +472,6 @@ def main():

output_mgbvars = feeds["outputs"]
output_mgbvars = optimize_for_inference(args, output_mgbvars)
output_mgbvars = [var._node for var in output_mgbvars]

inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
inputs = sorted((i.name, i.dtype) for i in inputs)
@@ -491,12 +486,8 @@ def main():
with open(args.output, "wb") as fout:
fout.write(b"mgbtest0")
fout.write(struct.pack("I", len(feeds["testcases"])))
if isinstance(output_mgbvars, dict):
wrap_output_vars = dict([(i, VarNode(j)) for i, j in output_mgbvars])
else:
wrap_output_vars = [VarNode(i) for i in output_mgbvars]
dump_content, stat = G.dump_graph(
wrap_output_vars,
output_mgbvars,
append_json=True,
strip_info_file=strip_info_file,
**sereg_kwargs,


Loading…
Cancel
Save