|
|
@@ -11,7 +11,7 @@ import fnmatch |
|
|
|
import itertools |
|
|
|
import re |
|
|
|
from collections import OrderedDict |
|
|
|
from typing import Dict, List |
|
|
|
from typing import Dict, List, Sequence |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -87,6 +87,58 @@ class Network: |
|
|
|
for o in opr.outputs: |
|
|
|
self.all_vars_map[o.var.id] = o |
|
|
|
|
|
|
|
def optimize_for_inference(self, dest_vars, **kwargs): |
|
|
|
r""" |
|
|
|
Applies optimize_for_inference pass for operator graph. |
|
|
|
|
|
|
|
:param dest_vars: list of output vars in the operator graph |
|
|
|
|
|
|
|
:Keyword Arguments: |
|
|
|
|
|
|
|
* enable_io16xc32 -- |
|
|
|
whether to use float16 for I/O between oprs and use |
|
|
|
float32 as internal computation precision. Note the output var would be |
|
|
|
changed to float16. |
|
|
|
* enable_ioc16 -- |
|
|
|
whether to use float16 for both I/O and computation |
|
|
|
precision. |
|
|
|
|
|
|
|
* enable_hwcd4 -- |
|
|
|
whether to use NHWCD4 data layout. This is faster on some |
|
|
|
OpenCL backend. |
|
|
|
* enable_nchw88 -- |
|
|
|
whether to use NCHW88 data layout, currently |
|
|
|
used in X86 AVX backend. |
|
|
|
* enable_nchw44 -- |
|
|
|
whether to use NCHW44 data layout, currently |
|
|
|
used in arm backend. |
|
|
|
* enable_nchw44_dot -- |
|
|
|
whether to use NCHW44_dot data layout, currently |
|
|
|
used in armv8.2+dotprod backend. |
|
|
|
* enable_nchw4 -- |
|
|
|
whether to use NCHW4 data layout, currently |
|
|
|
used in nvidia backend(based on cudnn). |
|
|
|
* enable_nchw32 -- |
|
|
|
whether to use NCHW32 data layout, currently |
|
|
|
used in nvidia backend with tensorcore(based on cudnn). |
|
|
|
* enable_chwn4 -- |
|
|
|
whether to use CHWN4 data layout, currently |
|
|
|
used in nvidia backend with tensorcore. |
|
|
|
|
|
|
|
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty |
|
|
|
into one opr. |
|
|
|
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z |
|
|
|
input for inference on nvidia backend(this optimization pass will |
|
|
|
result in mismatch of the precision of output of training and |
|
|
|
inference) |
|
|
|
""" |
|
|
|
|
|
|
|
if not isinstance(dest_vars, Sequence): |
|
|
|
dest_vars = [dest_vars] |
|
|
|
dest_vars = list(G.VarNode(var.var) for var in dest_vars) |
|
|
|
new_vars = G.optimize_for_inference(dest_vars, **kwargs) |
|
|
|
return list(self._get_var(var) for var in new_vars) |
|
|
|
|
|
|
|
def dump( |
|
|
|
self, |
|
|
|
file, |
|
|
@@ -126,42 +178,8 @@ class Network: |
|
|
|
|
|
|
|
:Keyword Arguments: |
|
|
|
|
|
|
|
* enable_io16xc32 -- |
|
|
|
whether to use float16 for I/O between oprs and use |
|
|
|
float32 as internal computation precision. Note the output var would be |
|
|
|
changed to float16. |
|
|
|
* enable_ioc16 -- |
|
|
|
whether to use float16 for both I/O and computation |
|
|
|
precision. |
|
|
|
|
|
|
|
* enable_hwcd4 -- |
|
|
|
whether to use NHWCD4 data layout. This is faster on some |
|
|
|
OpenCL backend. |
|
|
|
* enable_nchw88 -- |
|
|
|
whether to use NCHW88 data layout, currently |
|
|
|
used in X86 AVX backend. |
|
|
|
* enable_nchw44 -- |
|
|
|
whether to use NCHW44 data layout, currently |
|
|
|
used in arm backend. |
|
|
|
* enable_nchw44_dot -- |
|
|
|
whether to use NCHW44_dot data layout, currently |
|
|
|
used in armv8.2+dotprod backend. |
|
|
|
* enable_nchw4 -- |
|
|
|
whether to use NCHW4 data layout, currently |
|
|
|
used in nvidia backend(based on cudnn). |
|
|
|
* enable_nchw32 -- |
|
|
|
whether to use NCHW32 data layout, currently |
|
|
|
used in nvidia backend with tensorcore(based on cudnn). |
|
|
|
* enable_chwn4 -- |
|
|
|
whether to use CHWN4 data layout, currently |
|
|
|
used in nvidia backend with tensorcore. |
|
|
|
|
|
|
|
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty |
|
|
|
into one opr. |
|
|
|
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z |
|
|
|
input for inference on nvidia backend(this optimization pass will |
|
|
|
result in mismatch of the precision of output of training and |
|
|
|
inference) |
|
|
|
See also :py:meth:`optimize_for_inference`. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
self._compile() |
|
|
|