@@ -11,13 +11,12 @@ import json | |||||
import os | import os | ||||
import weakref | import weakref | ||||
from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
from typing import Dict, List, Optional, Tuple, Union | |||||
from typing import Dict, List, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from .. import _imperative_rt | from .. import _imperative_rt | ||||
from .._imperative_rt import GraphOptimizeOptions | |||||
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||||
from .._imperative_rt import GraphOptimizeOptions, SerializationFormat | |||||
from .._wrap import as_device | from .._wrap import as_device | ||||
from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
@@ -377,7 +376,8 @@ def dump_graph( | |||||
keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
strip_info_file=None, | strip_info_file=None, | ||||
append_json=False, | append_json=False, | ||||
metadata=None | |||||
metadata=None, | |||||
dump_format=None | |||||
) -> Tuple[bytes, CompGraphDumpResult]: | ) -> Tuple[bytes, CompGraphDumpResult]: | ||||
r"""serialize the computing graph of `output_vars` and get byte result. | r"""serialize the computing graph of `output_vars` and get byte result. | ||||
@@ -398,6 +398,7 @@ def dump_graph( | |||||
append_json: will be check when `strip_info_file` is not None. if set | append_json: will be check when `strip_info_file` is not None. if set | ||||
true, the information for code strip will be append to strip_info_file. | true, the information for code strip will be append to strip_info_file. | ||||
if set false, will rewrite strip_info_file | if set false, will rewrite strip_info_file | ||||
dump_format: using different dump formats. | |||||
Note: | Note: | ||||
The underlying C++ API only accepts a var list. If a dict is given, | The underlying C++ API only accepts a var list. If a dict is given, | ||||
@@ -434,6 +435,12 @@ def dump_graph( | |||||
outputs = [] | outputs = [] | ||||
params = [] | params = [] | ||||
dump_format_map = { | |||||
None: None, | |||||
"FBS": SerializationFormat.FBS, | |||||
} | |||||
dump_format = dump_format_map[dump_format] | |||||
dump_content = _imperative_rt.dump_graph( | dump_content = _imperative_rt.dump_graph( | ||||
ov, | ov, | ||||
keep_var_name, | keep_var_name, | ||||
@@ -441,6 +448,7 @@ def dump_graph( | |||||
keep_param_name, | keep_param_name, | ||||
keep_opr_priority, | keep_opr_priority, | ||||
metadata, | metadata, | ||||
dump_format, | |||||
stat, | stat, | ||||
inputs, | inputs, | ||||
outputs, | outputs, | ||||
@@ -1008,6 +1008,7 @@ class trace: | |||||
maxerr=1e-4, | maxerr=1e-4, | ||||
resize_input=False, | resize_input=False, | ||||
input_transform=None, | input_transform=None, | ||||
dump_format: str = None, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
r"""Serializes trace to file system. | r"""Serializes trace to file system. | ||||
@@ -1059,6 +1060,7 @@ class trace: | |||||
resize_input: whether resize input image to fit input var shape. | resize_input: whether resize input image to fit input var shape. | ||||
input_transform: a python expression to transform the input data. | input_transform: a python expression to transform the input data. | ||||
Example: data / np.std(data) | Example: data / np.std(data) | ||||
dump_format: using different dump formats. | |||||
Keyword Arguments: | Keyword Arguments: | ||||
@@ -1265,6 +1267,7 @@ class trace: | |||||
strip_info_file=strip_info_file, | strip_info_file=strip_info_file, | ||||
append_json=append_json, | append_json=append_json, | ||||
metadata=metadata, | metadata=metadata, | ||||
dump_format=dump_format, | |||||
) | ) | ||||
file.write(dump_content) | file.write(dump_content) | ||||
@@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | |||||
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | ||||
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
using _SerializationMetadata = mgb::serialization::Metadata; | using _SerializationMetadata = mgb::serialization::Metadata; | ||||
using _SerializationFormat = mgb::serialization::GraphDumpFormat; | |||||
namespace { | namespace { | ||||
class _CompGraphProfilerImpl { | class _CompGraphProfilerImpl { | ||||
@@ -310,6 +311,10 @@ void init_graph_rt(py::module m) { | |||||
.value("NCHW64", _LayoutTransform::NCHW64) | .value("NCHW64", _LayoutTransform::NCHW64) | ||||
.export_values(); | .export_values(); | ||||
py::enum_<_SerializationFormat>(m, "SerializationFormat") | |||||
.value("FBS", _SerializationFormat::FLATBUFFERS) | |||||
.export_values(); | |||||
m.def("optimize_for_inference", | m.def("optimize_for_inference", | ||||
[](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { | [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { | ||||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | ||||
@@ -380,11 +385,18 @@ void init_graph_rt(py::module m) { | |||||
m.def("dump_graph", | m.def("dump_graph", | ||||
[](const std::vector<VarNode*>& dest_vars, int keep_var_name, | [](const std::vector<VarNode*>& dest_vars, int keep_var_name, | ||||
bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, | bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, | ||||
std::optional<_SerializationMetadata> metadata, py::list& stat, | |||||
std::optional<_SerializationMetadata> metadata, | |||||
std::optional<_SerializationFormat> dump_format, py::list& stat, | |||||
py::list& inputs, py::list& outputs, py::list& params) { | py::list& inputs, py::list& outputs, py::list& params) { | ||||
std::vector<uint8_t> buf; | std::vector<uint8_t> buf; | ||||
auto dumper = | |||||
ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf)); | |||||
ser::GraphDumpFormat format; | |||||
if (dump_format.has_value()) { | |||||
format = dump_format.value(); | |||||
} else { | |||||
format = {}; | |||||
} | |||||
auto dumper = ser::GraphDumper::make( | |||||
ser::OutputFile::make_vector_proxy(&buf), format); | |||||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | ||||
ser::GraphDumper::DumpConfig config{ | ser::GraphDumper::DumpConfig config{ | ||||
@@ -190,7 +190,13 @@ def test_print_in_trace(): | |||||
np.testing.assert_equal(z, buf) | np.testing.assert_equal(z, buf) | ||||
def test_dump(): | |||||
@pytest.mark.parametrize( | |||||
"dump_format", | |||||
[ | |||||
"FBS", | |||||
], | |||||
) | |||||
def test_dump(dump_format): | |||||
@trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
def f(a, b): | def f(a, b): | ||||
return a + b | return a + b | ||||
@@ -205,7 +211,7 @@ def test_dump(): | |||||
np.testing.assert_equal(f(a, b).numpy(), y) | np.testing.assert_equal(f(a, b).numpy(), y) | ||||
file = io.BytesIO() | file = io.BytesIO() | ||||
dump_info = f.dump(file) | |||||
dump_info = f.dump(file, dump_format=dump_format) | |||||
assert dump_info.nr_opr == 3 | assert dump_info.nr_opr == 3 | ||||
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | ||||
np.testing.assert_equal(dump_info.outputs, ["ADD"]) | np.testing.assert_equal(dump_info.outputs, ["ADD"]) | ||||