Browse Source

feat(src/serialization): add support of serializing metadata

GitOrigin-RevId: b563c94451
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
54a4d70eb5
17 changed files with 365 additions and 34 deletions
  1. +51
    -6
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +20
    -2
      imperative/python/megengine/jit/tracing.py
  3. +2
    -1
      imperative/python/megengine/tools/load_network_and_run.py
  4. +2
    -1
      imperative/python/megengine/utils/comp_graph_tools.py
  5. +46
    -7
      imperative/python/megengine/utils/network.py
  6. +26
    -3
      imperative/python/src/graph_rt.cpp
  7. +3
    -4
      imperative/python/test/unit/jit/test_tracing.py
  8. +5
    -3
      imperative/python/test/unit/utils/test_cgtools.py
  9. +2
    -2
      imperative/python/test/unit/utils/test_dump_naming.py
  10. +44
    -0
      imperative/python/test/unit/utils/test_network.py
  11. +2
    -1
      sdk/load-and-run/dump_with_testcase_mge.py
  12. +25
    -1
      src/gopt/include/megbrain/gopt/inference.h
  13. +8
    -0
      src/serialization/impl/schema.fbs
  14. +40
    -2
      src/serialization/impl/serializer_oss.cpp
  15. +46
    -0
      src/serialization/include/megbrain/serialization/metadata.h
  16. +6
    -1
      src/serialization/include/megbrain/serialization/serializer.h
  17. +37
    -0
      src/serialization/test/serializer_oss.cpp

+ 51
- 6
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -11,7 +11,7 @@ import json
import os import os
import weakref import weakref
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union


import numpy as np import numpy as np


@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs):
* enable_chwn4 -- * enable_chwn4 --
whether to use CHWN4 data layout, currently whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore. used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.


* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. into one opr.
@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs):
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
"enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64,
} }


for k, v in inference_optimize_layout_transform_map.items(): for k, v in inference_optimize_layout_transform_map.items():
@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs):


dest_vars = _unwrap(dest_vars) dest_vars = _unwrap(dest_vars)
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return _wrap(res_vars)
return _wrap(res_vars), inference_options.serialize()


def deserialize_infer_option(x: int) -> Dict[str, bool]:
r"""
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.

:param x: inference options represented by int.
:return: inference options represented by dict.
"""

inference_options = GraphOptimizeOptions.deserialize(x)

inference_optimize_layout_transform_map = {
GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4",
GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4",
GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88",
GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32",
GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44",
GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot",
GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4",
GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64",
}

ret = dict()

layout = inference_options.layout_transform
if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT:
ret[inference_optimize_layout_transform_map[layout]] = True

if inference_options.f16_io_f32_comp:
ret["enable_io16xc32"] = True
if inference_options.f16_io_comp:
ret["enable_ioc16"] = True
if inference_options.fuse_conv_bias_nonlinearity:
ret["enable_fuse_conv_bias_nonlinearity"] = True
if inference_options.fuse_conv_bias_with_z:
ret["enable_fuse_conv_bias_with_z"] = True

return ret




def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): def modify_opr_algo_strategy_inplace(dest_vars, strategy: str):
@@ -331,7 +374,8 @@ def dump_graph(
keep_param_name: bool = False, keep_param_name: bool = False,
keep_opr_priority: bool = False, keep_opr_priority: bool = False,
strip_info_file=None, strip_info_file=None,
append_json=False
append_json=False,
metadata=None
) -> Tuple[bytes, CompGraphDumpResult]: ) -> Tuple[bytes, CompGraphDumpResult]:
""" """
serialize the computing graph of `output_vars` and get byte result. serialize the computing graph of `output_vars` and get byte result.
@@ -393,6 +437,7 @@ def dump_graph(
keep_opr_name, keep_opr_name,
keep_param_name, keep_param_name,
keep_opr_priority, keep_opr_priority,
metadata,
stat, stat,
inputs, inputs,
outputs, outputs,
@@ -427,7 +472,7 @@ def dump_graph(




CompGraphLoadResult = collections.namedtuple( CompGraphLoadResult = collections.namedtuple(
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"]
) )




@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult:
buf = open(fpath, "rb").read() buf = open(fpath, "rb").read()
else: else:
buf = fpath.read() buf = fpath.read()
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list)
cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata)




def _wrap(x): def _wrap(x):


+ 20
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -12,10 +12,12 @@ import functools
import itertools import itertools
import json import json
import os import os
import pickle
from typing import Any


import numpy as np import numpy as np


from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt import GraphProfiler, SerializationMetadata
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
TensorWeakRef, TensorWeakRef,
@@ -670,6 +672,8 @@ class trace:
strip_info_file=None, strip_info_file=None,
append_json=False, append_json=False,
optimize_for_inference=True, optimize_for_inference=True,
user_info: Any = None,
enable_metadata: bool = True,
**kwargs **kwargs
): ):
r""" r"""
@@ -697,6 +701,8 @@ class trace:
if set false, will rewrite strip_info_file if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations, :param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.


:Keyword Arguments: :Keyword Arguments:


@@ -729,6 +735,9 @@ class trace:
* enable_chwn4 -- * enable_chwn4 --
whether to use CHWN4 data layout, currently whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore. used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.


* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. into one opr.
@@ -851,7 +860,15 @@ class trace:
dest_vars.append(v) dest_vars.append(v)


if optimize_for_inference: if optimize_for_inference:
dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)

metadata = SerializationMetadata()
if enable_metadata:
metadata.user_info = pickle.dumps(user_info)
metadata.is_valid = True
metadata.graph_modified = False
if optimize_for_inference:
metadata.optimize_options = optimize_options


if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"
@@ -864,6 +881,7 @@ class trace:
keep_opr_priority=keep_opr_priority, keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file, strip_info_file=strip_info_file,
append_json=append_json, append_json=append_json,
metadata=metadata,
) )
file.write(dump_content) file.write(dump_content)
return dump_info return dump_info


+ 2
- 1
imperative/python/megengine/tools/load_network_and_run.py View File

@@ -411,7 +411,8 @@ def main():
args.embed_input = True args.embed_input = True


logger.info("loading model ...") logger.info("loading model ...")
graph, _, output_vars = G.load_graph(args.net)
ret = G.load_graph(args.net)
graph, output_vars = ret.graph, ret.output_vars_list
input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy")


if args.output_name is not None: if args.output_name is not None:


+ 2
- 1
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -391,7 +391,8 @@ class GraphInference:
optimize_for_inference: bool = False, optimize_for_inference: bool = False,
**kwargs **kwargs
): ):
self._graph, _, output_nodes = G.load_graph(file)
ret = G.load_graph(file)
self._graph, output_nodes = ret.graph, ret.output_vars_list
if outputs is not None: if outputs is not None:
output_nodes = find_vars_by_name(output_nodes, outputs) output_nodes = find_vars_by_name(output_nodes, outputs)
self._origin_outputs = output_nodes self._origin_outputs = output_nodes


+ 46
- 7
imperative/python/megengine/utils/network.py View File

@@ -9,14 +9,12 @@
import collections import collections
import fnmatch import fnmatch
import itertools import itertools
import pickle
import re import re
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Sequence
from typing import Any, Dict, List, Sequence


import numpy as np

from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core._imperative_rt import ComputingGraph, SerializationMetadata
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..logger import get_logger from ..logger import get_logger
@@ -42,6 +40,30 @@ class Network:
self.all_oprs_map = OrderedDict() self.all_oprs_map = OrderedDict()
self.all_vars_map = OrderedDict() self.all_vars_map = OrderedDict()
self.graph = ComputingGraph() self.graph = ComputingGraph()
self._metadata = None

@property
def metadata(self):
r"""
Load metadata as a dict.
"""
if not self._metadata.is_valid:
logger.info("metadata is not valid!")
return None
ret = dict()
try:
user_info = pickle.loads(self._metadata.user_info)
except: # pylint: disable=bare-except
logger.warning(
"can't parse user info by pickle, so return the original bytes object!"
)
user_info = self._metadata.user_info
ret["user_info"] = user_info
ret["graph_modified"] = self._metadata.graph_modified
ret["optimized_for_inference"] = self._metadata.optimized_for_inference
if ret["optimized_for_inference"]:
ret.update(G.deserialize_infer_option(self._metadata.optimize_options))
return ret


@classmethod @classmethod
def load(cls, model_path: str, outspec: List[str] = None): def load(cls, model_path: str, outspec: List[str] = None):
@@ -51,7 +73,8 @@ class Network:
:param outspec: only load the subgraph with outspec as its endpoints. :param outspec: only load the subgraph with outspec as its endpoints.
""" """
self = cls() self = cls()
_, _, outputs = G.load_graph(model_path)
ret = G.load_graph(model_path)
outputs, self._metadata = ret.output_vars_list, ret.metadata
if outspec is not None: if outspec is not None:
output_spec = outspec.copy() output_spec = outspec.copy()
all_vars = get_dep_vars(outputs) + outputs all_vars = get_dep_vars(outputs) + outputs
@@ -125,6 +148,9 @@ class Network:
* enable_chwn4 -- * enable_chwn4 --
whether to use CHWN4 data layout, currently whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore. used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.


* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. into one opr.
@@ -152,6 +178,8 @@ class Network:
append_json=False, append_json=False,
optimize_for_inference=True, optimize_for_inference=True,
append=False, append=False,
user_info: Any = None,
enable_metadata=True,
**kwargs **kwargs
): ):
""" """
@@ -176,6 +204,8 @@ class Network:
if set false, will rewrite strip_info_file if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations, :param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.


:Keyword Arguments: :Keyword Arguments:


@@ -201,7 +231,15 @@ class Network:
) )


if optimize_for_inference: if optimize_for_inference:
out = G.optimize_for_inference(out, **kwargs)
out, optimize_options = G.optimize_for_inference(out, **kwargs)

metadata = SerializationMetadata()
if enable_metadata:
metadata.is_valid = True
metadata.graph_modified = True
metadata.user_info = pickle.dumps(user_info)
if optimize_for_inference:
metadata.optimize_options = optimize_options


dump_content, _ = G.dump_graph( dump_content, _ = G.dump_graph(
out, out,
@@ -211,6 +249,7 @@ class Network:
keep_opr_priority=keep_opr_priority, keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file, strip_info_file=strip_info_file,
append_json=append_json, append_json=append_json,
metadata=metadata,
) )
if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"


+ 26
- 3
imperative/python/src/graph_rt.cpp View File

@@ -34,6 +34,7 @@ namespace ser = mgb::serialization;
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
using _SerializationMetadata = mgb::serialization::Metadata;


namespace { namespace {
class _CompGraphProfilerImpl { class _CompGraphProfilerImpl {
@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) {


auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
.def(py::init()) .def(py::init())
.def("serialize", &_OptimizeForInferenceOptions::serialize)
.def_static("deserialize", &_OptimizeForInferenceOptions::deserialize)
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp)
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) {
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
.value("NCHW32", _LayoutTransform::NCHW32) .value("NCHW32", _LayoutTransform::NCHW32)
.value("CHWN4", _LayoutTransform::CHWN4) .value("CHWN4", _LayoutTransform::CHWN4)
.value("NCHW64", _LayoutTransform::NCHW64)
.export_values() .export_values()
; ;


@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) {
})->to_string(); })->to_string();
}); });


py::class_<_SerializationMetadata>(m, "SerializationMetadata")
.def(py::init())
.def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); },
&_SerializationMetadata::set_user_info)
.def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference)
.def_property("optimize_options", &_SerializationMetadata::get_optimize_options,
&_SerializationMetadata::set_optimize_options)
.def_readwrite("graph_modified", &_SerializationMetadata::graph_modified)
.def_readwrite("is_valid", &_SerializationMetadata::is_valid)
;

m.def("dump_graph", []( m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars, const std::vector<VarNode*>& dest_vars,
int keep_var_name, int keep_var_name,
bool keep_opr_name, bool keep_opr_name,
bool keep_param_name, bool keep_param_name,
bool keep_opr_priority, bool keep_opr_priority,
std::optional<_SerializationMetadata> metadata,
py::list& stat, py::list& stat,
py::list& inputs, py::list& inputs,
py::list& outputs, py::list& outputs,
@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) {
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority, keep_opr_name}; keep_opr_priority, keep_opr_name};


auto rst = dumper->dump(symvars, config);
ser::GraphDumper::DumpResult rst;
if (metadata)
rst = dumper->dump(symvars, config, *metadata);
else
rst = dumper->dump(symvars, config);

for (auto i : rst.inputs) { for (auto i : rst.inputs) {
inputs.append(py::cast(i)); inputs.append(py::cast(i));
} }
@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) {
for (const auto& var : rst.output_var_list) { for (const auto& var : rst.output_var_list) {
iter.add(var); iter.add(var);
} }
return rst.graph;

auto ret = py::tuple(2);
ret[0] = py::cast(rst.graph);
ret[1] = py::cast(rst.metadata);
return ret;
}); });


#define CURRENT_CLASS cg::ComputingGraph::Options #define CURRENT_CLASS cg::ComputingGraph::Options


+ 3
- 4
imperative/python/test/unit/jit/test_tracing.py View File

@@ -239,8 +239,7 @@ def test_dump_volatile():
file = io.BytesIO() file = io.BytesIO()
f.dump(file, optimize_for_inference=False) f.dump(file, optimize_for_inference=False)
file.seek(0) file.seek(0)
cg, _, outputs = G.load_graph(file)
(out,) = outputs
(out,) = G.load_graph(file).output_vars_list
assert ( assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "ImmutableTensor" == "ImmutableTensor"
@@ -337,12 +336,12 @@ def test_goptions_log_exp():
f(tensor(1.0)) f(tensor(1.0))
_, out = mkstemp() _, out = mkstemp()
f.dump(out, optimize_for_inference=False) f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
outputs = G.load_graph(out).output_vars_list
oprs_1 = cgtools.get_oprs_seq(outputs) oprs_1 = cgtools.get_oprs_seq(outputs)


g(tensor(1.0)) g(tensor(1.0))
g.dump(out, optimize_for_inference=False) g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
outputs = G.load_graph(out).output_vars_list
oprs_2 = cgtools.get_oprs_seq(outputs) oprs_2 = cgtools.get_oprs_seq(outputs)


assert len(oprs_1) - len(oprs_2) == 2 assert len(oprs_1) - len(oprs_2) == 2


+ 5
- 3
imperative/python/test/unit/utils/test_cgtools.py View File

@@ -88,7 +88,7 @@ def test_graph_traversal():
file = io.BytesIO() file = io.BytesIO()
fun.dump(file, optimize_for_inference=False) fun.dump(file, optimize_for_inference=False)
file.seek(0) file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
outputs = mgb_graph.load_graph(file).output_vars_list


_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
input_var = map_vars[1] input_var = map_vars[1]
@@ -101,7 +101,9 @@ def test_load_refcnt():
graph = mgb_graph.Graph() graph = mgb_graph.Graph()
varnode = graph.make_const(0) varnode = graph.make_const(0)
buf, _ = mgb_graph.dump_graph([varnode]) buf, _ = mgb_graph.dump_graph([varnode])
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
ret = mgb_graph.load_graph(io.BytesIO(buf))
graph, (varnode,) = ret.graph, ret.output_vars_list
del ret
del graph del graph
varnode.owner varnode.owner


@@ -132,7 +134,7 @@ def test_get_opr_seq():
file = io.BytesIO() file = io.BytesIO()
func.dump(file, optimize_for_inference=False) func.dump(file, optimize_for_inference=False)
file.seek(0) file.seek(0)
*_, outputs = mgb_graph.load_graph(file)
outputs = mgb_graph.load_graph(file).output_vars_list


seq_1 = cgtools.get_oprs_seq(outputs, True) seq_1 = cgtools.get_oprs_seq(outputs, True)
assert len(seq_1) == 5 assert len(seq_1) == 5


+ 2
- 2
imperative/python/test/unit/utils/test_dump_naming.py View File

@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
keep_var_name=2, keep_var_name=2,
) )
file.seek(0) file.seek(0)
*_, outputs = G.load_graph(file)
outputs = G.load_graph(file).output_vars_list
ops = cgtools.get_oprs_seq(outputs) ops = cgtools.get_oprs_seq(outputs)
return ops return ops


@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name):
file = io.BytesIO() file = io.BytesIO()
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2)
file.seek(0) file.seek(0)
*_, outputs = G.load_graph(file)
outputs = G.load_graph(file).output_vars_list
op = cgtools.get_oprs_seq(outputs)[-1] op = cgtools.get_oprs_seq(outputs)[-1]
assert op.inputs[0].name == var_name assert op.inputs[0].name == var_name




+ 44
- 0
imperative/python/test/unit/utils/test_network.py View File

@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape
from megengine.utils.network_node import Host2DeviceCopy, VarNode from megengine.utils.network_node import Host2DeviceCopy, VarNode




def test_metadata():
x = Tensor(0)

@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return x * 2

fwd(x)

orig_model = io.BytesIO()
fwd.dump(orig_model, user_info="test", optimize_for_inference=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": "test",
"graph_modified": False, # False: tracing.dump
"optimized_for_inference": False,
}

orig_model.seek(0)
graph.dump(
orig_model,
user_info={"str": "x", "tensor": x, "module": M.Module, "none": None},
optimize_for_inference=True,
enable_nchw4=True,
enable_ioc16=True,
)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
"graph_modified": True, # True: Network.dump
"optimized_for_inference": True,
"enable_nchw4": True,
"enable_ioc16": True,
}

orig_model.seek(0)
fwd.dump(orig_model, enable_metadata=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata is None


def test_replace_var(): def test_replace_var():


a = Tensor([1, 2]) a = Tensor([1, 2])


+ 2
- 1
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec):




def make_feeds(args): def make_feeds(args):
cg_rt, _, outputs = G.load_graph(args.input)
ret = G.load_graph(args.input)
cg_rt, outputs = ret.graph, ret.output_vars_list
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")


inputs = {i.name: i for i in inputs} inputs = {i.name: i for i in inputs}


+ 25
- 1
src/gopt/include/megbrain/gopt/inference.h View File

@@ -322,7 +322,31 @@ namespace gopt {
static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter(); static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter();
}; };


struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {};
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
uint64_t serialize() {
uint64_t ret = 0;
ret |= (uint64_t)layout_transform << 32;
if (f16_io_f32_comp) ret |= 1u;
if (f16_io_comp) ret |= 1u << 1;
if (fuse_conv_bias_nonlinearity) ret |= 1u << 2;
if (fuse_conv_bias_with_z) ret |= 1u << 3;
if (weight_preprocess) ret |= 1u << 4;
if (fuse_preprocess) ret |= 1u << 5;
return ret;
}

static OptimizeForInferenceOptions deserialize(uint64_t buf) {
OptimizeForInferenceOptions ret;
ret.f16_io_f32_comp = buf & 1u;
ret.f16_io_comp = buf & 1u << 1;
ret.fuse_conv_bias_nonlinearity = buf & 1u << 2;
ret.fuse_conv_bias_with_z = buf & 1u << 3;
ret.weight_preprocess = buf & 1u << 4;
ret.fuse_preprocess = buf & 1u << 5;
ret.layout_transform = (LayoutTransform)(buf >> 32);
return ret;
}
};


/*! /*!
* \brief optimize a computing graph for inference * \brief optimize a computing graph for inference


+ 8
- 0
src/serialization/impl/schema.fbs View File

@@ -128,6 +128,13 @@ table Operator {
name:string; name:string;
} }


table Metadata {
is_valid:bool;
graph_modified:bool;
user_info:string;
optimize_options:ulong;
}

struct OutputVar { struct OutputVar {
compact_id:uint; compact_id:uint;
original_id:uint; original_id:uint;
@@ -141,6 +148,7 @@ table Graph {
nr_shared_tensor:uint; nr_shared_tensor:uint;
oprs:[Operator]; oprs:[Operator];
output_vars_idx:[OutputVar]; output_vars_idx:[OutputVar];
metadata:Metadata;
} }


root_type Graph; root_type Graph;

+ 40
- 2
src/serialization/impl/serializer_oss.cpp View File

@@ -30,6 +30,7 @@
#include "megbrain/serialization/internal/flatbuffers_helper.h" #include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h" #include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/version.h" #include "megbrain/version.h"


@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
std::vector<flatbuffers::Offset<void>> m_cur_opr_param; std::vector<flatbuffers::Offset<void>> m_cur_opr_param;


void init_oprs_to_dump(const SymbolVarArray& endpoints); void init_oprs_to_dump(const SymbolVarArray& endpoints);
flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata);
flatbuffers::Offset<fbs::Operator> build_single_opr( flatbuffers::Offset<fbs::Operator> build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry); cg::OperatorNodeBase* opr, const OprRegistry* registry);


@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
public: public:
GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
DumpResult dump(const SymbolVarArray& output_vars, DumpResult dump(const SymbolVarArray& output_vars,
const DumpConfig& config = {}) override;
const DumpConfig& config = {},
const Metadata& metadata = {}) override;
const GraphDumpConfig& config() const override { return m_config; } const GraphDumpConfig& config() const override { return m_config; }
void dump_tensor(const std::string& name, const HostTensorND& tensor, void dump_tensor(const std::string& name, const HostTensorND& tensor,
TensorWriteMethod method) override; TensorWriteMethod method) override;
@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
} }
} }


flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata(
const Metadata& metadata) {
auto user_info = m_builder.CreateSharedString(metadata.user_info);
fbs::MetadataBuilder builder(m_builder);
builder.add_is_valid(metadata.is_valid);
builder.add_graph_modified(metadata.graph_modified);
builder.add_user_info(user_info);
builder.add_optimize_options(metadata.optimize_options);
return builder.Finish();
}

flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry) { cg::OperatorNodeBase* opr, const OprRegistry* registry) {
m_cur_opr = opr; m_cur_opr = opr;
@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
} }


GraphDumper::DumpResult GraphDumperOSS::dump( GraphDumper::DumpResult GraphDumperOSS::dump(
const SymbolVarArray& output_vars, const DumpConfig& config) {
const SymbolVarArray& output_vars,
const DumpConfig& config, const Metadata& metadata) {
mgb_throw_if(output_vars.empty(), SerializationError, mgb_throw_if(output_vars.empty(), SerializationError,
"Can't dump empty graph"); "Can't dump empty graph");


@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
uint64_t offset_to_fbs = 0; uint64_t offset_to_fbs = 0;
m_file->write(&offset_to_fbs, sizeof(offset_to_fbs)); m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));


// Dump metadata
auto fbmeta = build_metadata(metadata);

// Dump operators // Dump operators
init_oprs_to_dump(output_vars); init_oprs_to_dump(output_vars);
std::vector<flatbuffers::Offset<fbs::Operator>> oprs; std::vector<flatbuffers::Offset<fbs::Operator>> oprs;
@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
graph.add_oprs(fb_oprs); graph.add_oprs(fb_oprs);
graph.add_output_vars_idx(fb_output_vars); graph.add_output_vars_idx(fb_output_vars);
graph.add_nr_shared_tensor(m_nr_shared_tensor); graph.add_nr_shared_tensor(m_nr_shared_tensor);
graph.add_metadata(fbmeta);
m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier()); m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier());


// Write actual offset_to_fbs // Write actual offset_to_fbs
@@ -531,6 +550,7 @@ public:
mgb_assert(nr == 1); mgb_assert(nr == 1);
} }


Metadata load_metadata();
LoadResult load_oprs(); LoadResult load_oprs();
CompNode load_comp_node(const fbs::CompNode* comp_node); CompNode load_comp_node(const fbs::CompNode* comp_node);


@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() {
return sh_ptr_ref; return sh_ptr_ref;
} }


Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() {
const auto* fbmeta = m_loader->m_graph->metadata();
Metadata ret;
ret.is_valid = fbmeta->is_valid();
ret.graph_modified = fbmeta->graph_modified();
if (fbmeta->user_info()) {
ret.user_info = fbmeta->user_info()->str();
ret.has_user_info = true;
}
if (fbmeta->optimize_options()) {
ret.optimize_options = fbmeta->optimize_options();
ret.optimized_for_inference = true;
}
return ret;
}

void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
const fbs::Operator* fbopr) { const fbs::Operator* fbopr) {
m_cur_opr_tensor_cnt = 0; m_cur_opr_tensor_cnt = 0;
@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
} }


OprLoadContextImpl ctx{this, m_graph->mgb_version()}; OprLoadContextImpl ctx{this, m_graph->mgb_version()};
auto metadata = ctx.load_metadata();
auto result = ctx.load_oprs(); auto result = ctx.load_oprs();
result.metadata = metadata;


auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
auto cur = m_file->tell(); auto cur = m_file->tell();


+ 46
- 0
src/serialization/include/megbrain/serialization/metadata.h View File

@@ -0,0 +1,46 @@
/**
* \file src/serialization/include/megbrain/serialization/metadata.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \brief MegEngine model's metadata
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once

#include <string>

namespace mgb {
namespace serialization {

struct Metadata {
bool is_valid = false;

bool graph_modified = false;

bool has_user_info = false;
std::string user_info;

bool optimized_for_inference = false;
uint64_t optimize_options;

#define ADD_PROPERTY(type, name) \
type get_##name() const { return name; } \
void set_##name(type x) { \
name = x; \
has_##name = true; \
}
ADD_PROPERTY(std::string, user_info)
#undef ADD_PROPERTY

uint64_t get_optimize_options() { return optimize_options; }
void set_optimize_options(uint64_t value) {
optimized_for_inference = true;
optimize_options = value;
}
};

} // namespace serialization
} // namespace mgb

+ 6
- 1
src/serialization/include/megbrain/serialization/serializer.h View File

@@ -15,6 +15,7 @@
#include "megbrain/serialization/dump_format.h" #include "megbrain/serialization/dump_format.h"
#include "megbrain/serialization/file.h" #include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h" #include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/metadata.h"


namespace mgb { namespace mgb {
namespace serialization { namespace serialization {
@@ -32,6 +33,9 @@ namespace serialization {
//! expliit dtor decl to reduce binary size //! expliit dtor decl to reduce binary size
~LoadResult() noexcept; ~LoadResult() noexcept;


//! metadata
Metadata metadata;

using TensorMap = std::unordered_map< using TensorMap = std::unordered_map<
std::string, std::shared_ptr<HostTensorND>>; std::string, std::shared_ptr<HostTensorND>>;


@@ -178,7 +182,8 @@ namespace serialization {


virtual DumpResult dump( virtual DumpResult dump(
const SymbolVarArray &output_vars, const SymbolVarArray &output_vars,
const DumpConfig &config = {}) = 0;
const DumpConfig &config = {},
const Metadata &metadata = {}) = 0;
virtual GraphDumpFormat format() const = 0; virtual GraphDumpFormat format() const = 0;
}; };


+ 37
- 0
src/serialization/test/serializer_oss.cpp View File

@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) {
load(); load();
} }


TEST(TestSerializer2, Metadata) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};

auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape),
host_y = std::make_shared<HostTensorND>(cn, shape);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}),
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"});
using Mode = opr::Elemwise::Mode;
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"});

Metadata metadata;
metadata.user_info = "TEST_METADATA";
metadata.has_user_info = true;

auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z")}, {}, metadata);
};

auto load = [&]() {
HostTensorGenerator<> gen;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
auto metadata = rst.metadata;
int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA");
EXPECT_EQ(cmp, 0);
};

dump();
load();
}

TEST(TestSerializer2, APlusB) { TEST(TestSerializer2, APlusB) {
auto fname = GET_OUTPUT_FILE(); auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3}; TensorShape shape{2, 3};


Loading…
Cancel
Save