GitOrigin-RevId: 5dce356452
release-1.7
@@ -13,10 +13,17 @@ import itertools | |||||
import json | import json | ||||
import os | import os | ||||
import pickle | import pickle | ||||
import re | |||||
import struct | |||||
from typing import Any | from typing import Any | ||||
import cv2 | |||||
import numpy as np | import numpy as np | ||||
from megengine.logger import get_logger | |||||
from .. import tensor | |||||
from ..core import _imperative_rt as rt | |||||
from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata | from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata | ||||
from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
@@ -38,12 +45,15 @@ from ..core._wrap import as_device | |||||
from ..core.ops.builtin import BatchNorm, OpDef | from ..core.ops.builtin import BatchNorm, OpDef | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..core.tensor.utils import setscalar | from ..core.tensor.utils import setscalar | ||||
from ..utils import comp_graph_tools as cgtools | |||||
from ..utils.naming import AutoNaming | from ..utils.naming import AutoNaming | ||||
from ..utils.profiler import is_profiling | from ..utils.profiler import is_profiling | ||||
from .dtr_config import DTRConfig | from .dtr_config import DTRConfig | ||||
from .graph_opt_config import GraphOptimizationConfig | from .graph_opt_config import GraphOptimizationConfig | ||||
from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
logger = get_logger(__name__) | |||||
def _input_node_use_static_shape(): | def _input_node_use_static_shape(): | ||||
return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None | return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None | ||||
@@ -692,6 +702,289 @@ class trace: | |||||
self._process_outputs(outputs) | self._process_outputs(outputs) | ||||
return outputs | return outputs | ||||
def _make_feed( | |||||
self, | |||||
graph, | |||||
outputs, | |||||
input_data, | |||||
repeat, | |||||
silent, | |||||
no_assert, | |||||
maxerr, | |||||
resize_input, | |||||
input_transform, | |||||
): | |||||
def auto_reformat_image(path, data, dst_shape): | |||||
"""reformat image to target shape | |||||
:param data: image data as numpy array | |||||
:param dst_shape: target shape | |||||
""" | |||||
dim3_format = False # required input format does not contain batch | |||||
hwc_format = False # required input format is NHWC | |||||
if not dst_shape: # input tensor shape is not predefined | |||||
if len(data.shape) == 2: | |||||
chl = 1 | |||||
h = data.shape[0] | |||||
w = data.shape[1] | |||||
else: | |||||
assert ( | |||||
len(data.shape) == 3 | |||||
), "Input image must be of dimension 2 or 3" | |||||
h, w, chl = data.shape | |||||
dst_shape = (1, chl, h, w) | |||||
if len(dst_shape) == 3: | |||||
dst_shape = (1,) + dst_shape | |||||
dim3_format = True | |||||
assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||||
chl = dst_shape[1] | |||||
if chl in [1, 3]: | |||||
n, c, h, w = dst_shape | |||||
dst_shape = (n, h, w, c) | |||||
else: | |||||
chl = dst_shape[3] | |||||
assert chl in [ | |||||
1, | |||||
3, | |||||
], "can not infer input format from shape: {}".format(dst_shape) | |||||
hwc_format = True | |||||
# dst_shape has now been normalized to NHWC format | |||||
if resize_input: | |||||
h, w = dst_shape[1:3] | |||||
data = cv2.resize(data, (w, h)) | |||||
logger.info("input {} resized to {}".format(path, data.shape)) | |||||
if chl == 1: | |||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||||
data = data[:, :, np.newaxis] | |||||
assert data.ndim == 3 | |||||
data = data[np.newaxis] | |||||
# data normalized to NHWC format | |||||
if not hwc_format: | |||||
data = np.transpose(data, (0, 3, 1, 2)) | |||||
if dim3_format: | |||||
data = np.squeeze(data, 0) | |||||
return data | |||||
def read_input_data(dst_shape, dtype, path): | |||||
def check_shape_equal(dst_shape, data_shape): | |||||
if len(dst_shape): | |||||
assert len(data_shape) == len( | |||||
dst_shape | |||||
), "input/data shapes mismatch: {} vs {}".format( | |||||
dst_shape, data_shape | |||||
) | |||||
if data_shape[1:] != dst_shape[1:]: | |||||
logger.warning( | |||||
"dst_shape is {}; data_shape is {}".format( | |||||
dst_shape, data_shape | |||||
) | |||||
) | |||||
if path.startswith("#"): | |||||
assert not resize_input | |||||
assert not input_transform | |||||
spec = path | |||||
m = re.match( | |||||
r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec | |||||
) | |||||
assert m, "bad spec {}".format(spec) | |||||
rng_min = float(m.group(1)) | |||||
rng_max = float(m.group(2)) | |||||
if m.group(3): | |||||
shape_str = m.group(3) | |||||
try: | |||||
shape = shape_str[1:].split(",") | |||||
if shape[-1].strip() == "...": | |||||
shape = shape[:-1] | |||||
shape.extend(list(dst_shape[len(shape) :])) | |||||
data_shape = tuple(map(int, shape)) | |||||
except ValueError as e: | |||||
raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||||
else: | |||||
data_shape = dst_shape | |||||
check_shape_equal(dst_shape, data_shape) | |||||
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||||
# try to load image | |||||
data = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
if data is None: | |||||
assert not resize_input | |||||
data = np.load(path) | |||||
assert isinstance(data, np.ndarray) | |||||
else: | |||||
# load image succeeds, so we expect input format is image format | |||||
data = auto_reformat_image(path, data, dst_shape) | |||||
data = np.repeat(data, repeat, axis=0) | |||||
if repeat > 1: | |||||
logger.info( | |||||
"repeat input for {} times, data shape is {}".format( | |||||
repeat, data.shape | |||||
) | |||||
) | |||||
check_shape_equal(dst_shape, data.shape) | |||||
if input_transform: | |||||
data = eval(input_transform, {"data": data, "np": np}) | |||||
return data | |||||
def gen_one_testcase(inputs, spec): | |||||
paths = spec.split(";") | |||||
if len(paths) != len(inputs): | |||||
if len(paths) == 1 and paths[0].startswith("#"): | |||||
paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||||
assert len(paths) == len( | |||||
inputs | |||||
), "required inputs: {}; data paths: {}".format(inputs.keys(), paths) | |||||
if len(paths) == 1 and ":" not in paths[0]: | |||||
paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||||
ret = {} | |||||
for path in paths: | |||||
var, path = path.split(":") | |||||
ret[var] = read_input_data(inputs[var].shape, inputs[var].dtype, path) | |||||
return ret | |||||
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||||
inputs = {i.name: i for i in inputs} | |||||
if not no_assert: | |||||
replace_varmap = {} | |||||
inp_map = {} | |||||
# replace var use InputNode | |||||
for name, var in inputs.items(): | |||||
inp = G.InputNode( | |||||
device="xpux", dtype=var.dtype, shape=var.shape, graph=graph | |||||
) | |||||
replace_varmap[var] = inp.outputs[0]._node | |||||
inp_map[name] = inp | |||||
new = cgtools.replace_vars(outputs, replace_varmap) | |||||
if isinstance(new, rt.VarNode): | |||||
new = list(new) | |||||
output_nodes = [G.OutputNode(var) for var in new] | |||||
func = graph.compile(*[node.outputs[0]._node for node in output_nodes]) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
def calculate(*args, **kwargs): | |||||
output_val = [] | |||||
# set inputs value | |||||
for name, var in inputs.items(): | |||||
val = kwargs.pop(name, None) | |||||
assert val is not None, "miss input name{}".format(name) | |||||
dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||||
inp_map[name].set_value(dev_tensor) | |||||
func.execute() | |||||
for res in output_nodes: | |||||
output_val.append(res.get_value().numpy()) | |||||
return output_val | |||||
def expect_name(var): | |||||
return "{}:expect".format(var.name) | |||||
testcases = [] | |||||
np.set_printoptions(precision=2, threshold=4, suppress=True) | |||||
data_list = [] | |||||
for item in input_data: | |||||
if item.startswith("@"): | |||||
with open(item[1:], "r") as f: | |||||
data_list.extend( | |||||
[line.rstrip() for line in f if line.rstrip() != ""] | |||||
) | |||||
else: | |||||
data_list.append(item) | |||||
for inp_spec in data_list: | |||||
cur_testcase = gen_one_testcase(inputs, inp_spec) | |||||
assert len(cur_testcase) == len( | |||||
inputs | |||||
), "required inputs: {}; given data: {}".format( | |||||
inputs.keys(), cur_testcase.keys() | |||||
) | |||||
if not no_assert: | |||||
outputs_get = calculate(**cur_testcase) | |||||
for var, val in zip(outputs, outputs_get): | |||||
cur_testcase[expect_name(var)] = val | |||||
logger.info( | |||||
"generate test groundtruth: var={} shape={} range=({}, {})" | |||||
" mean={} var={}".format( | |||||
var, | |||||
val.shape, | |||||
val.min(), | |||||
val.max(), | |||||
np.mean(val), | |||||
np.var(val), | |||||
) | |||||
) | |||||
testcases.append(cur_testcase) | |||||
logger.info( | |||||
"add testcase: \n {}".format( | |||||
"\n ".join( | |||||
"{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||||
"mean={:.2f} sd={:.2f}".format( | |||||
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||||
) | |||||
for k, v in sorted(cur_testcase.items()) | |||||
) | |||||
) | |||||
) | |||||
if not no_assert: | |||||
def expect_shp(var): | |||||
ret = var.shape | |||||
if ret: | |||||
return ret | |||||
return testcases[0][expect_name(var)].shape | |||||
def assert_equal(expect, real, **kwargs): | |||||
op = AssertEqual(**kwargs) | |||||
(res,) = G.apply_normal_varnode(op, expect, real) | |||||
return res._node | |||||
verbose = not silent | |||||
outputs_new = [] | |||||
for i in outputs: | |||||
device = rt.CompNode("xpux") | |||||
dtype = i.dtype | |||||
name = expect_name(i) | |||||
shape = expect_shp(i) | |||||
# make expect output as one input of model. | |||||
expect_get = rt.make_h2d(graph, device, dtype, shape, name) | |||||
# insert assert opr to check expect and real. | |||||
outputs_new.append( | |||||
assert_equal(expect_get, i, verbose=verbose, maxerr=maxerr,) | |||||
) | |||||
inputs[expect_name(i)] = expect_get | |||||
outputs = outputs_new | |||||
return {"outputs": outputs, "testcases": testcases} | |||||
def dump( | def dump( | ||||
self, | self, | ||||
file, | file, | ||||
@@ -708,6 +1001,13 @@ class trace: | |||||
optimize_for_inference=True, | optimize_for_inference=True, | ||||
user_info: Any = None, | user_info: Any = None, | ||||
enable_metadata: bool = True, | enable_metadata: bool = True, | ||||
input_data=None, | |||||
repeat=1, | |||||
silent=False, | |||||
no_assert=False, | |||||
maxerr=1e-4, | |||||
resize_input=False, | |||||
input_transform=None, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
r"""Serializes trace to file system. | r"""Serializes trace to file system. | ||||
@@ -738,6 +1038,27 @@ class trace: | |||||
will skip all optimize options if this is False. Default: True | will skip all optimize options if this is False. Default: True | ||||
user_info: any type object, which will be pickled to bytes. | user_info: any type object, which will be pickled to bytes. | ||||
enable_metadata: whether to save metadata into output file. | enable_metadata: whether to save metadata into output file. | ||||
input_data: input test data and current network output would be used as groundtruth. | |||||
The format is "var0:file0;var1:file1..." to specify data files for input vars. | |||||
It can also be "#rand(min,max,shape...)" for generating random input data, for | |||||
example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)" | |||||
where `...` means the remaining part of the original shape. If the shape is not | |||||
specified, the shape of corresponding input tensors in the network will be used. | |||||
If there is only one input var, its name can be omitted. Each data file can either | |||||
be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option | |||||
can be given multiple times to add multiple testcases. If you start the data | |||||
with the letter @, the rest should be a filename, and each line in the file should | |||||
be a single datum in the format described above. *NOTE* If `input_data` is not None, | |||||
you can only use load-and-run to run the output file. | |||||
repeat: how many times the input image is repeated. Useful when running benchmark for | |||||
batch size other than one. Have no effect on randomly generated input data. | |||||
silent: whether set verbose to False in assert_equal opr. | |||||
no_assert: whether insert assert_equal opr to check result; this option is useful for | |||||
benchmarking. | |||||
maxerr: max error for assert_equal check during runtime. | |||||
resize_input: whether resize input image to fit input var shape. | |||||
input_transform: a python expression to transform the input data. | |||||
Example: data / np.std(data) | |||||
Keyword Arguments: | Keyword Arguments: | ||||
@@ -778,6 +1099,8 @@ class trace: | |||||
input for inference on nvidia backend(this optimization pass will | input for inference on nvidia backend(this optimization pass will | ||||
result in mismatch of the precision of output of training and | result in mismatch of the precision of output of training and | ||||
inference) | inference) | ||||
* enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and | |||||
etc opr | |||||
""" | """ | ||||
if not self._capture_as_const: | if not self._capture_as_const: | ||||
raise ValueError( | raise ValueError( | ||||
@@ -892,8 +1215,28 @@ class trace: | |||||
v.name = output_names[i] | v.name = output_names[i] | ||||
dest_vars.append(v) | dest_vars.append(v) | ||||
dest_vars = [i._node for i in dest_vars] | |||||
if input_data is not None: | |||||
feeds = self._make_feed( | |||||
graph, | |||||
dest_vars, | |||||
input_data, | |||||
repeat, | |||||
silent, | |||||
no_assert, | |||||
maxerr, | |||||
resize_input, | |||||
input_transform, | |||||
) | |||||
assert ( | |||||
isinstance(feeds, dict) and feeds["testcases"] | |||||
), "testcases can not be empty" | |||||
dest_vars = feeds["outputs"] | |||||
if optimize_for_inference: | if optimize_for_inference: | ||||
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) | dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) | ||||
dest_vars = [i._node for i in dest_vars] | |||||
metadata = SerializationMetadata() | metadata = SerializationMetadata() | ||||
if enable_metadata: | if enable_metadata: | ||||
@@ -910,6 +1253,9 @@ class trace: | |||||
if keep_opr_priority: | if keep_opr_priority: | ||||
graph._set_priority_to_id(dest_vars) | graph._set_priority_to_id(dest_vars) | ||||
if input_data is not None: | |||||
file.write(b"mgbtest0") | |||||
file.write(struct.pack("I", len(feeds["testcases"]))) | |||||
dump_content, dump_info = G.dump_graph( | dump_content, dump_info = G.dump_graph( | ||||
dest_vars, | dest_vars, | ||||
keep_var_name=keep_var_name, | keep_var_name=keep_var_name, | ||||
@@ -921,6 +1267,34 @@ class trace: | |||||
metadata=metadata, | metadata=metadata, | ||||
) | ) | ||||
file.write(dump_content) | file.write(dump_content) | ||||
if input_data is not None: | |||||
inputs = cgtools.get_dep_vars(dest_vars, "Host2DeviceCopy") | |||||
inputs = sorted((i.name, i.dtype) for i in inputs) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
for testcase in feeds["testcases"]: | |||||
assert isinstance(testcase, dict) | |||||
cg = G.Graph() | |||||
output_mgbvars = [] | |||||
for name, dtype in inputs: | |||||
output_mgbvars.append( | |||||
cg.make_const( | |||||
make_dev_tensor( | |||||
testcase.pop(name), dtype=dtype, device="cpux" | |||||
) | |||||
) | |||||
) | |||||
assert not testcase, "extra inputs provided in testcase: {}".format( | |||||
testcase.keys() | |||||
) | |||||
dump_content, _ = G.dump_graph( | |||||
output_mgbvars, strip_info_file=strip_info_file, append_json=True, | |||||
) | |||||
file.write(dump_content) | |||||
return dump_info | return dump_info | ||||
def _process_inputs(self, *args, **kwargs): | def _process_inputs(self, *args, **kwargs): | ||||
@@ -287,6 +287,16 @@ def test_dump_backward_graph(): | |||||
np.testing.assert_equal(results[1], dx0) | np.testing.assert_equal(results[1], dx0) | ||||
def test_dump_with_testcase(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def f(x): | |||||
return exp(x) | |||||
f(tensor(1.0)) | |||||
file = io.BytesIO() | |||||
f.dump(file, input_data=["#rand(0, 255, 1)"]) | |||||
@pytest.mark.parametrize("trace_mode", [False, True]) | @pytest.mark.parametrize("trace_mode", [False, True]) | ||||
def test_trace_profiler(trace_mode): | def test_trace_profiler(trace_mode): | ||||
@trace(symbolic=trace_mode, profiling=True) | @trace(symbolic=trace_mode, profiling=True) | ||||
@@ -1,535 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import argparse | |||||
import os | |||||
import re | |||||
import struct | |||||
import cv2 | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.core._imperative_rt as rt | |||||
import megengine.core.tensor.megbrain_graph as G | |||||
from megengine import tensor | |||||
from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core.ops import builtin | |||||
from megengine.utils import comp_graph_tools as cgtools | |||||
logger = mge.get_logger(__name__) | |||||
def auto_reformat_image(args, path, data, dst_shape): | |||||
"""reformat image to target shape | |||||
:param data: image data as numpy array | |||||
:param dst_shape: target shape | |||||
""" | |||||
dim3_format = False # required input format does not contain batch | |||||
hwc_format = False # required input format is NHWC | |||||
if not dst_shape: # input tensor shape is not predefined | |||||
if len(data.shape) == 2: | |||||
chl = 1 | |||||
h = data.shape[0] | |||||
w = data.shape[1] | |||||
else: | |||||
assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" | |||||
h, w, chl = data.shape | |||||
dst_shape = (1, chl, h, w) | |||||
if len(dst_shape) == 3: | |||||
dst_shape = (1,) + dst_shape | |||||
dim3_format = True | |||||
assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||||
chl = dst_shape[1] | |||||
if chl in [1, 3]: | |||||
n, c, h, w = dst_shape | |||||
dst_shape = (n, h, w, c) | |||||
else: | |||||
chl = dst_shape[3] | |||||
assert chl in [1, 3], "can not infer input format from shape: {}".format( | |||||
dst_shape | |||||
) | |||||
hwc_format = True | |||||
# dst_shape has now been normalized to NHWC format | |||||
if args.resize_input: | |||||
h, w = dst_shape[1:3] | |||||
data = cv2.resize(data, (w, h)) | |||||
logger.info("input {} resized to {}".format(path, data.shape)) | |||||
if chl == 1: | |||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||||
data = data[:, :, np.newaxis] | |||||
assert data.ndim == 3 | |||||
data = data[np.newaxis] | |||||
# data normalized to NHWC format | |||||
if not hwc_format: | |||||
data = np.transpose(data, (0, 3, 1, 2)) | |||||
if dim3_format: | |||||
data = np.squeeze(data, 0) | |||||
return data | |||||
def read_input_data(args, dst_shape, dtype, path, repeat): | |||||
def check_shape_equal(dst_shape, data_shape): | |||||
if len(dst_shape): | |||||
assert len(data_shape) == len( | |||||
dst_shape | |||||
), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) | |||||
if data_shape[1:] != dst_shape[1:]: | |||||
logger.warning( | |||||
"dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) | |||||
) | |||||
if path.startswith("#"): | |||||
assert not args.resize_input | |||||
assert not args.input_transform | |||||
spec = path | |||||
m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) | |||||
assert m, "bad spec {}".format(spec) | |||||
rng_min = float(m.group(1)) | |||||
rng_max = float(m.group(2)) | |||||
if m.group(3): | |||||
shape_str = m.group(3) | |||||
try: | |||||
shape = shape_str[1:].split(",") | |||||
if shape[-1].strip() == "...": | |||||
shape = shape[:-1] | |||||
shape.extend(list(dst_shape[len(shape) :])) | |||||
data_shape = tuple(map(int, shape)) | |||||
except ValueError as e: | |||||
raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||||
else: | |||||
data_shape = dst_shape | |||||
check_shape_equal(dst_shape, data_shape) | |||||
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||||
# try to load image | |||||
data = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
if data is None: | |||||
assert not args.resize_input | |||||
data = np.load(path) | |||||
assert isinstance(data, np.ndarray) | |||||
else: | |||||
# load image succeeds, so we expect input format is image format | |||||
data = auto_reformat_image(args, path, data, dst_shape) | |||||
data = np.repeat(data, repeat, axis=0) | |||||
if repeat > 1: | |||||
logger.info( | |||||
"repeat input for {} times, data shape is {}".format(repeat, data.shape) | |||||
) | |||||
check_shape_equal(dst_shape, data.shape) | |||||
if args.input_transform: | |||||
data = eval(args.input_transform, {"data": data, "np": np}) | |||||
return data | |||||
def gen_one_testcase(args, inputs, spec): | |||||
paths = spec.split(";") | |||||
if len(paths) != len(inputs): | |||||
if len(paths) == 1 and paths[0].startswith("#"): | |||||
paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||||
assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( | |||||
inputs.keys(), paths | |||||
) | |||||
if len(paths) == 1 and ":" not in paths[0]: | |||||
paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||||
ret = {} | |||||
for path in paths: | |||||
var, path = path.split(":") | |||||
if args.repeat: | |||||
repeat = args.repeat | |||||
else: | |||||
repeat = 1 | |||||
ret[var] = read_input_data( | |||||
args, inputs[var].shape, inputs[var].dtype, path, repeat | |||||
) | |||||
return ret | |||||
def make_feeds(args): | |||||
ret = G.load_graph(args.input) | |||||
cg_rt, outputs = ret.graph, ret.output_vars_list | |||||
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||||
inputs = {i.name: i for i in inputs} | |||||
if not args.no_assert: | |||||
replace_varmap = {} | |||||
inp_map = {} | |||||
# replace var use InputNode | |||||
for name, var in inputs.items(): | |||||
inp = G.InputNode( | |||||
device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt | |||||
) | |||||
replace_varmap[var] = inp.outputs[0] | |||||
inp_map[name] = inp | |||||
new = cgtools.replace_vars(outputs, replace_varmap) | |||||
if isinstance(new, rt.VarNode): | |||||
new = list(new) | |||||
output_nodes = [G.OutputNode(var) for var in new] | |||||
func = cg_rt.compile([node.outputs[0] for node in output_nodes]) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
def calculate(*args, **kwargs): | |||||
output_val = [] | |||||
# set inputs value | |||||
for name, var in inputs.items(): | |||||
val = kwargs.pop(name, None) | |||||
assert val is not None, "miss input name{}".format(name) | |||||
dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||||
inp_map[name].set_value(dev_tensor) | |||||
func.execute() | |||||
for res in output_nodes: | |||||
output_val.append(res.get_value().numpy()) | |||||
return output_val | |||||
def expect_name(var): | |||||
return "{}:expect".format(var.name) | |||||
testcases = [] | |||||
np.set_printoptions(precision=2, threshold=4, suppress=True) | |||||
data_list = [] | |||||
for item in args.data: | |||||
if item.startswith("@"): | |||||
with open(item[1:], "r") as f: | |||||
data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) | |||||
else: | |||||
data_list.append(item) | |||||
for inp_spec in data_list: | |||||
cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||||
assert len(cur_testcase) == len( | |||||
inputs | |||||
), "required inputs: {}; given data: {}".format( | |||||
inputs.keys(), cur_testcase.keys() | |||||
) | |||||
if not args.no_assert: | |||||
outputs_get = calculate(**cur_testcase) | |||||
for var, val in zip(outputs, outputs_get): | |||||
cur_testcase[expect_name(var)] = val | |||||
logger.info( | |||||
"generate test groundtruth: var={} shape={} range=({}, {})" | |||||
" mean={} var={}".format( | |||||
var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) | |||||
) | |||||
) | |||||
testcases.append(cur_testcase) | |||||
logger.info( | |||||
"add testcase: \n {}".format( | |||||
"\n ".join( | |||||
"{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||||
"mean={:.2f} sd={:.2f}".format( | |||||
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||||
) | |||||
for k, v in sorted(cur_testcase.items()) | |||||
) | |||||
) | |||||
) | |||||
if not args.no_assert: | |||||
def expect_shp(var): | |||||
ret = var.shape | |||||
if ret: | |||||
return ret | |||||
return testcases[0][expect_name(var)].shape | |||||
def assert_equal(expect, real, **kwargs): | |||||
op = builtin.AssertEqual(**kwargs) | |||||
(res,) = apply(op, expect, real) | |||||
return res | |||||
verbose = not args.silent | |||||
outputs_new = [] | |||||
for i in outputs: | |||||
device = rt.CompNode("xpux") | |||||
dtype = i.dtype | |||||
name = expect_name(i) | |||||
shape = expect_shp(i) | |||||
# make expect output as one input of model. | |||||
expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) | |||||
# insert assert opr to check expect and real. | |||||
outputs_new.append( | |||||
assert_equal( | |||||
expect_get, | |||||
i, | |||||
verbose=verbose, | |||||
maxerr=args.maxerr, | |||||
) | |||||
) | |||||
inputs[expect_name(i)] = expect_get | |||||
outputs = outputs_new | |||||
return {"outputs": outputs, "testcases": testcases} | |||||
def optimize_for_inference(args, outputs): | |||||
args_list = [ | |||||
"enable_io16xc32", | |||||
"enable_ioc16", | |||||
"enable_hwcd4", | |||||
"enable_nchw4", | |||||
"enable_nchw88", | |||||
"enable_nchw44", | |||||
"enable_nchw44_dot", | |||||
"enable_nchw32", | |||||
"enable_chwn4", | |||||
"enable_fuse_conv_bias_nonlinearity", | |||||
"enable_fuse_conv_bias_with_z", | |||||
"enable_fuse_preprocess", | |||||
] | |||||
kwargs = {} | |||||
for k in args_list: | |||||
if getattr(args, k): | |||||
kwargs[k] = True | |||||
if args.optimize_for_inference: | |||||
outputs = G.optimize_for_inference(outputs, **kwargs) | |||||
return outputs | |||||
def main(): | |||||
parser = argparse.ArgumentParser( | |||||
description="Pack computing graph, input values and expected output " | |||||
"values into one file for checking correctness. README.md gives more " | |||||
"details on the usage", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument("input", help="MegEngine dumped model file") | |||||
parser.add_argument("-o", "--output", help="output file", required=True) | |||||
parser.add_argument( | |||||
"-d", | |||||
"--data", | |||||
default=[], | |||||
action="append", | |||||
required=True, | |||||
help="Given input test data when input file is a network, " | |||||
"and current network output would be used as groundtruth. " | |||||
"The format is var0:file0;var1:file1... to specify data files for " | |||||
"input vars. It can also be #rand(min,max,shape...) for generating " | |||||
"random input data, for example, #rand(0,255), " | |||||
"#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " | |||||
"the remaining part of the original shape. " | |||||
"If the shape is not specified, the shape of " | |||||
"corresponding input tensors in the network will be used. " | |||||
"If there is only one input var, its name can be omitted. " | |||||
"Each data file can either be an image which can be loaded by opencv, " | |||||
"or a pickled numpy.ndarray. " | |||||
"This option can be given multiple times to add multiple testcases. " | |||||
" *NOTE* " | |||||
"If you start the data with the letter @, the rest should be a " | |||||
"filename, and each line in the file should be a single datum in " | |||||
"the format described above. ", | |||||
) | |||||
parser.add_argument( | |||||
"--repeat", | |||||
type=int, | |||||
default=1, | |||||
help="Specify how many times the input image is repeated. " | |||||
"Useful when running benchmark for batch size other than one. " | |||||
"Have no effect on randomly generated input data.", | |||||
) | |||||
parser.add_argument( | |||||
"--silent", | |||||
action="store_true", | |||||
help="set verbose to False in asserti_equal opr", | |||||
) | |||||
parser.add_argument( | |||||
"--optimize-for-inference", | |||||
action="store_true", | |||||
help="enable optimization for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--no-assert", | |||||
action="store_true", | |||||
help="do not insert assert_equal opr to check result; " | |||||
"this option is useful for benchmarking", | |||||
) | |||||
parser.add_argument( | |||||
"--maxerr", | |||||
type=float, | |||||
default=1e-4, | |||||
help="max error for assert_equal check during runtime", | |||||
) | |||||
parser.add_argument( | |||||
"--resize-input", | |||||
action="store_true", | |||||
help="resize input image to fit input var shape", | |||||
) | |||||
parser.add_argument( | |||||
"--input-transform", | |||||
help="a python expression to transform the input data. " | |||||
"Example: data / np.std(data)", | |||||
) | |||||
parser.add_argument( | |||||
"--discard-var-name", | |||||
action="store_true", | |||||
help="discard variable and param names in the " "generated output", | |||||
) | |||||
parser.add_argument( | |||||
"--output-strip-info", action="store_true", help="output code strip information" | |||||
) | |||||
parser.add_argument( | |||||
"--enable-io16xc32", | |||||
action="store_true", | |||||
help="transform the mode to float16 io float32 compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-ioc16", | |||||
action="store_true", | |||||
help="transform the dtype of the model to float16 io " "and compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-conv-bias-nonlinearity", | |||||
action="store_true", | |||||
help="fuse convolution bias and nonlinearity opr to a " | |||||
"conv_bias opr and compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-hwcd4", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NHWCD4 " | |||||
"for inference; you may need to disable CUDA and set " | |||||
"MGB_USE_MEGDNN_DBG=2", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw4", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW4 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw88", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW88 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw44", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW44 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw44-dot", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW44_DOT " | |||||
"for optimizing armv8.2 dot in inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw32", | |||||
action="store_true", | |||||
help="transform the model format from NCHW4 to NCHW32 " | |||||
"for inference on nvidia TensoCore", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-chwn4", | |||||
action="store_true", | |||||
help="transform the model format to CHWN4 " | |||||
"for inference, mainly used for nvidia tensorcore", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-conv-bias-with-z", | |||||
action="store_true", | |||||
help="fuse conv_bias with z input for inference on " | |||||
"nvidia GPU (this optimization pass will result in mismatch " | |||||
"of the precision of output of training and inference)", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-preprocess", | |||||
action="store_true", | |||||
help="fuse astype\pad_channel\dimshuffle and etc opr " | |||||
"from h2d opr", | |||||
) | |||||
args = parser.parse_args() | |||||
feeds = make_feeds(args) | |||||
assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" | |||||
output_mgbvars = feeds["outputs"] | |||||
output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||||
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||||
inputs = sorted((i.name, i.dtype) for i in inputs) | |||||
if args.discard_var_name: | |||||
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||||
else: | |||||
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||||
strip_info_file = args.output + ".json" if args.output_strip_info else None | |||||
with open(args.output, "wb") as fout: | |||||
fout.write(b"mgbtest0") | |||||
fout.write(struct.pack("I", len(feeds["testcases"]))) | |||||
dump_content, stat = G.dump_graph( | |||||
output_mgbvars, | |||||
append_json=True, | |||||
strip_info_file=strip_info_file, | |||||
**sereg_kwargs, | |||||
) | |||||
fout.write(dump_content) | |||||
logger.info( | |||||
"graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( | |||||
stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||||
) | |||||
) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
for testcase in feeds["testcases"]: | |||||
assert isinstance(testcase, dict) | |||||
cg = G.Graph() | |||||
output_mgbvars = [] | |||||
for name, dtype in inputs: | |||||
output_mgbvars.append( | |||||
cg.make_const( | |||||
make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") | |||||
) | |||||
) | |||||
assert not testcase, "extra inputs provided in testcase: {}".format( | |||||
testcase.keys() | |||||
) | |||||
with open(args.output, "ab") as fout: | |||||
dump_content, _ = G.dump_graph( | |||||
output_mgbvars, strip_info_file=strip_info_file, append_json=True | |||||
) | |||||
fout.write(dump_content) | |||||
if __name__ == "__main__": | |||||
main() |