GitOrigin-RevId: b563c94451
release-1.5
@@ -11,7 +11,7 @@ import json | |||
import os | |||
import weakref | |||
from concurrent.futures import ThreadPoolExecutor | |||
from typing import Dict, List, Tuple, Union | |||
from typing import Dict, List, Optional, Tuple, Union | |||
import numpy as np | |||
@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
* enable_chwn4 -- | |||
whether to use CHWN4 data layout, currently | |||
used in nvidia backend with tensorcore. | |||
* enable_nchw64 -- | |||
whether to use NCHW64 data layout, used for fast int4 | |||
support on Nvidia GPU. | |||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. | |||
@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | |||
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | |||
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | |||
"enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64, | |||
} | |||
for k, v in inference_optimize_layout_transform_map.items(): | |||
@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
dest_vars = _unwrap(dest_vars) | |||
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) | |||
return _wrap(res_vars) | |||
return _wrap(res_vars), inference_options.serialize() | |||
def deserialize_infer_option(x: int) -> Dict[str, bool]: | |||
r""" | |||
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``. | |||
:param x: inference options represented by int. | |||
:return: inference options represented by dict. | |||
""" | |||
inference_options = GraphOptimizeOptions.deserialize(x) | |||
inference_optimize_layout_transform_map = { | |||
GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4", | |||
GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4", | |||
GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88", | |||
GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32", | |||
GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44", | |||
GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot", | |||
GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4", | |||
GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64", | |||
} | |||
ret = dict() | |||
layout = inference_options.layout_transform | |||
if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT: | |||
ret[inference_optimize_layout_transform_map[layout]] = True | |||
if inference_options.f16_io_f32_comp: | |||
ret["enable_io16xc32"] = True | |||
if inference_options.f16_io_comp: | |||
ret["enable_ioc16"] = True | |||
if inference_options.fuse_conv_bias_nonlinearity: | |||
ret["enable_fuse_conv_bias_nonlinearity"] = True | |||
if inference_options.fuse_conv_bias_with_z: | |||
ret["enable_fuse_conv_bias_with_z"] = True | |||
return ret | |||
def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): | |||
@@ -331,7 +374,8 @@ def dump_graph( | |||
keep_param_name: bool = False, | |||
keep_opr_priority: bool = False, | |||
strip_info_file=None, | |||
append_json=False | |||
append_json=False, | |||
metadata=None | |||
) -> Tuple[bytes, CompGraphDumpResult]: | |||
""" | |||
serialize the computing graph of `output_vars` and get byte result. | |||
@@ -393,6 +437,7 @@ def dump_graph( | |||
keep_opr_name, | |||
keep_param_name, | |||
keep_opr_priority, | |||
metadata, | |||
stat, | |||
inputs, | |||
outputs, | |||
@@ -427,7 +472,7 @@ def dump_graph( | |||
CompGraphLoadResult = collections.namedtuple( | |||
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] | |||
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"] | |||
) | |||
@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult: | |||
buf = open(fpath, "rb").read() | |||
else: | |||
buf = fpath.read() | |||
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) | |||
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) | |||
cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) | |||
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata) | |||
def _wrap(x): | |||
@@ -12,10 +12,12 @@ import functools | |||
import itertools | |||
import json | |||
import os | |||
import pickle | |||
from typing import Any | |||
import numpy as np | |||
from ..core._imperative_rt import GraphProfiler | |||
from ..core._imperative_rt import GraphProfiler, SerializationMetadata | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..core._imperative_rt.core2 import ( | |||
TensorWeakRef, | |||
@@ -670,6 +672,8 @@ class trace: | |||
strip_info_file=None, | |||
append_json=False, | |||
optimize_for_inference=True, | |||
user_info: Any = None, | |||
enable_metadata: bool = True, | |||
**kwargs | |||
): | |||
r""" | |||
@@ -697,6 +701,8 @@ class trace: | |||
if set false, will rewrite strip_info_file | |||
:param optimize_for_inference: enbale optmizations, | |||
will skip all optimize options if this is False. Default: True | |||
:param user_info: any type object, which will be pickled to bytes. | |||
:param enable_metadata: whether to save metadata into output file. | |||
:Keyword Arguments: | |||
@@ -729,6 +735,9 @@ class trace: | |||
* enable_chwn4 -- | |||
whether to use CHWN4 data layout, currently | |||
used in nvidia backend with tensorcore. | |||
* enable_nchw64 -- | |||
whether to use NCHW64 data layout, used for fast int4 | |||
support on Nvidia GPU. | |||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. | |||
@@ -851,7 +860,15 @@ class trace: | |||
dest_vars.append(v) | |||
if optimize_for_inference: | |||
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) | |||
metadata = SerializationMetadata() | |||
if enable_metadata: | |||
metadata.user_info = pickle.dumps(user_info) | |||
metadata.is_valid = True | |||
metadata.graph_modified = False | |||
if optimize_for_inference: | |||
metadata.optimize_options = optimize_options | |||
if isinstance(file, str): | |||
permission = "wb" if append == False else "ab" | |||
@@ -864,6 +881,7 @@ class trace: | |||
keep_opr_priority=keep_opr_priority, | |||
strip_info_file=strip_info_file, | |||
append_json=append_json, | |||
metadata=metadata, | |||
) | |||
file.write(dump_content) | |||
return dump_info | |||
@@ -411,7 +411,8 @@ def main(): | |||
args.embed_input = True | |||
logger.info("loading model ...") | |||
graph, _, output_vars = G.load_graph(args.net) | |||
ret = G.load_graph(args.net) | |||
graph, output_vars = ret.graph, ret.output_vars_list | |||
input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") | |||
if args.output_name is not None: | |||
@@ -391,7 +391,8 @@ class GraphInference: | |||
optimize_for_inference: bool = False, | |||
**kwargs | |||
): | |||
self._graph, _, output_nodes = G.load_graph(file) | |||
ret = G.load_graph(file) | |||
self._graph, output_nodes = ret.graph, ret.output_vars_list | |||
if outputs is not None: | |||
output_nodes = find_vars_by_name(output_nodes, outputs) | |||
self._origin_outputs = output_nodes | |||
@@ -9,14 +9,12 @@ | |||
import collections | |||
import fnmatch | |||
import itertools | |||
import pickle | |||
import re | |||
from collections import OrderedDict | |||
from typing import Dict, List, Sequence | |||
from typing import Any, Dict, List, Sequence | |||
import numpy as np | |||
from ..core._imperative_rt import ComputingGraph | |||
from ..core._imperative_rt.core2 import SymbolVar | |||
from ..core._imperative_rt import ComputingGraph, SerializationMetadata | |||
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | |||
from ..core.tensor import megbrain_graph as G | |||
from ..logger import get_logger | |||
@@ -42,6 +40,30 @@ class Network: | |||
self.all_oprs_map = OrderedDict() | |||
self.all_vars_map = OrderedDict() | |||
self.graph = ComputingGraph() | |||
self._metadata = None | |||
@property | |||
def metadata(self): | |||
r""" | |||
Load metadata as a dict. | |||
""" | |||
if not self._metadata.is_valid: | |||
logger.info("metadata is not valid!") | |||
return None | |||
ret = dict() | |||
try: | |||
user_info = pickle.loads(self._metadata.user_info) | |||
except: # pylint: disable=bare-except | |||
logger.warning( | |||
"can't parse user info by pickle, so return the original bytes object!" | |||
) | |||
user_info = self._metadata.user_info | |||
ret["user_info"] = user_info | |||
ret["graph_modified"] = self._metadata.graph_modified | |||
ret["optimized_for_inference"] = self._metadata.optimized_for_inference | |||
if ret["optimized_for_inference"]: | |||
ret.update(G.deserialize_infer_option(self._metadata.optimize_options)) | |||
return ret | |||
@classmethod | |||
def load(cls, model_path: str, outspec: List[str] = None): | |||
@@ -51,7 +73,8 @@ class Network: | |||
:param outspec: only load the subgraph with outspec as its endpoints. | |||
""" | |||
self = cls() | |||
_, _, outputs = G.load_graph(model_path) | |||
ret = G.load_graph(model_path) | |||
outputs, self._metadata = ret.output_vars_list, ret.metadata | |||
if outspec is not None: | |||
output_spec = outspec.copy() | |||
all_vars = get_dep_vars(outputs) + outputs | |||
@@ -125,6 +148,9 @@ class Network: | |||
* enable_chwn4 -- | |||
whether to use CHWN4 data layout, currently | |||
used in nvidia backend with tensorcore. | |||
* enable_nchw64 -- | |||
whether to use NCHW64 data layout, used for fast int4 | |||
support on Nvidia GPU. | |||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. | |||
@@ -152,6 +178,8 @@ class Network: | |||
append_json=False, | |||
optimize_for_inference=True, | |||
append=False, | |||
user_info: Any = None, | |||
enable_metadata=True, | |||
**kwargs | |||
): | |||
""" | |||
@@ -176,6 +204,8 @@ class Network: | |||
if set false, will rewrite strip_info_file | |||
:param optimize_for_inference: enbale optmizations, | |||
will skip all optimize options if this is False. Default: True | |||
:param user_info: any type object, which will be pickled to bytes. | |||
:param enable_metadata: whether to save metadata into output file. | |||
:Keyword Arguments: | |||
@@ -201,7 +231,15 @@ class Network: | |||
) | |||
if optimize_for_inference: | |||
out = G.optimize_for_inference(out, **kwargs) | |||
out, optimize_options = G.optimize_for_inference(out, **kwargs) | |||
metadata = SerializationMetadata() | |||
if enable_metadata: | |||
metadata.is_valid = True | |||
metadata.graph_modified = True | |||
metadata.user_info = pickle.dumps(user_info) | |||
if optimize_for_inference: | |||
metadata.optimize_options = optimize_options | |||
dump_content, _ = G.dump_graph( | |||
out, | |||
@@ -211,6 +249,7 @@ class Network: | |||
keep_opr_priority=keep_opr_priority, | |||
strip_info_file=strip_info_file, | |||
append_json=append_json, | |||
metadata=metadata, | |||
) | |||
if isinstance(file, str): | |||
permission = "wb" if append == False else "ab" | |||
@@ -34,6 +34,7 @@ namespace ser = mgb::serialization; | |||
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | |||
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | |||
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
using _SerializationMetadata = mgb::serialization::Metadata; | |||
namespace { | |||
class _CompGraphProfilerImpl { | |||
@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) { | |||
auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") | |||
.def(py::init()) | |||
.def("serialize", &_OptimizeForInferenceOptions::serialize) | |||
.def_static("deserialize", &_OptimizeForInferenceOptions::deserialize) | |||
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) | |||
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) | |||
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) | |||
@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) { | |||
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) | |||
.value("NCHW32", _LayoutTransform::NCHW32) | |||
.value("CHWN4", _LayoutTransform::CHWN4) | |||
.value("NCHW64", _LayoutTransform::NCHW64) | |||
.export_values() | |||
; | |||
@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) { | |||
})->to_string(); | |||
}); | |||
py::class_<_SerializationMetadata>(m, "SerializationMetadata") | |||
.def(py::init()) | |||
.def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); }, | |||
&_SerializationMetadata::set_user_info) | |||
.def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference) | |||
.def_property("optimize_options", &_SerializationMetadata::get_optimize_options, | |||
&_SerializationMetadata::set_optimize_options) | |||
.def_readwrite("graph_modified", &_SerializationMetadata::graph_modified) | |||
.def_readwrite("is_valid", &_SerializationMetadata::is_valid) | |||
; | |||
m.def("dump_graph", []( | |||
const std::vector<VarNode*>& dest_vars, | |||
int keep_var_name, | |||
bool keep_opr_name, | |||
bool keep_param_name, | |||
bool keep_opr_priority, | |||
std::optional<_SerializationMetadata> metadata, | |||
py::list& stat, | |||
py::list& inputs, | |||
py::list& outputs, | |||
@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) { | |||
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | |||
keep_opr_priority, keep_opr_name}; | |||
auto rst = dumper->dump(symvars, config); | |||
ser::GraphDumper::DumpResult rst; | |||
if (metadata) | |||
rst = dumper->dump(symvars, config, *metadata); | |||
else | |||
rst = dumper->dump(symvars, config); | |||
for (auto i : rst.inputs) { | |||
inputs.append(py::cast(i)); | |||
} | |||
@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) { | |||
for (const auto& var : rst.output_var_list) { | |||
iter.add(var); | |||
} | |||
return rst.graph; | |||
auto ret = py::tuple(2); | |||
ret[0] = py::cast(rst.graph); | |||
ret[1] = py::cast(rst.metadata); | |||
return ret; | |||
}); | |||
#define CURRENT_CLASS cg::ComputingGraph::Options | |||
@@ -239,8 +239,7 @@ def test_dump_volatile(): | |||
file = io.BytesIO() | |||
f.dump(file, optimize_for_inference=False) | |||
file.seek(0) | |||
cg, _, outputs = G.load_graph(file) | |||
(out,) = outputs | |||
(out,) = G.load_graph(file).output_vars_list | |||
assert ( | |||
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) | |||
== "ImmutableTensor" | |||
@@ -337,12 +336,12 @@ def test_goptions_log_exp(): | |||
f(tensor(1.0)) | |||
_, out = mkstemp() | |||
f.dump(out, optimize_for_inference=False) | |||
*_, outputs = G.load_graph(out) | |||
outputs = G.load_graph(out).output_vars_list | |||
oprs_1 = cgtools.get_oprs_seq(outputs) | |||
g(tensor(1.0)) | |||
g.dump(out, optimize_for_inference=False) | |||
*_, outputs = G.load_graph(out) | |||
outputs = G.load_graph(out).output_vars_list | |||
oprs_2 = cgtools.get_oprs_seq(outputs) | |||
assert len(oprs_1) - len(oprs_2) == 2 | |||
@@ -88,7 +88,7 @@ def test_graph_traversal(): | |||
file = io.BytesIO() | |||
fun.dump(file, optimize_for_inference=False) | |||
file.seek(0) | |||
cg, _, outputs = mgb_graph.load_graph(file) | |||
outputs = mgb_graph.load_graph(file).output_vars_list | |||
_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) | |||
input_var = map_vars[1] | |||
@@ -101,7 +101,9 @@ def test_load_refcnt(): | |||
graph = mgb_graph.Graph() | |||
varnode = graph.make_const(0) | |||
buf, _ = mgb_graph.dump_graph([varnode]) | |||
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) | |||
ret = mgb_graph.load_graph(io.BytesIO(buf)) | |||
graph, (varnode,) = ret.graph, ret.output_vars_list | |||
del ret | |||
del graph | |||
varnode.owner | |||
@@ -132,7 +134,7 @@ def test_get_opr_seq(): | |||
file = io.BytesIO() | |||
func.dump(file, optimize_for_inference=False) | |||
file.seek(0) | |||
*_, outputs = mgb_graph.load_graph(file) | |||
outputs = mgb_graph.load_graph(file).output_vars_list | |||
seq_1 = cgtools.get_oprs_seq(outputs, True) | |||
assert len(seq_1) == 5 | |||
@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): | |||
keep_var_name=2, | |||
) | |||
file.seek(0) | |||
*_, outputs = G.load_graph(file) | |||
outputs = G.load_graph(file).output_vars_list | |||
ops = cgtools.get_oprs_seq(outputs) | |||
return ops | |||
@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name): | |||
file = io.BytesIO() | |||
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) | |||
file.seek(0) | |||
*_, outputs = G.load_graph(file) | |||
outputs = G.load_graph(file).output_vars_list | |||
op = cgtools.get_oprs_seq(outputs)[-1] | |||
assert op.inputs[0].name == var_name | |||
@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape | |||
from megengine.utils.network_node import Host2DeviceCopy, VarNode | |||
def test_metadata(): | |||
x = Tensor(0) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x): | |||
return x * 2 | |||
fwd(x) | |||
orig_model = io.BytesIO() | |||
fwd.dump(orig_model, user_info="test", optimize_for_inference=False) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
assert graph.metadata == { | |||
"user_info": "test", | |||
"graph_modified": False, # False: tracing.dump | |||
"optimized_for_inference": False, | |||
} | |||
orig_model.seek(0) | |||
graph.dump( | |||
orig_model, | |||
user_info={"str": "x", "tensor": x, "module": M.Module, "none": None}, | |||
optimize_for_inference=True, | |||
enable_nchw4=True, | |||
enable_ioc16=True, | |||
) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
assert graph.metadata == { | |||
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None}, | |||
"graph_modified": True, # True: Network.dump | |||
"optimized_for_inference": True, | |||
"enable_nchw4": True, | |||
"enable_ioc16": True, | |||
} | |||
orig_model.seek(0) | |||
fwd.dump(orig_model, enable_metadata=False) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
assert graph.metadata is None | |||
def test_replace_var(): | |||
a = Tensor([1, 2]) | |||
@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec): | |||
def make_feeds(args): | |||
cg_rt, _, outputs = G.load_graph(args.input) | |||
ret = G.load_graph(args.input) | |||
cg_rt, outputs = ret.graph, ret.output_vars_list | |||
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||
inputs = {i.name: i for i in inputs} | |||
@@ -322,7 +322,31 @@ namespace gopt { | |||
static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter(); | |||
}; | |||
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {}; | |||
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | |||
uint64_t serialize() { | |||
uint64_t ret = 0; | |||
ret |= (uint64_t)layout_transform << 32; | |||
if (f16_io_f32_comp) ret |= 1u; | |||
if (f16_io_comp) ret |= 1u << 1; | |||
if (fuse_conv_bias_nonlinearity) ret |= 1u << 2; | |||
if (fuse_conv_bias_with_z) ret |= 1u << 3; | |||
if (weight_preprocess) ret |= 1u << 4; | |||
if (fuse_preprocess) ret |= 1u << 5; | |||
return ret; | |||
} | |||
static OptimizeForInferenceOptions deserialize(uint64_t buf) { | |||
OptimizeForInferenceOptions ret; | |||
ret.f16_io_f32_comp = buf & 1u; | |||
ret.f16_io_comp = buf & 1u << 1; | |||
ret.fuse_conv_bias_nonlinearity = buf & 1u << 2; | |||
ret.fuse_conv_bias_with_z = buf & 1u << 3; | |||
ret.weight_preprocess = buf & 1u << 4; | |||
ret.fuse_preprocess = buf & 1u << 5; | |||
ret.layout_transform = (LayoutTransform)(buf >> 32); | |||
return ret; | |||
} | |||
}; | |||
/*! | |||
* \brief optimize a computing graph for inference | |||
@@ -128,6 +128,13 @@ table Operator { | |||
name:string; | |||
} | |||
table Metadata { | |||
is_valid:bool; | |||
graph_modified:bool; | |||
user_info:string; | |||
optimize_options:ulong; | |||
} | |||
struct OutputVar { | |||
compact_id:uint; | |||
original_id:uint; | |||
@@ -141,6 +148,7 @@ table Graph { | |||
nr_shared_tensor:uint; | |||
oprs:[Operator]; | |||
output_vars_idx:[OutputVar]; | |||
metadata:Metadata; | |||
} | |||
root_type Graph; |
@@ -30,6 +30,7 @@ | |||
#include "megbrain/serialization/internal/flatbuffers_helper.h" | |||
#include "megbrain/serialization/internal/schema_generated.h" | |||
#include "megbrain/serialization/opr_load_dump.h" | |||
#include "megbrain/serialization/metadata.h" | |||
#include "megbrain/serialization/serializer.h" | |||
#include "megbrain/version.h" | |||
@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { | |||
std::vector<flatbuffers::Offset<void>> m_cur_opr_param; | |||
void init_oprs_to_dump(const SymbolVarArray& endpoints); | |||
flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata); | |||
flatbuffers::Offset<fbs::Operator> build_single_opr( | |||
cg::OperatorNodeBase* opr, const OprRegistry* registry); | |||
@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { | |||
public: | |||
GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} | |||
DumpResult dump(const SymbolVarArray& output_vars, | |||
const DumpConfig& config = {}) override; | |||
const DumpConfig& config = {}, | |||
const Metadata& metadata = {}) override; | |||
const GraphDumpConfig& config() const override { return m_config; } | |||
void dump_tensor(const std::string& name, const HostTensorND& tensor, | |||
TensorWriteMethod method) override; | |||
@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) { | |||
} | |||
} | |||
flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata( | |||
const Metadata& metadata) { | |||
auto user_info = m_builder.CreateSharedString(metadata.user_info); | |||
fbs::MetadataBuilder builder(m_builder); | |||
builder.add_is_valid(metadata.is_valid); | |||
builder.add_graph_modified(metadata.graph_modified); | |||
builder.add_user_info(user_info); | |||
builder.add_optimize_options(metadata.optimize_options); | |||
return builder.Finish(); | |||
} | |||
flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||
cg::OperatorNodeBase* opr, const OprRegistry* registry) { | |||
m_cur_opr = opr; | |||
@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||
} | |||
GraphDumper::DumpResult GraphDumperOSS::dump( | |||
const SymbolVarArray& output_vars, const DumpConfig& config) { | |||
const SymbolVarArray& output_vars, | |||
const DumpConfig& config, const Metadata& metadata) { | |||
mgb_throw_if(output_vars.empty(), SerializationError, | |||
"Can't dump empty graph"); | |||
@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||
uint64_t offset_to_fbs = 0; | |||
m_file->write(&offset_to_fbs, sizeof(offset_to_fbs)); | |||
// Dump metadata | |||
auto fbmeta = build_metadata(metadata); | |||
// Dump operators | |||
init_oprs_to_dump(output_vars); | |||
std::vector<flatbuffers::Offset<fbs::Operator>> oprs; | |||
@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||
graph.add_oprs(fb_oprs); | |||
graph.add_output_vars_idx(fb_output_vars); | |||
graph.add_nr_shared_tensor(m_nr_shared_tensor); | |||
graph.add_metadata(fbmeta); | |||
m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier()); | |||
// Write actual offset_to_fbs | |||
@@ -531,6 +550,7 @@ public: | |||
mgb_assert(nr == 1); | |||
} | |||
Metadata load_metadata(); | |||
LoadResult load_oprs(); | |||
CompNode load_comp_node(const fbs::CompNode* comp_node); | |||
@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() { | |||
return sh_ptr_ref; | |||
} | |||
Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() { | |||
const auto* fbmeta = m_loader->m_graph->metadata(); | |||
Metadata ret; | |||
ret.is_valid = fbmeta->is_valid(); | |||
ret.graph_modified = fbmeta->graph_modified(); | |||
if (fbmeta->user_info()) { | |||
ret.user_info = fbmeta->user_info()->str(); | |||
ret.has_user_info = true; | |||
} | |||
if (fbmeta->optimize_options()) { | |||
ret.optimize_options = fbmeta->optimize_options(); | |||
ret.optimized_for_inference = true; | |||
} | |||
return ret; | |||
} | |||
void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||
const fbs::Operator* fbopr) { | |||
m_cur_opr_tensor_cnt = 0; | |||
@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, | |||
} | |||
OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | |||
auto metadata = ctx.load_metadata(); | |||
auto result = ctx.load_oprs(); | |||
result.metadata = metadata; | |||
auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; | |||
auto cur = m_file->tell(); | |||
@@ -0,0 +1,46 @@ | |||
/** | |||
* \file src/serialization/include/megbrain/serialization/metadata.h | |||
* | |||
* This file is part of MegBrain, a deep learning framework developed by Megvii. | |||
* | |||
* \brief MegEngine model's metadata | |||
* | |||
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
*/ | |||
#pragma once | |||
#include <string> | |||
namespace mgb { | |||
namespace serialization { | |||
struct Metadata { | |||
bool is_valid = false; | |||
bool graph_modified = false; | |||
bool has_user_info = false; | |||
std::string user_info; | |||
bool optimized_for_inference = false; | |||
uint64_t optimize_options; | |||
#define ADD_PROPERTY(type, name) \ | |||
type get_##name() const { return name; } \ | |||
void set_##name(type x) { \ | |||
name = x; \ | |||
has_##name = true; \ | |||
} | |||
ADD_PROPERTY(std::string, user_info) | |||
#undef ADD_PROPERTY | |||
uint64_t get_optimize_options() { return optimize_options; } | |||
void set_optimize_options(uint64_t value) { | |||
optimized_for_inference = true; | |||
optimize_options = value; | |||
} | |||
}; | |||
} // namespace serialization | |||
} // namespace mgb |
@@ -15,6 +15,7 @@ | |||
#include "megbrain/serialization/dump_format.h" | |||
#include "megbrain/serialization/file.h" | |||
#include "megbrain/serialization/load_dump_config.h" | |||
#include "megbrain/serialization/metadata.h" | |||
namespace mgb { | |||
namespace serialization { | |||
@@ -32,6 +33,9 @@ namespace serialization { | |||
//! expliit dtor decl to reduce binary size | |||
~LoadResult() noexcept; | |||
//! metadata | |||
Metadata metadata; | |||
using TensorMap = std::unordered_map< | |||
std::string, std::shared_ptr<HostTensorND>>; | |||
@@ -178,7 +182,8 @@ namespace serialization { | |||
virtual DumpResult dump( | |||
const SymbolVarArray &output_vars, | |||
const DumpConfig &config = {}) = 0; | |||
const DumpConfig &config = {}, | |||
const Metadata &metadata = {}) = 0; | |||
virtual GraphDumpFormat format() const = 0; | |||
}; | |||
@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) { | |||
load(); | |||
} | |||
TEST(TestSerializer2, Metadata) { | |||
auto fname = GET_OUTPUT_FILE(); | |||
TensorShape shape{2, 3}; | |||
auto dump = [&]() { | |||
auto cn = CompNode::load("xpu0"); | |||
auto host_x = std::make_shared<HostTensorND>(cn, shape), | |||
host_y = std::make_shared<HostTensorND>(cn, shape); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}), | |||
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"}); | |||
using Mode = opr::Elemwise::Mode; | |||
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"}); | |||
Metadata metadata; | |||
metadata.user_info = "TEST_METADATA"; | |||
metadata.has_user_info = true; | |||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), | |||
GraphDumpFormat::FLATBUFFERS); | |||
auto rst = dumper->dump({z.rename("z")}, {}, metadata); | |||
}; | |||
auto load = [&]() { | |||
HostTensorGenerator<> gen; | |||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), | |||
GraphDumpFormat::FLATBUFFERS); | |||
auto rst = loader->load(); | |||
auto metadata = rst.metadata; | |||
int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA"); | |||
EXPECT_EQ(cmp, 0); | |||
}; | |||
dump(); | |||
load(); | |||
} | |||
TEST(TestSerializer2, APlusB) { | |||
auto fname = GET_OUTPUT_FILE(); | |||
TensorShape shape{2, 3}; | |||