GitOrigin-RevId: fbc0d51c2b
tags/v1.0.0-rc1
@@ -130,32 +130,31 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
inference) | |||
""" | |||
inference_options = GraphOptimizeOptions() | |||
if optimize_for_inference: | |||
inference_optimize_layout_transform_map = { | |||
"enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4, | |||
"enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4, | |||
"enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88, | |||
"enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32, | |||
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | |||
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | |||
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | |||
} | |||
for k, v in inference_optimize_layout_transform_map.items(): | |||
if kwargs.pop(k, False): | |||
inference_options.layout_transform = v | |||
if kwargs.pop("enable_io16xc32", False): | |||
inference_options.f16_io_f32_comp = True | |||
if kwargs.pop("enable_ioc16", False): | |||
inference_options.f16_io_comp = True | |||
if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False): | |||
inference_options.fuse_conv_bias_nonlinearity = True | |||
if kwargs.pop("enable_fuse_conv_bias_with_z", False): | |||
inference_options.fuse_conv_bias_with_z = True | |||
if kwargs: | |||
raise ValueError("unknown options: %s" % list(kwargs)) | |||
inference_optimize_layout_transform_map = { | |||
"enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4, | |||
"enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4, | |||
"enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88, | |||
"enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32, | |||
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | |||
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | |||
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | |||
} | |||
for k, v in inference_optimize_layout_transform_map.items(): | |||
if kwargs.pop(k, False): | |||
inference_options.layout_transform = v | |||
if kwargs.pop("enable_io16xc32", False): | |||
inference_options.f16_io_f32_comp = True | |||
if kwargs.pop("enable_ioc16", False): | |||
inference_options.f16_io_comp = True | |||
if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False): | |||
inference_options.fuse_conv_bias_nonlinearity = True | |||
if kwargs.pop("enable_fuse_conv_bias_with_z", False): | |||
inference_options.fuse_conv_bias_with_z = True | |||
if kwargs: | |||
raise ValueError("unknown options: %s" % list(kwargs)) | |||
res_vars = _imperative_rt.optimize_for_inference( | |||
[i._node for i in dest_vars], inference_options | |||
@@ -458,7 +458,16 @@ class trace: | |||
self._process_outputs(outputs) | |||
return outputs | |||
def dump(self, file, *, arg_names=None, output_names=None, append=False, **kwargs): | |||
def dump( | |||
self, | |||
file, | |||
*, | |||
arg_names=None, | |||
output_names=None, | |||
append=False, | |||
optimize_for_inference=True, | |||
**kwargs | |||
): | |||
r"""Serializes trace to file system. | |||
:param file: output file, could be file object or filename. | |||
@@ -467,6 +476,8 @@ class trace: | |||
use the default name if not specified. | |||
:param append: whether output is appended to ``file``. | |||
Only works when ``file`` is str. | |||
:param optimize_for_inference: enbale optmizations, | |||
will skip all optimize options if this is False. Default: True | |||
:Keyword Arguments: | |||
@@ -572,7 +583,8 @@ class trace: | |||
v.name = output_names[i] | |||
dest_vars.append(v) | |||
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
if optimize_for_inference: | |||
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
if isinstance(file, str): | |||
permission = "wb" if append == False else "ab" | |||
@@ -155,6 +155,9 @@ void init_graph_rt(py::module m) { | |||
}) | |||
.def_property_readonly("id",[](cg::VarNode* v){ | |||
return (v->id()); | |||
}) | |||
.def("__repr__", [](cg::VarNode* v) { | |||
return "Var:" + v->name(); | |||
}); | |||
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") | |||
@@ -175,6 +178,9 @@ void init_graph_rt(py::module m) { | |||
}) | |||
.def_property_readonly("type",[](cg::OperatorNodeBase* opr){ | |||
return opr->dyn_typeinfo()->name; | |||
}) | |||
.def("__repr__", [](cg::OperatorNodeBase* opr){ | |||
return "Opr:" + opr->name(); | |||
}); | |||
@@ -67,7 +67,6 @@ def test_replace_oprs(): | |||
np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25])) | |||
@pytest.mark.skip(reason="Please check opr index") | |||
def test_graph_traversal(): | |||
net = M.Conv2d(3, 32, 3) | |||
@@ -77,11 +76,11 @@ def test_graph_traversal(): | |||
return x | |||
data = np.random.random([1, 3, 224, 224]).astype(np.float32) | |||
for i in range(3): | |||
for _ in range(3): | |||
fun(megengine.tensor(data)) | |||
file = io.BytesIO() | |||
fun.dump(file) | |||
fun.dump(file, optimize_for_inference=False) | |||
file.seek(0) | |||
cg, _, outputs = mgb_graph.load_graph(file) | |||
@@ -13,7 +13,6 @@ import numpy as np | |||
import pytest | |||
import megengine | |||
import megengine.core.tensor.megbrain_graph as G | |||
import megengine.module as M | |||
from megengine import cgtools, tensor | |||
from megengine.core._trace_option import set_tensor_shape | |||
@@ -150,7 +149,6 @@ def test_capture_dump(): | |||
np.testing.assert_equal(result[0], y) | |||
@pytest.mark.skip(reason="get MultipleDeviceTensorHolder instead of SharedDeviceTensor") | |||
def test_dump_volatile(): | |||
p = as_raw_tensor([2]) | |||
@@ -167,7 +165,7 @@ def test_dump_volatile(): | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
file = io.BytesIO() | |||
f.dump(file) | |||
f.dump(file, optimize_for_inference=False) | |||
file.seek(0) | |||
cg, _, outputs = G.load_graph(file) | |||
(out,) = outputs | |||
@@ -196,26 +194,7 @@ def test_trace_profiler(): | |||
assert out.get("profiler") | |||
@pytest.mark.skip(reason="eq_to_unit failed in inplace.cpp") | |||
def test_goptions_div_zero(): | |||
@trace(symbolic=True, opt_level=0) | |||
def f(x): | |||
return x / x | |||
@trace(symbolic=True, opt_level=1) | |||
def g(x): | |||
return x / x | |||
out = f(tensor(0.0)) | |||
if out == out: | |||
raise ValueError("actual result should be nan") | |||
out = g(tensor(0.0)) | |||
if out != out: | |||
raise ValueError("actual result should be 1") | |||
@pytest.mark.skip(reason="cast to Elemwise failed in inplace.cpp") | |||
@pytest.mark.skip(reason="could not disable opt_level") | |||
def test_goptions_log_exp(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x): | |||
@@ -227,19 +206,19 @@ def test_goptions_log_exp(): | |||
f(tensor(1.0)) | |||
_, out = mkstemp() | |||
f.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
f.dump(out, optimize_for_inference=False) | |||
*_, outputs = G.load_graph(out) | |||
oprs_1 = cgtools.get_oprs_seq(outputs) | |||
g(tensor(1.0)) | |||
g.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
g.dump(out, optimize_for_inference=False) | |||
*_, outputs = G.load_graph(out) | |||
oprs_2 = cgtools.get_oprs_seq(outputs) | |||
assert len(oprs_1) - len(oprs_2) == 2 | |||
@pytest.mark.skip(reason="need cgtools to check final oprs") | |||
@pytest.mark.skip(reason="could not disable opt_level") | |||
def test_goptions_log_sum_exp(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x, y): | |||
@@ -251,19 +230,18 @@ def test_goptions_log_sum_exp(): | |||
f(tensor(1.0), tensor(2.0)) | |||
_, out = mkstemp() | |||
f.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
f.dump(out, optimize_for_inference=False) | |||
*_, outputs = G.load_graph(out) | |||
oprs_1 = cgtools.get_oprs_seq(outputs) | |||
g(tensor(1.0), tensor(2.0)) | |||
g.dump(out) | |||
*_, outputs = G.load_comp_graph_from_file(out) | |||
g.dump(out, optimize_for_inference=False) | |||
*_, outputs = G.load_graph(out) | |||
oprs_2 = cgtools.get_oprs_seq(outputs) | |||
assert len(oprs_1) - len(oprs_2) == 2 | |||
@pytest.mark.skip(reason="need cgtools to check computing input dtype") | |||
def test_optimize_for_inference(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
@@ -271,9 +249,9 @@ def test_optimize_for_inference(): | |||
_, out = mkstemp() | |||
f(tensor(5.0)) | |||
f.dump(out, optimize_for_inference=True, optimize_options={"enable_io16xc32": True}) | |||
f.dump(out, enable_io16xc32=True) | |||
res = G.load_comp_graph_from_file(out) | |||
res = G.load_graph(out) | |||
computing_input = res.output_vars_list[0].owner.inputs[0] | |||
assert computing_input.dtype == np.float16 | |||