diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index b9af63a6..af96ec2d 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -14,6 +14,7 @@ from concurrent.futures import Future, ThreadPoolExecutor import numpy as np from .. import _imperative_rt +from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt.ops import BackwardGraph from .._wrap import device as as_device from ..ops.builtin import OpDef @@ -83,6 +84,84 @@ class Graph(_imperative_rt.ComputingGraph): return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) +def optimize_for_inference(dest_vars, **kwargs): + r"""Applies optimize_for_inference pass for computing graph. + + :param dest_vars: list of output vars in the computing graph + + :Keyword Arguments: + + * enable_io16xc32 -- + whether to use float16 for I/O between oprs and use + float32 as internal computation precision. Note the output var would be + changed to float16. + * enable_ioc16 -- + whether to use float16 for both I/O and computation + precision. + + * enable_hwcd4 -- + whether to use NHWCD4 data layout. This is faster on some + OpenCL backend. + * enable_nchw88 -- + whether to use NCHW88 data layout, currently + used in X86 AVX backend. + * enable_nchw44 -- + whether to use NCHW44 data layout, currently + used in arm backend. + * enable_nchw44_dot -- + whether to use NCHW44_dot data layout, currently + used in armv8.2+dotprod backend. + * enable_nchw4 -- + whether to use NCHW4 data layout, currently + used in nvidia backend(based on cudnn). + * enable_nchw32 -- + whether to use NCHW32 data layout, currently + used in nvidia backend with tensorcore(based on cudnn). + * enable_chwn4 -- + whether to use CHWN4 data layout, currently + used in nvidia backend with tensorcore. + + * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty + into one opr. + * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z + input for inference on nvidia backend(this optimization pass will + result in mismatch of the precision of output of training and + inference) + """ + inference_options = GraphOptimizeOptions() + if optimize_for_inference: + inference_optimize_layout_transform_map = { + "enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4, + "enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4, + "enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88, + "enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32, + "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, + "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, + "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, + } + + for k, v in inference_optimize_layout_transform_map.items(): + if kwargs.pop(k, False): + inference_options.layout_transform = v + + if kwargs.pop("enable_io16xc32", False): + inference_options.f16_io_f32_comp = True + if kwargs.pop("enable_ioc16", False): + inference_options.f16_io_comp = True + if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False): + inference_options.fuse_conv_bias_nonlinearity = True + if kwargs.pop("enable_fuse_conv_bias_with_z", False): + inference_options.fuse_conv_bias_with_z = True + + if kwargs: + raise ValueError("unknown options: %s" % list(kwargs)) + + res_vars = _imperative_rt.optimize_for_inference( + [i._node for i in dest_vars], inference_options + ) + return [VarNode(i) for i in res_vars] + + def dump(*args): return _imperative_rt.dump_graph([i._node for i in args]) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index baa40c34..2d272dba 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -11,6 +11,7 @@ import numpy as np from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt.ops import OprAttr +from ..core._trace_option import set_tensor_shape from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply @@ -76,6 +77,22 @@ class TensorInfo: class trace: + """ + Wraps a callable and provide: + + * tracing via :meth:`.trace` and :meth:`.dump` + * accelerated evalutaion via :meth:`.__call__` + + :param function: the function will be traced. + :param symbolic: whether to apply symbolic execution for tracing. Default: False + :param capture_as_const: capture global vars or closures as const value. Default: False + :param sublinear_memory_config: configuration for sublinear memory optimization. + If not None, it enables sublinear memory optimization with given setting. + :param profiling: whether to profile compiled trace. Default: False + :param opt_level: optimization level for compiling trace. + :param symbolic_shape: whether to use symbolic shape for tracing. Default: True + """ + def __new__(cls, *args, **kwargs): if not args: return functools.partial(cls, **kwargs) @@ -88,6 +105,8 @@ class trace: capture_as_const=False, sublinear_memory_config: SublinearMemoryConfig = None, profiling: bool = False, + opt_level: int = None, + tensor_shape: bool = True, ): self.__wrapped__ = function self._symbolic = symbolic @@ -95,6 +114,8 @@ class trace: self._sublinear_memory_config = sublinear_memory_config self._profiling = profiling self._profiler = None + self._graph_opt_level = opt_level + self._tensor_shape = tensor_shape self._untraced = True self._tinfo = [] # handle -> TensorInfo @@ -112,6 +133,8 @@ class trace: self._output_bindings = None self._output_names = None + set_tensor_shape(self._tensor_shape) + def _new_handle(self): handle = len(self._tinfo) info = TensorInfo() @@ -307,6 +330,9 @@ class trace: def _apply_graph_options(self, graph): graph.options.seq_opt.enable_seq_comp_node_opt = False + # graph opt level + if self._graph_opt_level is not None: + graph.options.graph_opt_level = self._graph_opt_level # sublinear if self._sublinear_memory_config is not None: graph.options.enable_sublinear_memory_opt = True @@ -320,6 +346,7 @@ class trace: ) sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try sublinear_config.num_worker = self._sublinear_memory_config.num_worker + # profile if self._profiling: self._profiler = GraphProfiler(graph) @@ -416,7 +443,55 @@ class trace: self._process_outputs(outputs) return outputs - def dump(self, file, *, arg_names=None, output_names=None): + def dump(self, file, *, arg_names=None, output_names=None, append=False, **kwargs): + r"""Serializes trace to file system. + + :param file: output file, could be file object or filename. + :param arg_names: names of the input tensors in the traced function. + :param output_names: names of the output tensors in the traced function, + use the default name if not specified. + :param append: whether output is appended to ``file``. + Only works when ``file`` is str. + + :Keyword Arguments: + + * enable_io16xc32 -- + whether to use float16 for I/O between oprs and use + float32 as internal computation precision. Note the output var would be + changed to float16. + * enable_ioc16 -- + whether to use float16 for both I/O and computation + precision. + + * enable_hwcd4 -- + whether to use NHWCD4 data layout. This is faster on some + OpenCL backend. + * enable_nchw88 -- + whether to use NCHW88 data layout, currently + used in X86 AVX backend. + * enable_nchw44 -- + whether to use NCHW44 data layout, currently + used in arm backend. + * enable_nchw44_dot -- + whether to use NCHW44_dot data layout, currently + used in armv8.2+dotprod backend. + * enable_nchw4 -- + whether to use NCHW4 data layout, currently + used in nvidia backend(based on cudnn). + * enable_nchw32 -- + whether to use NCHW32 data layout, currently + used in nvidia backend with tensorcore(based on cudnn). + * enable_chwn4 -- + whether to use CHWN4 data layout, currently + used in nvidia backend with tensorcore. + + * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty + into one opr. + * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z + input for inference on nvidia backend(this optimization pass will + result in mismatch of the precision of output of training and + inference) + """ if not self._capture_as_const: raise ValueError( "you must specify capture_as_const=True at __init__ to use dump" @@ -482,8 +557,11 @@ class trace: v.name = output_names[i] dest_vars.append(v) + dest_vars = G.optimize_for_inference(dest_vars, **kwargs) + if isinstance(file, str): - file = open(file, "wb") + permission = "wb" if append == False else "ab" + file = open(file, permission) file.write(G.dump(*dest_vars)) def _process_inputs(self, *args, **kwargs): diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index f6073644..6874388b 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -20,12 +20,17 @@ #include "./helper.h" #include "megbrain/plugin/profiler.h" #include "./common.h" +#include "megbrain/gopt/inference.h" + namespace py = pybind11; using namespace mgb; using namespace imperative; +using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; +using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; + namespace { class _CompGraphProfilerImpl { std::shared_ptr m_comp_graph; @@ -138,6 +143,37 @@ void init_graph_rt(py::module m) { return py::bytes(reinterpret_cast(&buf[0]), buf.size()); }); + auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") + .def(py::init()) + .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) + .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) + .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) + .def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z) + .def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform) + ; + + py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform") + .value("DEFAULT", _LayoutTransform::DEFAULT) + .value("NCHW4", _LayoutTransform::NCHW4) + .value("NHWCD4", _LayoutTransform::NHWCD4) + .value("NCHW88", _LayoutTransform::NCHW88) + .value("NCHW44", _LayoutTransform::NCHW44) + .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) + .value("NCHW32", _LayoutTransform::NCHW32) + .value("CHWN4", _LayoutTransform::CHWN4) + .export_values() + ; + + m.def("optimize_for_inference", [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { + SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); + auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt); + VarNodeArray vars; + for (auto& si: res_symvars) + vars.push_back(si.node()); + return vars; + }); + + #define CURRENT_CLASS cg::ComputingGraph::Options auto PyComputingGraphOptions = py::class_(PyComputingGraph, "Options") diff --git a/imperative/python/test/unit/test_jit.py b/imperative/python/test/unit/test_jit.py deleted file mode 100644 index 4bc9c2f1..00000000 --- a/imperative/python/test/unit/test_jit.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import pytest - -from megengine.core import Tensor - -# from megengine.core.interpreter.hints import function - - -@pytest.mark.skip(reason="under rewrite") -def test_1(): - @function - def f(x, p): - x = x + 1 - if p: - return x * x - return x * 2 - - x = Tensor(0) - - for _ in range(5): - assert f(x, 0).numpy() == 2 - assert f(x, 1).numpy() == 1 diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 45202c4e..d78d231c 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -1,10 +1,23 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import io +from tempfile import mkstemp import numpy as np +import pytest +from megengine import tensor from megengine.core.ops import builtin as ops +from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.core import apply from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.functional import exp, log from megengine.jit import exclude_from_trace, trace @@ -101,3 +114,85 @@ def test_trace_profiler(): out = f.get_profile() assert out.get("profiler") + + +@pytest.mark.skip(reason="eq_to_unit failed in inplace.cpp") +def test_goptions_div_zero(): + @trace(symbolic=True, opt_level=0) + def f(x): + return x / x + + @trace(symbolic=True, opt_level=1) + def g(x): + return x / x + + out = f(tensor(0.0)) + if out == out: + raise ValueError("actual result should be nan") + + out = g(tensor(0.0)) + if out != out: + raise ValueError("actual result should be 1") + + +@pytest.mark.skip(reason="cast to Elemwise failed in inplace.cpp") +def test_goptions_log_exp(): + @trace(symbolic=True, opt_level=0, capture_as_const=True) + def f(x): + return log(exp(x)) + + @trace(symbolic=True, opt_level=1, capture_as_const=True) + def g(x): + return log(exp(x)) + + f(tensor(1.0)) + _, out = mkstemp() + f.dump(out) + *_, outputs = G.load_comp_graph_from_file(out) + oprs_1 = cgtools.get_oprs_seq(outputs) + + g(tensor(1.0)) + g.dump(out) + *_, outputs = G.load_comp_graph_from_file(out) + oprs_2 = cgtools.get_oprs_seq(outputs) + + assert len(oprs_1) - len(oprs_2) == 2 + + +@pytest.mark.skip(reason="need cgtools to check final oprs") +def test_goptions_log_sum_exp(): + @trace(symbolic=True, opt_level=0, capture_as_const=True) + def f(x, y): + return log(exp(x) + exp(y)) + + @trace(symbolic=True, opt_level=1, capture_as_const=True) + def g(x, y): + return log(exp(x) + exp(y)) + + f(tensor(1.0), tensor(2.0)) + _, out = mkstemp() + f.dump(out) + *_, outputs = G.load_comp_graph_from_file(out) + oprs_1 = cgtools.get_oprs_seq(outputs) + + g(tensor(1.0), tensor(2.0)) + g.dump(out) + *_, outputs = G.load_comp_graph_from_file(out) + oprs_2 = cgtools.get_oprs_seq(outputs) + + assert len(oprs_1) - len(oprs_2) == 2 + + +@pytest.mark.skip(reason="need cgtools to check computing input dtype") +def test_optimize_for_inference(): + @trace(symbolic=True, capture_as_const=True) + def f(x): + return exp(x) + + _, out = mkstemp() + f(tensor(5.0)) + f.dump(out, optimize_for_inference=True, optimize_options={"enable_io16xc32": True}) + + res = G.load_comp_graph_from_file(out) + computing_input = res.output_vars_list[0].owner.inputs[0] + assert computing_input.dtype == np.float16