@@ -11,13 +11,12 @@ import json | |||
import os | |||
import weakref | |||
from concurrent.futures import ThreadPoolExecutor | |||
from typing import Dict, List, Optional, Tuple, Union | |||
from typing import Dict, List, Tuple, Union | |||
import numpy as np | |||
from .. import _imperative_rt | |||
from .._imperative_rt import GraphOptimizeOptions | |||
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||
from .._imperative_rt import GraphOptimizeOptions, SerializationFormat | |||
from .._wrap import as_device | |||
from ..ops.builtin import OpDef | |||
@@ -377,7 +376,8 @@ def dump_graph( | |||
keep_opr_priority: bool = False, | |||
strip_info_file=None, | |||
append_json=False, | |||
metadata=None | |||
metadata=None, | |||
dump_format=None | |||
) -> Tuple[bytes, CompGraphDumpResult]: | |||
r"""serialize the computing graph of `output_vars` and get byte result. | |||
@@ -398,6 +398,7 @@ def dump_graph( | |||
append_json: will be check when `strip_info_file` is not None. if set | |||
true, the information for code strip will be append to strip_info_file. | |||
if set false, will rewrite strip_info_file | |||
dump_format: using different dump formats. | |||
Note: | |||
The underlying C++ API only accepts a var list. If a dict is given, | |||
@@ -434,6 +435,12 @@ def dump_graph( | |||
outputs = [] | |||
params = [] | |||
dump_format_map = { | |||
None: None, | |||
"FBS": SerializationFormat.FBS, | |||
} | |||
dump_format = dump_format_map[dump_format] | |||
dump_content = _imperative_rt.dump_graph( | |||
ov, | |||
keep_var_name, | |||
@@ -441,6 +448,7 @@ def dump_graph( | |||
keep_param_name, | |||
keep_opr_priority, | |||
metadata, | |||
dump_format, | |||
stat, | |||
inputs, | |||
outputs, | |||
@@ -1008,6 +1008,7 @@ class trace: | |||
maxerr=1e-4, | |||
resize_input=False, | |||
input_transform=None, | |||
dump_format: str = None, | |||
**kwargs | |||
): | |||
r"""Serializes trace to file system. | |||
@@ -1059,6 +1060,7 @@ class trace: | |||
resize_input: whether resize input image to fit input var shape. | |||
input_transform: a python expression to transform the input data. | |||
Example: data / np.std(data) | |||
dump_format: using different dump formats. | |||
Keyword Arguments: | |||
@@ -1265,6 +1267,7 @@ class trace: | |||
strip_info_file=strip_info_file, | |||
append_json=append_json, | |||
metadata=metadata, | |||
dump_format=dump_format, | |||
) | |||
file.write(dump_content) | |||
@@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | |||
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | |||
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
using _SerializationMetadata = mgb::serialization::Metadata; | |||
using _SerializationFormat = mgb::serialization::GraphDumpFormat; | |||
namespace { | |||
class _CompGraphProfilerImpl { | |||
@@ -310,6 +311,10 @@ void init_graph_rt(py::module m) { | |||
.value("NCHW64", _LayoutTransform::NCHW64) | |||
.export_values(); | |||
py::enum_<_SerializationFormat>(m, "SerializationFormat") | |||
.value("FBS", _SerializationFormat::FLATBUFFERS) | |||
.export_values(); | |||
m.def("optimize_for_inference", | |||
[](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { | |||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
@@ -380,11 +385,18 @@ void init_graph_rt(py::module m) { | |||
m.def("dump_graph", | |||
[](const std::vector<VarNode*>& dest_vars, int keep_var_name, | |||
bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, | |||
std::optional<_SerializationMetadata> metadata, py::list& stat, | |||
std::optional<_SerializationMetadata> metadata, | |||
std::optional<_SerializationFormat> dump_format, py::list& stat, | |||
py::list& inputs, py::list& outputs, py::list& params) { | |||
std::vector<uint8_t> buf; | |||
auto dumper = | |||
ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf)); | |||
ser::GraphDumpFormat format; | |||
if (dump_format.has_value()) { | |||
format = dump_format.value(); | |||
} else { | |||
format = {}; | |||
} | |||
auto dumper = ser::GraphDumper::make( | |||
ser::OutputFile::make_vector_proxy(&buf), format); | |||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
ser::GraphDumper::DumpConfig config{ | |||
@@ -190,7 +190,13 @@ def test_print_in_trace(): | |||
np.testing.assert_equal(z, buf) | |||
def test_dump(): | |||
@pytest.mark.parametrize( | |||
"dump_format", | |||
[ | |||
"FBS", | |||
], | |||
) | |||
def test_dump(dump_format): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(a, b): | |||
return a + b | |||
@@ -205,7 +211,7 @@ def test_dump(): | |||
np.testing.assert_equal(f(a, b).numpy(), y) | |||
file = io.BytesIO() | |||
dump_info = f.dump(file) | |||
dump_info = f.dump(file, dump_format=dump_format) | |||
assert dump_info.nr_opr == 3 | |||
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | |||
np.testing.assert_equal(dump_info.outputs, ["ADD"]) | |||