@@ -14,6 +14,7 @@ from concurrent.futures import Future, ThreadPoolExecutor | |||
import numpy as np | |||
from .. import _imperative_rt | |||
from .._imperative_rt import GraphOptimizeOptions | |||
from .._imperative_rt.ops import BackwardGraph | |||
from .._wrap import device as as_device | |||
from ..ops.builtin import OpDef | |||
@@ -83,6 +84,84 @@ class Graph(_imperative_rt.ComputingGraph): | |||
return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) | |||
def optimize_for_inference(dest_vars, **kwargs): | |||
r"""Applies optimize_for_inference pass for computing graph. | |||
:param dest_vars: list of output vars in the computing graph | |||
:Keyword Arguments: | |||
* enable_io16xc32 -- | |||
whether to use float16 for I/O between oprs and use | |||
float32 as internal computation precision. Note the output var would be | |||
changed to float16. | |||
* enable_ioc16 -- | |||
whether to use float16 for both I/O and computation | |||
precision. | |||
* enable_hwcd4 -- | |||
whether to use NHWCD4 data layout. This is faster on some | |||
OpenCL backend. | |||
* enable_nchw88 -- | |||
whether to use NCHW88 data layout, currently | |||
used in X86 AVX backend. | |||
* enable_nchw44 -- | |||
whether to use NCHW44 data layout, currently | |||
used in arm backend. | |||
* enable_nchw44_dot -- | |||
whether to use NCHW44_dot data layout, currently | |||
used in armv8.2+dotprod backend. | |||
* enable_nchw4 -- | |||
whether to use NCHW4 data layout, currently | |||
used in nvidia backend(based on cudnn). | |||
* enable_nchw32 -- | |||
whether to use NCHW32 data layout, currently | |||
used in nvidia backend with tensorcore(based on cudnn). | |||
* enable_chwn4 -- | |||
whether to use CHWN4 data layout, currently | |||
used in nvidia backend with tensorcore. | |||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. | |||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
input for inference on nvidia backend(this optimization pass will | |||
result in mismatch of the precision of output of training and | |||
inference) | |||
""" | |||
inference_options = GraphOptimizeOptions() | |||
if optimize_for_inference: | |||
inference_optimize_layout_transform_map = { | |||
"enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4, | |||
"enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4, | |||
"enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88, | |||
"enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32, | |||
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | |||
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | |||
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | |||
} | |||
for k, v in inference_optimize_layout_transform_map.items(): | |||
if kwargs.pop(k, False): | |||
inference_options.layout_transform = v | |||
if kwargs.pop("enable_io16xc32", False): | |||
inference_options.f16_io_f32_comp = True | |||
if kwargs.pop("enable_ioc16", False): | |||
inference_options.f16_io_comp = True | |||
if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False): | |||
inference_options.fuse_conv_bias_nonlinearity = True | |||
if kwargs.pop("enable_fuse_conv_bias_with_z", False): | |||
inference_options.fuse_conv_bias_with_z = True | |||
if kwargs: | |||
raise ValueError("unknown options: %s" % list(kwargs)) | |||
res_vars = _imperative_rt.optimize_for_inference( | |||
[i._node for i in dest_vars], inference_options | |||
) | |||
return [VarNode(i) for i in res_vars] | |||
def dump(*args): | |||
return _imperative_rt.dump_graph([i._node for i in args]) | |||
@@ -11,6 +11,7 @@ import numpy as np | |||
from ..core._imperative_rt import GraphProfiler | |||
from ..core._imperative_rt.ops import OprAttr | |||
from ..core._trace_option import set_tensor_shape | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
@@ -76,6 +77,22 @@ class TensorInfo: | |||
class trace: | |||
""" | |||
Wraps a callable and provide: | |||
* tracing via :meth:`.trace` and :meth:`.dump` | |||
* accelerated evalutaion via :meth:`.__call__` | |||
:param function: the function will be traced. | |||
:param symbolic: whether to apply symbolic execution for tracing. Default: False | |||
:param capture_as_const: capture global vars or closures as const value. Default: False | |||
:param sublinear_memory_config: configuration for sublinear memory optimization. | |||
If not None, it enables sublinear memory optimization with given setting. | |||
:param profiling: whether to profile compiled trace. Default: False | |||
:param opt_level: optimization level for compiling trace. | |||
:param symbolic_shape: whether to use symbolic shape for tracing. Default: True | |||
""" | |||
def __new__(cls, *args, **kwargs): | |||
if not args: | |||
return functools.partial(cls, **kwargs) | |||
@@ -88,6 +105,8 @@ class trace: | |||
capture_as_const=False, | |||
sublinear_memory_config: SublinearMemoryConfig = None, | |||
profiling: bool = False, | |||
opt_level: int = None, | |||
tensor_shape: bool = True, | |||
): | |||
self.__wrapped__ = function | |||
self._symbolic = symbolic | |||
@@ -95,6 +114,8 @@ class trace: | |||
self._sublinear_memory_config = sublinear_memory_config | |||
self._profiling = profiling | |||
self._profiler = None | |||
self._graph_opt_level = opt_level | |||
self._tensor_shape = tensor_shape | |||
self._untraced = True | |||
self._tinfo = [] # handle -> TensorInfo | |||
@@ -112,6 +133,8 @@ class trace: | |||
self._output_bindings = None | |||
self._output_names = None | |||
set_tensor_shape(self._tensor_shape) | |||
def _new_handle(self): | |||
handle = len(self._tinfo) | |||
info = TensorInfo() | |||
@@ -307,6 +330,9 @@ class trace: | |||
def _apply_graph_options(self, graph): | |||
graph.options.seq_opt.enable_seq_comp_node_opt = False | |||
# graph opt level | |||
if self._graph_opt_level is not None: | |||
graph.options.graph_opt_level = self._graph_opt_level | |||
# sublinear | |||
if self._sublinear_memory_config is not None: | |||
graph.options.enable_sublinear_memory_opt = True | |||
@@ -320,6 +346,7 @@ class trace: | |||
) | |||
sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try | |||
sublinear_config.num_worker = self._sublinear_memory_config.num_worker | |||
# profile | |||
if self._profiling: | |||
self._profiler = GraphProfiler(graph) | |||
@@ -416,7 +443,55 @@ class trace: | |||
self._process_outputs(outputs) | |||
return outputs | |||
def dump(self, file, *, arg_names=None, output_names=None): | |||
def dump(self, file, *, arg_names=None, output_names=None, append=False, **kwargs): | |||
r"""Serializes trace to file system. | |||
:param file: output file, could be file object or filename. | |||
:param arg_names: names of the input tensors in the traced function. | |||
:param output_names: names of the output tensors in the traced function, | |||
use the default name if not specified. | |||
:param append: whether output is appended to ``file``. | |||
Only works when ``file`` is str. | |||
:Keyword Arguments: | |||
* enable_io16xc32 -- | |||
whether to use float16 for I/O between oprs and use | |||
float32 as internal computation precision. Note the output var would be | |||
changed to float16. | |||
* enable_ioc16 -- | |||
whether to use float16 for both I/O and computation | |||
precision. | |||
* enable_hwcd4 -- | |||
whether to use NHWCD4 data layout. This is faster on some | |||
OpenCL backend. | |||
* enable_nchw88 -- | |||
whether to use NCHW88 data layout, currently | |||
used in X86 AVX backend. | |||
* enable_nchw44 -- | |||
whether to use NCHW44 data layout, currently | |||
used in arm backend. | |||
* enable_nchw44_dot -- | |||
whether to use NCHW44_dot data layout, currently | |||
used in armv8.2+dotprod backend. | |||
* enable_nchw4 -- | |||
whether to use NCHW4 data layout, currently | |||
used in nvidia backend(based on cudnn). | |||
* enable_nchw32 -- | |||
whether to use NCHW32 data layout, currently | |||
used in nvidia backend with tensorcore(based on cudnn). | |||
* enable_chwn4 -- | |||
whether to use CHWN4 data layout, currently | |||
used in nvidia backend with tensorcore. | |||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. | |||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
input for inference on nvidia backend(this optimization pass will | |||
result in mismatch of the precision of output of training and | |||
inference) | |||
""" | |||
if not self._capture_as_const: | |||
raise ValueError( | |||
"you must specify capture_as_const=True at __init__ to use dump" | |||
@@ -482,8 +557,11 @@ class trace: | |||
v.name = output_names[i] | |||
dest_vars.append(v) | |||
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
if isinstance(file, str): | |||
file = open(file, "wb") | |||
permission = "wb" if append == False else "ab" | |||
file = open(file, permission) | |||
file.write(G.dump(*dest_vars)) | |||
def _process_inputs(self, *args, **kwargs): | |||
@@ -20,12 +20,17 @@ | |||
#include "./helper.h" | |||
#include "megbrain/plugin/profiler.h" | |||
#include "./common.h" | |||
#include "megbrain/gopt/inference.h" | |||
namespace py = pybind11; | |||
using namespace mgb; | |||
using namespace imperative; | |||
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | |||
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | |||
namespace { | |||
class _CompGraphProfilerImpl { | |||
std::shared_ptr<ComputingGraph> m_comp_graph; | |||
@@ -138,6 +143,37 @@ void init_graph_rt(py::module m) { | |||
return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size()); | |||
}); | |||
auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") | |||
.def(py::init()) | |||
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) | |||
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) | |||
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) | |||
.def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z) | |||
.def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform) | |||
; | |||
py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform") | |||
.value("DEFAULT", _LayoutTransform::DEFAULT) | |||
.value("NCHW4", _LayoutTransform::NCHW4) | |||
.value("NHWCD4", _LayoutTransform::NHWCD4) | |||
.value("NCHW88", _LayoutTransform::NCHW88) | |||
.value("NCHW44", _LayoutTransform::NCHW44) | |||
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) | |||
.value("NCHW32", _LayoutTransform::NCHW32) | |||
.value("CHWN4", _LayoutTransform::CHWN4) | |||
.export_values() | |||
; | |||
m.def("optimize_for_inference", [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { | |||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt); | |||
VarNodeArray vars; | |||
for (auto& si: res_symvars) | |||
vars.push_back(si.node()); | |||
return vars; | |||
}); | |||
#define CURRENT_CLASS cg::ComputingGraph::Options | |||
auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options") | |||
@@ -1,29 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import pytest | |||
from megengine.core import Tensor | |||
# from megengine.core.interpreter.hints import function | |||
@pytest.mark.skip(reason="under rewrite") | |||
def test_1(): | |||
@function | |||
def f(x, p): | |||
x = x + 1 | |||
if p: | |||
return x * x | |||
return x * 2 | |||
x = Tensor(0) | |||
for _ in range(5): | |||
assert f(x, 0).numpy() == 2 | |||
assert f(x, 1).numpy() == 1 |
@@ -1,10 +1,23 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import io | |||
from tempfile import mkstemp | |||
import numpy as np | |||
import pytest | |||
from megengine import tensor | |||
from megengine.core.ops import builtin as ops | |||
from megengine.core.tensor import megbrain_graph as G | |||
from megengine.core.tensor.core import apply | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.functional import exp, log | |||
from megengine.jit import exclude_from_trace, trace | |||
@@ -101,3 +114,85 @@ def test_trace_profiler(): | |||
out = f.get_profile() | |||
assert out.get("profiler") | |||
@pytest.mark.skip(reason="eq_to_unit failed in inplace.cpp") | |||
def test_goptions_div_zero(): | |||
@trace(symbolic=True, opt_level=0) | |||
def f(x): | |||
return x / x | |||
@trace(symbolic=True, opt_level=1) | |||
def g(x): | |||
return x / x | |||
out = f(tensor(0.0)) | |||
if out == out: | |||
raise ValueError("actual result should be nan") | |||
out = g(tensor(0.0)) | |||
if out != out: | |||
raise ValueError("actual result should be 1") | |||
@pytest.mark.skip(reason="cast to Elemwise failed in inplace.cpp") | |||
def test_goptions_log_exp(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x): | |||
return log(exp(x)) | |||
@trace(symbolic=True, opt_level=1, capture_as_const=True) | |||
def g(x): | |||
return log(exp(x)) | |||
f(tensor(1.0)) | |||
_, out = mkstemp() | |||
f.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
oprs_1 = cgtools.get_oprs_seq(outputs) | |||
g(tensor(1.0)) | |||
g.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
oprs_2 = cgtools.get_oprs_seq(outputs) | |||
assert len(oprs_1) - len(oprs_2) == 2 | |||
@pytest.mark.skip(reason="need cgtools to check final oprs") | |||
def test_goptions_log_sum_exp(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x, y): | |||
return log(exp(x) + exp(y)) | |||
@trace(symbolic=True, opt_level=1, capture_as_const=True) | |||
def g(x, y): | |||
return log(exp(x) + exp(y)) | |||
f(tensor(1.0), tensor(2.0)) | |||
_, out = mkstemp() | |||
f.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
oprs_1 = cgtools.get_oprs_seq(outputs) | |||
g(tensor(1.0), tensor(2.0)) | |||
g.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
oprs_2 = cgtools.get_oprs_seq(outputs) | |||
assert len(oprs_1) - len(oprs_2) == 2 | |||
@pytest.mark.skip(reason="need cgtools to check computing input dtype") | |||
def test_optimize_for_inference(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
return exp(x) | |||
_, out = mkstemp() | |||
f(tensor(5.0)) | |||
f.dump(out, optimize_for_inference=True, optimize_options={"enable_io16xc32": True}) | |||
res = G.load_comp_graph_from_file(out) | |||
computing_input = res.output_vars_list[0].owner.inputs[0] | |||
assert computing_input.dtype == np.float16 |