From fad5bc74c34d2555f15196dd727c11e12f8865f4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 4 Mar 2021 17:24:45 +0800 Subject: [PATCH] feat(mge/tools): add load_network_and_run for megengine GitOrigin-RevId: 7d8ed5af79619b50c14e062f50da29cbee13e875 --- imperative/python/megengine/serialization.py | 10 +- .../python/megengine/tools/load_network_and_run.py | 418 +++++++++++++++++++++ .../python/megengine/utils/profile_analyzer.py | 2 +- imperative/python/src/graph_rt.cpp | 1 + 4 files changed, 425 insertions(+), 6 deletions(-) create mode 100755 imperative/python/megengine/tools/load_network_and_run.py diff --git a/imperative/python/megengine/serialization.py b/imperative/python/megengine/serialization.py index a6a01ecd..38af5145 100644 --- a/imperative/python/megengine/serialization.py +++ b/imperative/python/megengine/serialization.py @@ -107,14 +107,14 @@ def load(f, map_location=None, pickle_module=pickle): .. testcode: import megengine as mge - # Load tensors to the same device as defined in model.mge - mge.load('model.mge') + # Load tensors to the same device as defined in model.pkl + mge.load('model.pkl') # Load all tensors to gpu0. - mge.load('model.mge', map_location='gpu0') + mge.load('model.pkl', map_location='gpu0') # Load all tensors originally on gpu0 to cpu0 - mge.load('model.mge', map_location={'gpu0':'cpu0'}) + mge.load('model.pkl', map_location={'gpu0':'cpu0'}) # Load all tensors to cpu0 - mge.load('model.mge', map_location=lambda dev: 'cpu0') + mge.load('model.pkl', map_location=lambda dev: 'cpu0') """ if isinstance(f, str): diff --git a/imperative/python/megengine/tools/load_network_and_run.py b/imperative/python/megengine/tools/load_network_and_run.py new file mode 100755 index 00000000..2b56a2fb --- /dev/null +++ b/imperative/python/megengine/tools/load_network_and_run.py @@ -0,0 +1,418 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import argparse +import logging +import time +from collections import OrderedDict + +import numpy as np + +import megengine as mge +from megengine.core.tensor import megbrain_graph as G +from megengine.device import get_device_count, set_default_device +from megengine.functional.debug_param import set_execution_strategy +from megengine.logger import enable_debug_log, get_logger, set_log_file +from megengine.utils import comp_graph_tools as tools + +logger = get_logger(__name__) + + +def make_data_given_desc(args, inputs, shape0_multiply=1): + if args.load_input_data: + logger.info("load data from {}".format(args.load_input_data)) + data = mge.load(args.load_input_data) + data_names = [inp.name for inp in inputs] + + if isinstance(data, np.ndarray): + assert len(data_names) == 1, ( + "data is given as a single numpy array, so there should be " + "exactly one input in the graph; got: {}".format(data_names) + ) + data = {data_names[0]: data} + + assert isinstance(data, dict) + for v in data.values(): + assert isinstance( + v, np.ndarray + ), "data should provide ndarray; got {} instead".format(v) + + if args.batchsize: + for k, v in list(data.items()): + assert ( + args.batchsize % v.shape[0] == 0 + ), "current batch size must divide given batch size: {} {}".format( + args.batchsize, v.shape[0] + ) + data[k] = np.repeat(v, args.batchsize // v.shape[0], axis=0) + return data + + def iter_inpdesc(desc): + if not desc: + return + for pair in desc.split(";"): + name, value = pair.split(":") + if name not in data_shapes: + logger.warning("rng name {} not in data provider".format(name)) + yield name, value + + rng = np.random.RandomState(args.seed) + + data_shapes = OrderedDict((inp.name, list(inp.shape)) for inp in inputs) + data_dtypes = OrderedDict((inp.name, inp.dtype) for inp in inputs) + + for name, shape in iter_inpdesc(args.input_desc): + data_shapes[name] = list(map(int, shape.split(","))) + + if args.batchsize: + for i in data_shapes.values(): + i[0] = args.batchsize + + data_rngs = dict(iter_inpdesc(args.rng)) + + result = OrderedDict() + for name, shape in data_shapes.items(): + shape[0] *= shape0_multiply + rng_expr = data_rngs.get(name) + if rng_expr: + value = eval("rng.{}".format(rng_expr).format(shape), {"rng": rng}) + else: + value = rng.uniform(size=shape) + + value = np.ascontiguousarray(value, dtype=data_dtypes[name]) + assert value.shape == tuple(shape) + result[name] = value + + return result + + +def get_execution_strategy(args): + if not args.fast_run: + logger.warning("--fast-run not enabled; execution may be slow") + strategy = "HEURISTIC" + else: + logger.warning("--fast-run enabled; compile may be slow") + strategy = "PROFILE" + if args.reproducible: + strategy += "_REPRODUCIBLE" + return strategy + + +def get_opt_kwargs(args): + args_list = [ + "enable_io16xc32", + "enable_ioc16", + "enable_hwcd4", + "enable_nchw4", + "enable_nchw88", + "enable_nchw44", + "enable_nchw44_dot", + "enable_nchw32", + "enable_chwn4", + "enable_fuse_conv_bias_nonlinearity", + "enable_fuse_conv_bias_with_z", + ] + kwargs = {} + for k in args_list: + if getattr(args, k): + kwargs[k] = True + return kwargs + + +def run_model(args, graph, inputs, outputs, data): + # must use level0 to avoid unintended opr modification + graph.options.graph_opt_level = 0 + + logger.info("input tensors: ") + for k, v in data.items(): + logger.info(" {}: {}".format(k, v.shape)) + + G.modify_opr_algo_strategy_inplace(outputs, get_execution_strategy(args)) + + if args.optimize_for_inference: + opt_kwargs = get_opt_kwargs(args) + outputs = G.optimize_for_inference(outputs, **opt_kwargs) + + # embed inputs must be on the last, to avoid const fold + if args.embed_input: + outputs, inp_dict = tools.embed_inputs(outputs, data.values(), inputs=inputs) + else: + outputs, inp_dict = tools.convert_inputs(outputs, inputs=inputs) + + if args.dump_cpp_model: + dump_content, _ = G.dump_graph(outputs, keep_var_name=2) + with open(args.dump_cpp_model, "wb") as file: + file.write(dump_content) + logger.info("C++ model written to {}".format(args.dump_cpp_model)) + + outputs, output_dict = tools.convert_outputs(outputs) + + if args.profile: + profiler = tools.GraphProfiler(graph) + + func = graph.compile(outputs) + + def run(): + if not args.embed_input: + for key in inp_dict: + inp_dict[key].set_value(mge.Tensor(data[key])._dev_tensor()) + func.execute() + func.wait() + return [oup_node.get_value().numpy() for oup_node in output_dict.values()] + + if args.warm_up: + logger.info("warming up") + run() + + total_time = 0 + + for i in range(args.iter): + logger.info("iter {}".format(i)) + start_time = time.time() + retval = run() + cur_time = time.time() - start_time + total_time += cur_time + + avg_speed = (i + 1) / total_time + if "data" in data: + avg_speed *= data["data"].shape[0] + avg_speed_txt = "{:.3f}sample/s".format(avg_speed) + else: + avg_speed_txt = "{:.3f}batch/s".format(avg_speed) + + msg = ( + "iter {}: duration={:.4f}({:.4f})s average={:.4f}s " + "avg_speed={} time={:.4f}s" + ).format( + i, + cur_time, + func.get_prev_exec_time(), + total_time / (i + 1), + avg_speed_txt, + total_time, + ) + if args.calc_output_rms: + rms = [] + for v in retval: + rms.append("{:.3g}".format(float(((v ** 2).mean()) ** 0.5))) + msg += " output_rms=[{}]".format(", ".join(rms)) + if logger.level > logging.INFO: + print(msg) + else: + logger.info(msg) + + if args.focused_nvprof: + if get_device_count("gpu") < 1: + logger.warning( + "No cuda device detected. ``focused_nvprof`` will be ignored." + ) + else: + try: + import pycuda.driver as D + + D.start_profiler() + func.execute() + func.wait() + D.stop_profiler() + except ImportError: + logger.error("`focused_nvprof need pycuda`", exc_info=True) + + if args.profile: + with open(args.profile, "w") as fout: + fout.write(profiler.get()) + + return avg_speed + + +def main(): + parser = argparse.ArgumentParser( + description="load a network and run inference on random data", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("net") + parser.add_argument( + "--device", "-d", help="set defult device, like 'gpux' or 'cpux'" + ) + parser.add_argument( + "--calc-output-rms", + action="store_true", + help="compute RMS of outputs; useful for comparing computing results", + ) + parser.add_argument( + "--output-name", + nargs="*", + help="Specify output name. This option can be" + " specified multiple times. We will look for opr/var" + " in the graph", + ) + parser.add_argument( + "--load-input-data", + help="load input data from pickle file; it should be" + " a numpy array or a dict of numpy array", + ) + parser.add_argument("--profile", help="profiler output file") + parser.add_argument( + "--fast-run", + action="store_true", + help="enable fast running by profiling conv algorithms during compiling.", + ) + parser.add_argument( + "--reproducible", action="store_true", help="use reproducible kernels" + ) + parser.add_argument( + "--input-desc", + help="specifiy input names and shapes manually in" + " format: :[;:, ...], where" + " name is a string and shape is a comma separated" + ' string. e.g., "data:128,1,28,28,label:128".' + " different input tensor are separated by semicolon.", + ) + parser.add_argument( + "--batchsize", + type=int, + help="change batchsize; the first dimension of each" + " input is assumed to be batch size", + ) + parser.add_argument( + "--warm-up", + action="store_true", + help="warm up model before do timing " " for better estimation", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="verbose output, logging in debug mode", + ) + parser.add_argument( + "--iter", type=int, default=1, help="number of iters to run the model" + ) + parser.add_argument("--log", help="give a file path to duplicate log to") + parser.add_argument( + "--seed", + type=int, + default=0, + help="seed for random number generator for input data", + ) + parser.add_argument( + "--rng", + help="special RNG options to generate input data in" + " format: :func[;:func, ...] where name is" + " a string and func is a python expression containing" + ' "{}" for the size param, e.g. ' + ' "label:randint(low=0,high=1000,size={})"', + ) + parser.add_argument( + "--focused-nvprof", + action="store_true", + help="only profile last iter for `nvprof --profile-from-start off`", + ) + parser.add_argument( + "--optimize-for-inference", + action="store_true", + help="optimize model for inference", + ) + parser.add_argument( + "--enable-io16xc32", + action="store_true", + help="transform the mode to float16 io float32 compute", + ) + parser.add_argument( + "--enable-ioc16", + action="store_true", + help="transform the dtype of the model to float16 io and compute", + ) + parser.add_argument( + "--enable-hwcd4", + action="store_true", + help="transform the model format from NCHW to NHWCD4 for inference", + ) + parser.add_argument( + "--enable-nchw4", + action="store_true", + help="transform the model format from NCHW to NCHW4 for inference", + ) + parser.add_argument( + "--enable-nchw88", + action="store_true", + help="transform the model format from NCHW to NCHW88 for inference", + ) + parser.add_argument( + "--enable-nchw44", + action="store_true", + help="transform the model format from NCHW to NCHW44 for inference", + ) + parser.add_argument( + "--enable-nchw44-dot", + action="store_true", + help="transform the model format from NCHW to NCHW44_DOT " + "for optimizing armv8.2 dot in inference", + ) + parser.add_argument( + "--enable-chwn4", + action="store_true", + help="transform the model format to CHWN4 " + "for inference, mainly used for nvidia tensorcore", + ) + parser.add_argument( + "--enable-nchw32", + action="store_true", + help="transform the model format from NCHW4 to NCHW32 " + "for inference on nvidia TensoCore", + ) + parser.add_argument( + "--enable-fuse-conv-bias-nonlinearity", + action="store_true", + help="fuse convolution bias and nonlinearity opr to a " + "conv_bias opr and compute", + ) + parser.add_argument( + "--enable-fuse-conv-bias-with-z", + action="store_true", + help="fuse conv_bias with z input for inference on " + "nvidia GPU (this optimization pass will result in mismatch " + "of the precision of output of training and inference)", + ) + parser.add_argument( + "--dump-cpp-model", + help="write a C++ model that can be loaded by " + "megbrain/sdk/load-and-run; " + "this implies --embed-input", + ) + parser.add_argument( + "--embed-input", + action="store_true", + help="embed input data as SharedDeviceTensor in model, " + "to remove memory copy for inputs", + ) + args = parser.parse_args() + + if args.verbose: + enable_debug_log() + if args.log: + set_log_file(args.log) + + if args.device: + set_default_device(args.device) + + if args.dump_cpp_model: + args.embed_input = True + + logger.info("loading model ...") + graph, _, output_vars = G.load_graph(args.net) + input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") + + if args.output_name is not None: + output_vars = tools.find_vars_by_name(output_vars, args.output_name) + + data = make_data_given_desc(args, input_vars) + + run_model(args, graph, input_vars, output_vars, data) + + +if __name__ == "__main__": + main() diff --git a/imperative/python/megengine/utils/profile_analyzer.py b/imperative/python/megengine/utils/profile_analyzer.py index fab2f4d6..59945ca4 100644 --- a/imperative/python/megengine/utils/profile_analyzer.py +++ b/imperative/python/megengine/utils/profile_analyzer.py @@ -181,7 +181,7 @@ class ProfileAnalyzer: :param opr_filter: function that filter oprs. """ self._opr_set = dict() # type: dict - assert isinstance(obj, dict) + assert isinstance(obj, dict), type(obj) varz = obj["graph_exec"]["var"] for opr_id, entry in obj["graph_exec"]["operator"].items(): inp = [varz[i] for i in entry["input"]] diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index db6f10a2..151d8a8f 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -195,6 +195,7 @@ void init_graph_rt(py::module m) { py::class_(m, "AsyncExecutable") .def("execute", &cg::AsyncExecutable::execute, py::call_guard()) .def("wait", &cg::AsyncExecutable::wait, py::call_guard()) + .def("get_prev_exec_time", &cg::AsyncExecutable::get_prev_exec_time, py::call_guard()) // only used for exception handle .def_property_readonly("_all_rendezvous", [](cg::AsyncExecutable* exec) { auto ud = exec->owner_graph()->options().user_data