GitOrigin-RevId: b563c94451
release-1.5
@@ -11,7 +11,7 @@ import json | |||||
import os | import os | ||||
import weakref | import weakref | ||||
from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
from typing import Dict, List, Tuple, Union | |||||
from typing import Dict, List, Optional, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs): | |||||
* enable_chwn4 -- | * enable_chwn4 -- | ||||
whether to use CHWN4 data layout, currently | whether to use CHWN4 data layout, currently | ||||
used in nvidia backend with tensorcore. | used in nvidia backend with tensorcore. | ||||
* enable_nchw64 -- | |||||
whether to use NCHW64 data layout, used for fast int4 | |||||
support on Nvidia GPU. | |||||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | ||||
into one opr. | into one opr. | ||||
@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs): | |||||
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | ||||
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | ||||
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | ||||
"enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64, | |||||
} | } | ||||
for k, v in inference_optimize_layout_transform_map.items(): | for k, v in inference_optimize_layout_transform_map.items(): | ||||
@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs): | |||||
dest_vars = _unwrap(dest_vars) | dest_vars = _unwrap(dest_vars) | ||||
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) | res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) | ||||
return _wrap(res_vars) | |||||
return _wrap(res_vars), inference_options.serialize() | |||||
def deserialize_infer_option(x: int) -> Dict[str, bool]: | |||||
r""" | |||||
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``. | |||||
:param x: inference options represented by int. | |||||
:return: inference options represented by dict. | |||||
""" | |||||
inference_options = GraphOptimizeOptions.deserialize(x) | |||||
inference_optimize_layout_transform_map = { | |||||
GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4", | |||||
GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4", | |||||
GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88", | |||||
GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32", | |||||
GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44", | |||||
GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot", | |||||
GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4", | |||||
GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64", | |||||
} | |||||
ret = dict() | |||||
layout = inference_options.layout_transform | |||||
if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT: | |||||
ret[inference_optimize_layout_transform_map[layout]] = True | |||||
if inference_options.f16_io_f32_comp: | |||||
ret["enable_io16xc32"] = True | |||||
if inference_options.f16_io_comp: | |||||
ret["enable_ioc16"] = True | |||||
if inference_options.fuse_conv_bias_nonlinearity: | |||||
ret["enable_fuse_conv_bias_nonlinearity"] = True | |||||
if inference_options.fuse_conv_bias_with_z: | |||||
ret["enable_fuse_conv_bias_with_z"] = True | |||||
return ret | |||||
def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): | def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): | ||||
@@ -331,7 +374,8 @@ def dump_graph( | |||||
keep_param_name: bool = False, | keep_param_name: bool = False, | ||||
keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
strip_info_file=None, | strip_info_file=None, | ||||
append_json=False | |||||
append_json=False, | |||||
metadata=None | |||||
) -> Tuple[bytes, CompGraphDumpResult]: | ) -> Tuple[bytes, CompGraphDumpResult]: | ||||
""" | """ | ||||
serialize the computing graph of `output_vars` and get byte result. | serialize the computing graph of `output_vars` and get byte result. | ||||
@@ -393,6 +437,7 @@ def dump_graph( | |||||
keep_opr_name, | keep_opr_name, | ||||
keep_param_name, | keep_param_name, | ||||
keep_opr_priority, | keep_opr_priority, | ||||
metadata, | |||||
stat, | stat, | ||||
inputs, | inputs, | ||||
outputs, | outputs, | ||||
@@ -427,7 +472,7 @@ def dump_graph( | |||||
CompGraphLoadResult = collections.namedtuple( | CompGraphLoadResult = collections.namedtuple( | ||||
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] | |||||
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"] | |||||
) | ) | ||||
@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult: | |||||
buf = open(fpath, "rb").read() | buf = open(fpath, "rb").read() | ||||
else: | else: | ||||
buf = fpath.read() | buf = fpath.read() | ||||
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) | |||||
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) | |||||
cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) | |||||
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata) | |||||
def _wrap(x): | def _wrap(x): | ||||
@@ -12,10 +12,12 @@ import functools | |||||
import itertools | import itertools | ||||
import json | import json | ||||
import os | import os | ||||
import pickle | |||||
from typing import Any | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt import GraphProfiler | |||||
from ..core._imperative_rt import GraphProfiler, SerializationMetadata | |||||
from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
TensorWeakRef, | TensorWeakRef, | ||||
@@ -670,6 +672,8 @@ class trace: | |||||
strip_info_file=None, | strip_info_file=None, | ||||
append_json=False, | append_json=False, | ||||
optimize_for_inference=True, | optimize_for_inference=True, | ||||
user_info: Any = None, | |||||
enable_metadata: bool = True, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
r""" | r""" | ||||
@@ -697,6 +701,8 @@ class trace: | |||||
if set false, will rewrite strip_info_file | if set false, will rewrite strip_info_file | ||||
:param optimize_for_inference: enbale optmizations, | :param optimize_for_inference: enbale optmizations, | ||||
will skip all optimize options if this is False. Default: True | will skip all optimize options if this is False. Default: True | ||||
:param user_info: any type object, which will be pickled to bytes. | |||||
:param enable_metadata: whether to save metadata into output file. | |||||
:Keyword Arguments: | :Keyword Arguments: | ||||
@@ -729,6 +735,9 @@ class trace: | |||||
* enable_chwn4 -- | * enable_chwn4 -- | ||||
whether to use CHWN4 data layout, currently | whether to use CHWN4 data layout, currently | ||||
used in nvidia backend with tensorcore. | used in nvidia backend with tensorcore. | ||||
* enable_nchw64 -- | |||||
whether to use NCHW64 data layout, used for fast int4 | |||||
support on Nvidia GPU. | |||||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | ||||
into one opr. | into one opr. | ||||
@@ -851,7 +860,15 @@ class trace: | |||||
dest_vars.append(v) | dest_vars.append(v) | ||||
if optimize_for_inference: | if optimize_for_inference: | ||||
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||||
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) | |||||
metadata = SerializationMetadata() | |||||
if enable_metadata: | |||||
metadata.user_info = pickle.dumps(user_info) | |||||
metadata.is_valid = True | |||||
metadata.graph_modified = False | |||||
if optimize_for_inference: | |||||
metadata.optimize_options = optimize_options | |||||
if isinstance(file, str): | if isinstance(file, str): | ||||
permission = "wb" if append == False else "ab" | permission = "wb" if append == False else "ab" | ||||
@@ -864,6 +881,7 @@ class trace: | |||||
keep_opr_priority=keep_opr_priority, | keep_opr_priority=keep_opr_priority, | ||||
strip_info_file=strip_info_file, | strip_info_file=strip_info_file, | ||||
append_json=append_json, | append_json=append_json, | ||||
metadata=metadata, | |||||
) | ) | ||||
file.write(dump_content) | file.write(dump_content) | ||||
return dump_info | return dump_info | ||||
@@ -411,7 +411,8 @@ def main(): | |||||
args.embed_input = True | args.embed_input = True | ||||
logger.info("loading model ...") | logger.info("loading model ...") | ||||
graph, _, output_vars = G.load_graph(args.net) | |||||
ret = G.load_graph(args.net) | |||||
graph, output_vars = ret.graph, ret.output_vars_list | |||||
input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") | input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") | ||||
if args.output_name is not None: | if args.output_name is not None: | ||||
@@ -391,7 +391,8 @@ class GraphInference: | |||||
optimize_for_inference: bool = False, | optimize_for_inference: bool = False, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
self._graph, _, output_nodes = G.load_graph(file) | |||||
ret = G.load_graph(file) | |||||
self._graph, output_nodes = ret.graph, ret.output_vars_list | |||||
if outputs is not None: | if outputs is not None: | ||||
output_nodes = find_vars_by_name(output_nodes, outputs) | output_nodes = find_vars_by_name(output_nodes, outputs) | ||||
self._origin_outputs = output_nodes | self._origin_outputs = output_nodes | ||||
@@ -9,14 +9,12 @@ | |||||
import collections | import collections | ||||
import fnmatch | import fnmatch | ||||
import itertools | import itertools | ||||
import pickle | |||||
import re | import re | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import Dict, List, Sequence | |||||
from typing import Any, Dict, List, Sequence | |||||
import numpy as np | |||||
from ..core._imperative_rt import ComputingGraph | |||||
from ..core._imperative_rt.core2 import SymbolVar | |||||
from ..core._imperative_rt import ComputingGraph, SerializationMetadata | |||||
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
@@ -42,6 +40,30 @@ class Network: | |||||
self.all_oprs_map = OrderedDict() | self.all_oprs_map = OrderedDict() | ||||
self.all_vars_map = OrderedDict() | self.all_vars_map = OrderedDict() | ||||
self.graph = ComputingGraph() | self.graph = ComputingGraph() | ||||
self._metadata = None | |||||
@property | |||||
def metadata(self): | |||||
r""" | |||||
Load metadata as a dict. | |||||
""" | |||||
if not self._metadata.is_valid: | |||||
logger.info("metadata is not valid!") | |||||
return None | |||||
ret = dict() | |||||
try: | |||||
user_info = pickle.loads(self._metadata.user_info) | |||||
except: # pylint: disable=bare-except | |||||
logger.warning( | |||||
"can't parse user info by pickle, so return the original bytes object!" | |||||
) | |||||
user_info = self._metadata.user_info | |||||
ret["user_info"] = user_info | |||||
ret["graph_modified"] = self._metadata.graph_modified | |||||
ret["optimized_for_inference"] = self._metadata.optimized_for_inference | |||||
if ret["optimized_for_inference"]: | |||||
ret.update(G.deserialize_infer_option(self._metadata.optimize_options)) | |||||
return ret | |||||
@classmethod | @classmethod | ||||
def load(cls, model_path: str, outspec: List[str] = None): | def load(cls, model_path: str, outspec: List[str] = None): | ||||
@@ -51,7 +73,8 @@ class Network: | |||||
:param outspec: only load the subgraph with outspec as its endpoints. | :param outspec: only load the subgraph with outspec as its endpoints. | ||||
""" | """ | ||||
self = cls() | self = cls() | ||||
_, _, outputs = G.load_graph(model_path) | |||||
ret = G.load_graph(model_path) | |||||
outputs, self._metadata = ret.output_vars_list, ret.metadata | |||||
if outspec is not None: | if outspec is not None: | ||||
output_spec = outspec.copy() | output_spec = outspec.copy() | ||||
all_vars = get_dep_vars(outputs) + outputs | all_vars = get_dep_vars(outputs) + outputs | ||||
@@ -125,6 +148,9 @@ class Network: | |||||
* enable_chwn4 -- | * enable_chwn4 -- | ||||
whether to use CHWN4 data layout, currently | whether to use CHWN4 data layout, currently | ||||
used in nvidia backend with tensorcore. | used in nvidia backend with tensorcore. | ||||
* enable_nchw64 -- | |||||
whether to use NCHW64 data layout, used for fast int4 | |||||
support on Nvidia GPU. | |||||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | ||||
into one opr. | into one opr. | ||||
@@ -152,6 +178,8 @@ class Network: | |||||
append_json=False, | append_json=False, | ||||
optimize_for_inference=True, | optimize_for_inference=True, | ||||
append=False, | append=False, | ||||
user_info: Any = None, | |||||
enable_metadata=True, | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
""" | """ | ||||
@@ -176,6 +204,8 @@ class Network: | |||||
if set false, will rewrite strip_info_file | if set false, will rewrite strip_info_file | ||||
:param optimize_for_inference: enbale optmizations, | :param optimize_for_inference: enbale optmizations, | ||||
will skip all optimize options if this is False. Default: True | will skip all optimize options if this is False. Default: True | ||||
:param user_info: any type object, which will be pickled to bytes. | |||||
:param enable_metadata: whether to save metadata into output file. | |||||
:Keyword Arguments: | :Keyword Arguments: | ||||
@@ -201,7 +231,15 @@ class Network: | |||||
) | ) | ||||
if optimize_for_inference: | if optimize_for_inference: | ||||
out = G.optimize_for_inference(out, **kwargs) | |||||
out, optimize_options = G.optimize_for_inference(out, **kwargs) | |||||
metadata = SerializationMetadata() | |||||
if enable_metadata: | |||||
metadata.is_valid = True | |||||
metadata.graph_modified = True | |||||
metadata.user_info = pickle.dumps(user_info) | |||||
if optimize_for_inference: | |||||
metadata.optimize_options = optimize_options | |||||
dump_content, _ = G.dump_graph( | dump_content, _ = G.dump_graph( | ||||
out, | out, | ||||
@@ -211,6 +249,7 @@ class Network: | |||||
keep_opr_priority=keep_opr_priority, | keep_opr_priority=keep_opr_priority, | ||||
strip_info_file=strip_info_file, | strip_info_file=strip_info_file, | ||||
append_json=append_json, | append_json=append_json, | ||||
metadata=metadata, | |||||
) | ) | ||||
if isinstance(file, str): | if isinstance(file, str): | ||||
permission = "wb" if append == False else "ab" | permission = "wb" if append == False else "ab" | ||||
@@ -34,6 +34,7 @@ namespace ser = mgb::serialization; | |||||
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | ||||
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | ||||
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
using _SerializationMetadata = mgb::serialization::Metadata; | |||||
namespace { | namespace { | ||||
class _CompGraphProfilerImpl { | class _CompGraphProfilerImpl { | ||||
@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) { | |||||
auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") | auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") | ||||
.def(py::init()) | .def(py::init()) | ||||
.def("serialize", &_OptimizeForInferenceOptions::serialize) | |||||
.def_static("deserialize", &_OptimizeForInferenceOptions::deserialize) | |||||
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) | .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) | ||||
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) | .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) | ||||
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) | .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) | ||||
@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) { | |||||
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) | .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) | ||||
.value("NCHW32", _LayoutTransform::NCHW32) | .value("NCHW32", _LayoutTransform::NCHW32) | ||||
.value("CHWN4", _LayoutTransform::CHWN4) | .value("CHWN4", _LayoutTransform::CHWN4) | ||||
.value("NCHW64", _LayoutTransform::NCHW64) | |||||
.export_values() | .export_values() | ||||
; | ; | ||||
@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) { | |||||
})->to_string(); | })->to_string(); | ||||
}); | }); | ||||
py::class_<_SerializationMetadata>(m, "SerializationMetadata") | |||||
.def(py::init()) | |||||
.def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); }, | |||||
&_SerializationMetadata::set_user_info) | |||||
.def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference) | |||||
.def_property("optimize_options", &_SerializationMetadata::get_optimize_options, | |||||
&_SerializationMetadata::set_optimize_options) | |||||
.def_readwrite("graph_modified", &_SerializationMetadata::graph_modified) | |||||
.def_readwrite("is_valid", &_SerializationMetadata::is_valid) | |||||
; | |||||
m.def("dump_graph", []( | m.def("dump_graph", []( | ||||
const std::vector<VarNode*>& dest_vars, | const std::vector<VarNode*>& dest_vars, | ||||
int keep_var_name, | int keep_var_name, | ||||
bool keep_opr_name, | bool keep_opr_name, | ||||
bool keep_param_name, | bool keep_param_name, | ||||
bool keep_opr_priority, | bool keep_opr_priority, | ||||
std::optional<_SerializationMetadata> metadata, | |||||
py::list& stat, | py::list& stat, | ||||
py::list& inputs, | py::list& inputs, | ||||
py::list& outputs, | py::list& outputs, | ||||
@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) { | |||||
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ||||
keep_opr_priority, keep_opr_name}; | keep_opr_priority, keep_opr_name}; | ||||
auto rst = dumper->dump(symvars, config); | |||||
ser::GraphDumper::DumpResult rst; | |||||
if (metadata) | |||||
rst = dumper->dump(symvars, config, *metadata); | |||||
else | |||||
rst = dumper->dump(symvars, config); | |||||
for (auto i : rst.inputs) { | for (auto i : rst.inputs) { | ||||
inputs.append(py::cast(i)); | inputs.append(py::cast(i)); | ||||
} | } | ||||
@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) { | |||||
for (const auto& var : rst.output_var_list) { | for (const auto& var : rst.output_var_list) { | ||||
iter.add(var); | iter.add(var); | ||||
} | } | ||||
return rst.graph; | |||||
auto ret = py::tuple(2); | |||||
ret[0] = py::cast(rst.graph); | |||||
ret[1] = py::cast(rst.metadata); | |||||
return ret; | |||||
}); | }); | ||||
#define CURRENT_CLASS cg::ComputingGraph::Options | #define CURRENT_CLASS cg::ComputingGraph::Options | ||||
@@ -239,8 +239,7 @@ def test_dump_volatile(): | |||||
file = io.BytesIO() | file = io.BytesIO() | ||||
f.dump(file, optimize_for_inference=False) | f.dump(file, optimize_for_inference=False) | ||||
file.seek(0) | file.seek(0) | ||||
cg, _, outputs = G.load_graph(file) | |||||
(out,) = outputs | |||||
(out,) = G.load_graph(file).output_vars_list | |||||
assert ( | assert ( | ||||
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) | cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) | ||||
== "ImmutableTensor" | == "ImmutableTensor" | ||||
@@ -337,12 +336,12 @@ def test_goptions_log_exp(): | |||||
f(tensor(1.0)) | f(tensor(1.0)) | ||||
_, out = mkstemp() | _, out = mkstemp() | ||||
f.dump(out, optimize_for_inference=False) | f.dump(out, optimize_for_inference=False) | ||||
*_, outputs = G.load_graph(out) | |||||
outputs = G.load_graph(out).output_vars_list | |||||
oprs_1 = cgtools.get_oprs_seq(outputs) | oprs_1 = cgtools.get_oprs_seq(outputs) | ||||
g(tensor(1.0)) | g(tensor(1.0)) | ||||
g.dump(out, optimize_for_inference=False) | g.dump(out, optimize_for_inference=False) | ||||
*_, outputs = G.load_graph(out) | |||||
outputs = G.load_graph(out).output_vars_list | |||||
oprs_2 = cgtools.get_oprs_seq(outputs) | oprs_2 = cgtools.get_oprs_seq(outputs) | ||||
assert len(oprs_1) - len(oprs_2) == 2 | assert len(oprs_1) - len(oprs_2) == 2 | ||||
@@ -88,7 +88,7 @@ def test_graph_traversal(): | |||||
file = io.BytesIO() | file = io.BytesIO() | ||||
fun.dump(file, optimize_for_inference=False) | fun.dump(file, optimize_for_inference=False) | ||||
file.seek(0) | file.seek(0) | ||||
cg, _, outputs = mgb_graph.load_graph(file) | |||||
outputs = mgb_graph.load_graph(file).output_vars_list | |||||
_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) | _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) | ||||
input_var = map_vars[1] | input_var = map_vars[1] | ||||
@@ -101,7 +101,9 @@ def test_load_refcnt(): | |||||
graph = mgb_graph.Graph() | graph = mgb_graph.Graph() | ||||
varnode = graph.make_const(0) | varnode = graph.make_const(0) | ||||
buf, _ = mgb_graph.dump_graph([varnode]) | buf, _ = mgb_graph.dump_graph([varnode]) | ||||
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) | |||||
ret = mgb_graph.load_graph(io.BytesIO(buf)) | |||||
graph, (varnode,) = ret.graph, ret.output_vars_list | |||||
del ret | |||||
del graph | del graph | ||||
varnode.owner | varnode.owner | ||||
@@ -132,7 +134,7 @@ def test_get_opr_seq(): | |||||
file = io.BytesIO() | file = io.BytesIO() | ||||
func.dump(file, optimize_for_inference=False) | func.dump(file, optimize_for_inference=False) | ||||
file.seek(0) | file.seek(0) | ||||
*_, outputs = mgb_graph.load_graph(file) | |||||
outputs = mgb_graph.load_graph(file).output_vars_list | |||||
seq_1 = cgtools.get_oprs_seq(outputs, True) | seq_1 = cgtools.get_oprs_seq(outputs, True) | ||||
assert len(seq_1) == 5 | assert len(seq_1) == 5 | ||||
@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): | |||||
keep_var_name=2, | keep_var_name=2, | ||||
) | ) | ||||
file.seek(0) | file.seek(0) | ||||
*_, outputs = G.load_graph(file) | |||||
outputs = G.load_graph(file).output_vars_list | |||||
ops = cgtools.get_oprs_seq(outputs) | ops = cgtools.get_oprs_seq(outputs) | ||||
return ops | return ops | ||||
@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name): | |||||
file = io.BytesIO() | file = io.BytesIO() | ||||
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) | func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) | ||||
file.seek(0) | file.seek(0) | ||||
*_, outputs = G.load_graph(file) | |||||
outputs = G.load_graph(file).output_vars_list | |||||
op = cgtools.get_oprs_seq(outputs)[-1] | op = cgtools.get_oprs_seq(outputs)[-1] | ||||
assert op.inputs[0].name == var_name | assert op.inputs[0].name == var_name | ||||
@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape | |||||
from megengine.utils.network_node import Host2DeviceCopy, VarNode | from megengine.utils.network_node import Host2DeviceCopy, VarNode | ||||
def test_metadata(): | |||||
x = Tensor(0) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fwd(x): | |||||
return x * 2 | |||||
fwd(x) | |||||
orig_model = io.BytesIO() | |||||
fwd.dump(orig_model, user_info="test", optimize_for_inference=False) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
assert graph.metadata == { | |||||
"user_info": "test", | |||||
"graph_modified": False, # False: tracing.dump | |||||
"optimized_for_inference": False, | |||||
} | |||||
orig_model.seek(0) | |||||
graph.dump( | |||||
orig_model, | |||||
user_info={"str": "x", "tensor": x, "module": M.Module, "none": None}, | |||||
optimize_for_inference=True, | |||||
enable_nchw4=True, | |||||
enable_ioc16=True, | |||||
) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
assert graph.metadata == { | |||||
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None}, | |||||
"graph_modified": True, # True: Network.dump | |||||
"optimized_for_inference": True, | |||||
"enable_nchw4": True, | |||||
"enable_ioc16": True, | |||||
} | |||||
orig_model.seek(0) | |||||
fwd.dump(orig_model, enable_metadata=False) | |||||
orig_model.seek(0) | |||||
graph = Net.load(orig_model) | |||||
assert graph.metadata is None | |||||
def test_replace_var(): | def test_replace_var(): | ||||
a = Tensor([1, 2]) | a = Tensor([1, 2]) | ||||
@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec): | |||||
def make_feeds(args): | def make_feeds(args): | ||||
cg_rt, _, outputs = G.load_graph(args.input) | |||||
ret = G.load_graph(args.input) | |||||
cg_rt, outputs = ret.graph, ret.output_vars_list | |||||
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | ||||
inputs = {i.name: i for i in inputs} | inputs = {i.name: i for i in inputs} | ||||
@@ -322,7 +322,31 @@ namespace gopt { | |||||
static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter(); | static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter(); | ||||
}; | }; | ||||
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {}; | |||||
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | |||||
uint64_t serialize() { | |||||
uint64_t ret = 0; | |||||
ret |= (uint64_t)layout_transform << 32; | |||||
if (f16_io_f32_comp) ret |= 1u; | |||||
if (f16_io_comp) ret |= 1u << 1; | |||||
if (fuse_conv_bias_nonlinearity) ret |= 1u << 2; | |||||
if (fuse_conv_bias_with_z) ret |= 1u << 3; | |||||
if (weight_preprocess) ret |= 1u << 4; | |||||
if (fuse_preprocess) ret |= 1u << 5; | |||||
return ret; | |||||
} | |||||
static OptimizeForInferenceOptions deserialize(uint64_t buf) { | |||||
OptimizeForInferenceOptions ret; | |||||
ret.f16_io_f32_comp = buf & 1u; | |||||
ret.f16_io_comp = buf & 1u << 1; | |||||
ret.fuse_conv_bias_nonlinearity = buf & 1u << 2; | |||||
ret.fuse_conv_bias_with_z = buf & 1u << 3; | |||||
ret.weight_preprocess = buf & 1u << 4; | |||||
ret.fuse_preprocess = buf & 1u << 5; | |||||
ret.layout_transform = (LayoutTransform)(buf >> 32); | |||||
return ret; | |||||
} | |||||
}; | |||||
/*! | /*! | ||||
* \brief optimize a computing graph for inference | * \brief optimize a computing graph for inference | ||||
@@ -128,6 +128,13 @@ table Operator { | |||||
name:string; | name:string; | ||||
} | } | ||||
table Metadata { | |||||
is_valid:bool; | |||||
graph_modified:bool; | |||||
user_info:string; | |||||
optimize_options:ulong; | |||||
} | |||||
struct OutputVar { | struct OutputVar { | ||||
compact_id:uint; | compact_id:uint; | ||||
original_id:uint; | original_id:uint; | ||||
@@ -141,6 +148,7 @@ table Graph { | |||||
nr_shared_tensor:uint; | nr_shared_tensor:uint; | ||||
oprs:[Operator]; | oprs:[Operator]; | ||||
output_vars_idx:[OutputVar]; | output_vars_idx:[OutputVar]; | ||||
metadata:Metadata; | |||||
} | } | ||||
root_type Graph; | root_type Graph; |
@@ -30,6 +30,7 @@ | |||||
#include "megbrain/serialization/internal/flatbuffers_helper.h" | #include "megbrain/serialization/internal/flatbuffers_helper.h" | ||||
#include "megbrain/serialization/internal/schema_generated.h" | #include "megbrain/serialization/internal/schema_generated.h" | ||||
#include "megbrain/serialization/opr_load_dump.h" | #include "megbrain/serialization/opr_load_dump.h" | ||||
#include "megbrain/serialization/metadata.h" | |||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/version.h" | #include "megbrain/version.h" | ||||
@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { | |||||
std::vector<flatbuffers::Offset<void>> m_cur_opr_param; | std::vector<flatbuffers::Offset<void>> m_cur_opr_param; | ||||
void init_oprs_to_dump(const SymbolVarArray& endpoints); | void init_oprs_to_dump(const SymbolVarArray& endpoints); | ||||
flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata); | |||||
flatbuffers::Offset<fbs::Operator> build_single_opr( | flatbuffers::Offset<fbs::Operator> build_single_opr( | ||||
cg::OperatorNodeBase* opr, const OprRegistry* registry); | cg::OperatorNodeBase* opr, const OprRegistry* registry); | ||||
@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { | |||||
public: | public: | ||||
GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} | GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} | ||||
DumpResult dump(const SymbolVarArray& output_vars, | DumpResult dump(const SymbolVarArray& output_vars, | ||||
const DumpConfig& config = {}) override; | |||||
const DumpConfig& config = {}, | |||||
const Metadata& metadata = {}) override; | |||||
const GraphDumpConfig& config() const override { return m_config; } | const GraphDumpConfig& config() const override { return m_config; } | ||||
void dump_tensor(const std::string& name, const HostTensorND& tensor, | void dump_tensor(const std::string& name, const HostTensorND& tensor, | ||||
TensorWriteMethod method) override; | TensorWriteMethod method) override; | ||||
@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) { | |||||
} | } | ||||
} | } | ||||
flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata( | |||||
const Metadata& metadata) { | |||||
auto user_info = m_builder.CreateSharedString(metadata.user_info); | |||||
fbs::MetadataBuilder builder(m_builder); | |||||
builder.add_is_valid(metadata.is_valid); | |||||
builder.add_graph_modified(metadata.graph_modified); | |||||
builder.add_user_info(user_info); | |||||
builder.add_optimize_options(metadata.optimize_options); | |||||
return builder.Finish(); | |||||
} | |||||
flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | ||||
cg::OperatorNodeBase* opr, const OprRegistry* registry) { | cg::OperatorNodeBase* opr, const OprRegistry* registry) { | ||||
m_cur_opr = opr; | m_cur_opr = opr; | ||||
@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||||
} | } | ||||
GraphDumper::DumpResult GraphDumperOSS::dump( | GraphDumper::DumpResult GraphDumperOSS::dump( | ||||
const SymbolVarArray& output_vars, const DumpConfig& config) { | |||||
const SymbolVarArray& output_vars, | |||||
const DumpConfig& config, const Metadata& metadata) { | |||||
mgb_throw_if(output_vars.empty(), SerializationError, | mgb_throw_if(output_vars.empty(), SerializationError, | ||||
"Can't dump empty graph"); | "Can't dump empty graph"); | ||||
@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||||
uint64_t offset_to_fbs = 0; | uint64_t offset_to_fbs = 0; | ||||
m_file->write(&offset_to_fbs, sizeof(offset_to_fbs)); | m_file->write(&offset_to_fbs, sizeof(offset_to_fbs)); | ||||
// Dump metadata | |||||
auto fbmeta = build_metadata(metadata); | |||||
// Dump operators | // Dump operators | ||||
init_oprs_to_dump(output_vars); | init_oprs_to_dump(output_vars); | ||||
std::vector<flatbuffers::Offset<fbs::Operator>> oprs; | std::vector<flatbuffers::Offset<fbs::Operator>> oprs; | ||||
@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||||
graph.add_oprs(fb_oprs); | graph.add_oprs(fb_oprs); | ||||
graph.add_output_vars_idx(fb_output_vars); | graph.add_output_vars_idx(fb_output_vars); | ||||
graph.add_nr_shared_tensor(m_nr_shared_tensor); | graph.add_nr_shared_tensor(m_nr_shared_tensor); | ||||
graph.add_metadata(fbmeta); | |||||
m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier()); | m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier()); | ||||
// Write actual offset_to_fbs | // Write actual offset_to_fbs | ||||
@@ -531,6 +550,7 @@ public: | |||||
mgb_assert(nr == 1); | mgb_assert(nr == 1); | ||||
} | } | ||||
Metadata load_metadata(); | |||||
LoadResult load_oprs(); | LoadResult load_oprs(); | ||||
CompNode load_comp_node(const fbs::CompNode* comp_node); | CompNode load_comp_node(const fbs::CompNode* comp_node); | ||||
@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() { | |||||
return sh_ptr_ref; | return sh_ptr_ref; | ||||
} | } | ||||
Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() { | |||||
const auto* fbmeta = m_loader->m_graph->metadata(); | |||||
Metadata ret; | |||||
ret.is_valid = fbmeta->is_valid(); | |||||
ret.graph_modified = fbmeta->graph_modified(); | |||||
if (fbmeta->user_info()) { | |||||
ret.user_info = fbmeta->user_info()->str(); | |||||
ret.has_user_info = true; | |||||
} | |||||
if (fbmeta->optimize_options()) { | |||||
ret.optimize_options = fbmeta->optimize_options(); | |||||
ret.optimized_for_inference = true; | |||||
} | |||||
return ret; | |||||
} | |||||
void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | ||||
const fbs::Operator* fbopr) { | const fbs::Operator* fbopr) { | ||||
m_cur_opr_tensor_cnt = 0; | m_cur_opr_tensor_cnt = 0; | ||||
@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, | |||||
} | } | ||||
OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | ||||
auto metadata = ctx.load_metadata(); | |||||
auto result = ctx.load_oprs(); | auto result = ctx.load_oprs(); | ||||
result.metadata = metadata; | |||||
auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; | auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; | ||||
auto cur = m_file->tell(); | auto cur = m_file->tell(); | ||||
@@ -0,0 +1,46 @@ | |||||
/** | |||||
* \file src/serialization/include/megbrain/serialization/metadata.h | |||||
* | |||||
* This file is part of MegBrain, a deep learning framework developed by Megvii. | |||||
* | |||||
* \brief MegEngine model's metadata | |||||
* | |||||
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
namespace mgb { | |||||
namespace serialization { | |||||
struct Metadata { | |||||
bool is_valid = false; | |||||
bool graph_modified = false; | |||||
bool has_user_info = false; | |||||
std::string user_info; | |||||
bool optimized_for_inference = false; | |||||
uint64_t optimize_options; | |||||
#define ADD_PROPERTY(type, name) \ | |||||
type get_##name() const { return name; } \ | |||||
void set_##name(type x) { \ | |||||
name = x; \ | |||||
has_##name = true; \ | |||||
} | |||||
ADD_PROPERTY(std::string, user_info) | |||||
#undef ADD_PROPERTY | |||||
uint64_t get_optimize_options() { return optimize_options; } | |||||
void set_optimize_options(uint64_t value) { | |||||
optimized_for_inference = true; | |||||
optimize_options = value; | |||||
} | |||||
}; | |||||
} // namespace serialization | |||||
} // namespace mgb |
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/serialization/dump_format.h" | #include "megbrain/serialization/dump_format.h" | ||||
#include "megbrain/serialization/file.h" | #include "megbrain/serialization/file.h" | ||||
#include "megbrain/serialization/load_dump_config.h" | #include "megbrain/serialization/load_dump_config.h" | ||||
#include "megbrain/serialization/metadata.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace serialization { | namespace serialization { | ||||
@@ -32,6 +33,9 @@ namespace serialization { | |||||
//! expliit dtor decl to reduce binary size | //! expliit dtor decl to reduce binary size | ||||
~LoadResult() noexcept; | ~LoadResult() noexcept; | ||||
//! metadata | |||||
Metadata metadata; | |||||
using TensorMap = std::unordered_map< | using TensorMap = std::unordered_map< | ||||
std::string, std::shared_ptr<HostTensorND>>; | std::string, std::shared_ptr<HostTensorND>>; | ||||
@@ -178,7 +182,8 @@ namespace serialization { | |||||
virtual DumpResult dump( | virtual DumpResult dump( | ||||
const SymbolVarArray &output_vars, | const SymbolVarArray &output_vars, | ||||
const DumpConfig &config = {}) = 0; | |||||
const DumpConfig &config = {}, | |||||
const Metadata &metadata = {}) = 0; | |||||
virtual GraphDumpFormat format() const = 0; | virtual GraphDumpFormat format() const = 0; | ||||
}; | }; | ||||
@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) { | |||||
load(); | load(); | ||||
} | } | ||||
TEST(TestSerializer2, Metadata) { | |||||
auto fname = GET_OUTPUT_FILE(); | |||||
TensorShape shape{2, 3}; | |||||
auto dump = [&]() { | |||||
auto cn = CompNode::load("xpu0"); | |||||
auto host_x = std::make_shared<HostTensorND>(cn, shape), | |||||
host_y = std::make_shared<HostTensorND>(cn, shape); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}), | |||||
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"}); | |||||
using Mode = opr::Elemwise::Mode; | |||||
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"}); | |||||
Metadata metadata; | |||||
metadata.user_info = "TEST_METADATA"; | |||||
metadata.has_user_info = true; | |||||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), | |||||
GraphDumpFormat::FLATBUFFERS); | |||||
auto rst = dumper->dump({z.rename("z")}, {}, metadata); | |||||
}; | |||||
auto load = [&]() { | |||||
HostTensorGenerator<> gen; | |||||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), | |||||
GraphDumpFormat::FLATBUFFERS); | |||||
auto rst = loader->load(); | |||||
auto metadata = rst.metadata; | |||||
int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA"); | |||||
EXPECT_EQ(cmp, 0); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
} | |||||
TEST(TestSerializer2, APlusB) { | TEST(TestSerializer2, APlusB) { | ||||
auto fname = GET_OUTPUT_FILE(); | auto fname = GET_OUTPUT_FILE(); | ||||
TensorShape shape{2, 3}; | TensorShape shape{2, 3}; | ||||