Browse Source

feat(src/serialization): add support of serializing metadata

GitOrigin-RevId: b563c94451
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
54a4d70eb5
17 changed files with 365 additions and 34 deletions
  1. +51
    -6
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +20
    -2
      imperative/python/megengine/jit/tracing.py
  3. +2
    -1
      imperative/python/megengine/tools/load_network_and_run.py
  4. +2
    -1
      imperative/python/megengine/utils/comp_graph_tools.py
  5. +46
    -7
      imperative/python/megengine/utils/network.py
  6. +26
    -3
      imperative/python/src/graph_rt.cpp
  7. +3
    -4
      imperative/python/test/unit/jit/test_tracing.py
  8. +5
    -3
      imperative/python/test/unit/utils/test_cgtools.py
  9. +2
    -2
      imperative/python/test/unit/utils/test_dump_naming.py
  10. +44
    -0
      imperative/python/test/unit/utils/test_network.py
  11. +2
    -1
      sdk/load-and-run/dump_with_testcase_mge.py
  12. +25
    -1
      src/gopt/include/megbrain/gopt/inference.h
  13. +8
    -0
      src/serialization/impl/schema.fbs
  14. +40
    -2
      src/serialization/impl/serializer_oss.cpp
  15. +46
    -0
      src/serialization/include/megbrain/serialization/metadata.h
  16. +6
    -1
      src/serialization/include/megbrain/serialization/serializer.h
  17. +37
    -0
      src/serialization/test/serializer_oss.cpp

+ 51
- 6
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -11,7 +11,7 @@ import json
import os
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs):
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.

* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs):
"enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
"enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
"enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
"enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64,
}

for k, v in inference_optimize_layout_transform_map.items():
@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs):

dest_vars = _unwrap(dest_vars)
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return _wrap(res_vars)
return _wrap(res_vars), inference_options.serialize()


def deserialize_infer_option(x: int) -> Dict[str, bool]:
r"""
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.

:param x: inference options represented by int.
:return: inference options represented by dict.
"""

inference_options = GraphOptimizeOptions.deserialize(x)

inference_optimize_layout_transform_map = {
GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4",
GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4",
GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88",
GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32",
GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44",
GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot",
GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4",
GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64",
}

ret = dict()

layout = inference_options.layout_transform
if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT:
ret[inference_optimize_layout_transform_map[layout]] = True

if inference_options.f16_io_f32_comp:
ret["enable_io16xc32"] = True
if inference_options.f16_io_comp:
ret["enable_ioc16"] = True
if inference_options.fuse_conv_bias_nonlinearity:
ret["enable_fuse_conv_bias_nonlinearity"] = True
if inference_options.fuse_conv_bias_with_z:
ret["enable_fuse_conv_bias_with_z"] = True

return ret


def modify_opr_algo_strategy_inplace(dest_vars, strategy: str):
@@ -331,7 +374,8 @@ def dump_graph(
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False
append_json=False,
metadata=None
) -> Tuple[bytes, CompGraphDumpResult]:
"""
serialize the computing graph of `output_vars` and get byte result.
@@ -393,6 +437,7 @@ def dump_graph(
keep_opr_name,
keep_param_name,
keep_opr_priority,
metadata,
stat,
inputs,
outputs,
@@ -427,7 +472,7 @@ def dump_graph(


CompGraphLoadResult = collections.namedtuple(
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"]
)


@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult:
buf = open(fpath, "rb").read()
else:
buf = fpath.read()
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list)
cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata)


def _wrap(x):


+ 20
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -12,10 +12,12 @@ import functools
import itertools
import json
import os
import pickle
from typing import Any

import numpy as np

from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt import GraphProfiler, SerializationMetadata
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
TensorWeakRef,
@@ -670,6 +672,8 @@ class trace:
strip_info_file=None,
append_json=False,
optimize_for_inference=True,
user_info: Any = None,
enable_metadata: bool = True,
**kwargs
):
r"""
@@ -697,6 +701,8 @@ class trace:
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.

:Keyword Arguments:

@@ -729,6 +735,9 @@ class trace:
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.

* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
@@ -851,7 +860,15 @@ class trace:
dest_vars.append(v)

if optimize_for_inference:
dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)

metadata = SerializationMetadata()
if enable_metadata:
metadata.user_info = pickle.dumps(user_info)
metadata.is_valid = True
metadata.graph_modified = False
if optimize_for_inference:
metadata.optimize_options = optimize_options

if isinstance(file, str):
permission = "wb" if append == False else "ab"
@@ -864,6 +881,7 @@ class trace:
keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file,
append_json=append_json,
metadata=metadata,
)
file.write(dump_content)
return dump_info


+ 2
- 1
imperative/python/megengine/tools/load_network_and_run.py View File

@@ -411,7 +411,8 @@ def main():
args.embed_input = True

logger.info("loading model ...")
graph, _, output_vars = G.load_graph(args.net)
ret = G.load_graph(args.net)
graph, output_vars = ret.graph, ret.output_vars_list
input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy")

if args.output_name is not None:


+ 2
- 1
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -391,7 +391,8 @@ class GraphInference:
optimize_for_inference: bool = False,
**kwargs
):
self._graph, _, output_nodes = G.load_graph(file)
ret = G.load_graph(file)
self._graph, output_nodes = ret.graph, ret.output_vars_list
if outputs is not None:
output_nodes = find_vars_by_name(output_nodes, outputs)
self._origin_outputs = output_nodes


+ 46
- 7
imperative/python/megengine/utils/network.py View File

@@ -9,14 +9,12 @@
import collections
import fnmatch
import itertools
import pickle
import re
from collections import OrderedDict
from typing import Dict, List, Sequence
from typing import Any, Dict, List, Sequence

import numpy as np

from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core._imperative_rt import ComputingGraph, SerializationMetadata
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G
from ..logger import get_logger
@@ -42,6 +40,30 @@ class Network:
self.all_oprs_map = OrderedDict()
self.all_vars_map = OrderedDict()
self.graph = ComputingGraph()
self._metadata = None

@property
def metadata(self):
r"""
Load metadata as a dict.
"""
if not self._metadata.is_valid:
logger.info("metadata is not valid!")
return None
ret = dict()
try:
user_info = pickle.loads(self._metadata.user_info)
except: # pylint: disable=bare-except
logger.warning(
"can't parse user info by pickle, so return the original bytes object!"
)
user_info = self._metadata.user_info
ret["user_info"] = user_info
ret["graph_modified"] = self._metadata.graph_modified
ret["optimized_for_inference"] = self._metadata.optimized_for_inference
if ret["optimized_for_inference"]:
ret.update(G.deserialize_infer_option(self._metadata.optimize_options))
return ret

@classmethod
def load(cls, model_path: str, outspec: List[str] = None):
@@ -51,7 +73,8 @@ class Network:
:param outspec: only load the subgraph with outspec as its endpoints.
"""
self = cls()
_, _, outputs = G.load_graph(model_path)
ret = G.load_graph(model_path)
outputs, self._metadata = ret.output_vars_list, ret.metadata
if outspec is not None:
output_spec = outspec.copy()
all_vars = get_dep_vars(outputs) + outputs
@@ -125,6 +148,9 @@ class Network:
* enable_chwn4 --
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.

* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
@@ -152,6 +178,8 @@ class Network:
append_json=False,
optimize_for_inference=True,
append=False,
user_info: Any = None,
enable_metadata=True,
**kwargs
):
"""
@@ -176,6 +204,8 @@ class Network:
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.

:Keyword Arguments:

@@ -201,7 +231,15 @@ class Network:
)

if optimize_for_inference:
out = G.optimize_for_inference(out, **kwargs)
out, optimize_options = G.optimize_for_inference(out, **kwargs)

metadata = SerializationMetadata()
if enable_metadata:
metadata.is_valid = True
metadata.graph_modified = True
metadata.user_info = pickle.dumps(user_info)
if optimize_for_inference:
metadata.optimize_options = optimize_options

dump_content, _ = G.dump_graph(
out,
@@ -211,6 +249,7 @@ class Network:
keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file,
append_json=append_json,
metadata=metadata,
)
if isinstance(file, str):
permission = "wb" if append == False else "ab"


+ 26
- 3
imperative/python/src/graph_rt.cpp View File

@@ -34,6 +34,7 @@ namespace ser = mgb::serialization;
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
using _SerializationMetadata = mgb::serialization::Metadata;

namespace {
class _CompGraphProfilerImpl {
@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) {

auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
.def(py::init())
.def("serialize", &_OptimizeForInferenceOptions::serialize)
.def_static("deserialize", &_OptimizeForInferenceOptions::deserialize)
.def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp)
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) {
.value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
.value("NCHW32", _LayoutTransform::NCHW32)
.value("CHWN4", _LayoutTransform::CHWN4)
.value("NCHW64", _LayoutTransform::NCHW64)
.export_values()
;

@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) {
})->to_string();
});

py::class_<_SerializationMetadata>(m, "SerializationMetadata")
.def(py::init())
.def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); },
&_SerializationMetadata::set_user_info)
.def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference)
.def_property("optimize_options", &_SerializationMetadata::get_optimize_options,
&_SerializationMetadata::set_optimize_options)
.def_readwrite("graph_modified", &_SerializationMetadata::graph_modified)
.def_readwrite("is_valid", &_SerializationMetadata::is_valid)
;

m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars,
int keep_var_name,
bool keep_opr_name,
bool keep_param_name,
bool keep_opr_priority,
std::optional<_SerializationMetadata> metadata,
py::list& stat,
py::list& inputs,
py::list& outputs,
@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) {
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority, keep_opr_name};

auto rst = dumper->dump(symvars, config);
ser::GraphDumper::DumpResult rst;
if (metadata)
rst = dumper->dump(symvars, config, *metadata);
else
rst = dumper->dump(symvars, config);

for (auto i : rst.inputs) {
inputs.append(py::cast(i));
}
@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) {
for (const auto& var : rst.output_var_list) {
iter.add(var);
}
return rst.graph;

auto ret = py::tuple(2);
ret[0] = py::cast(rst.graph);
ret[1] = py::cast(rst.metadata);
return ret;
});

#define CURRENT_CLASS cg::ComputingGraph::Options


+ 3
- 4
imperative/python/test/unit/jit/test_tracing.py View File

@@ -239,8 +239,7 @@ def test_dump_volatile():
file = io.BytesIO()
f.dump(file, optimize_for_inference=False)
file.seek(0)
cg, _, outputs = G.load_graph(file)
(out,) = outputs
(out,) = G.load_graph(file).output_vars_list
assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "ImmutableTensor"
@@ -337,12 +336,12 @@ def test_goptions_log_exp():
f(tensor(1.0))
_, out = mkstemp()
f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
outputs = G.load_graph(out).output_vars_list
oprs_1 = cgtools.get_oprs_seq(outputs)

g(tensor(1.0))
g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
outputs = G.load_graph(out).output_vars_list
oprs_2 = cgtools.get_oprs_seq(outputs)

assert len(oprs_1) - len(oprs_2) == 2


+ 5
- 3
imperative/python/test/unit/utils/test_cgtools.py View File

@@ -88,7 +88,7 @@ def test_graph_traversal():
file = io.BytesIO()
fun.dump(file, optimize_for_inference=False)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
outputs = mgb_graph.load_graph(file).output_vars_list

_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
input_var = map_vars[1]
@@ -101,7 +101,9 @@ def test_load_refcnt():
graph = mgb_graph.Graph()
varnode = graph.make_const(0)
buf, _ = mgb_graph.dump_graph([varnode])
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
ret = mgb_graph.load_graph(io.BytesIO(buf))
graph, (varnode,) = ret.graph, ret.output_vars_list
del ret
del graph
varnode.owner

@@ -132,7 +134,7 @@ def test_get_opr_seq():
file = io.BytesIO()
func.dump(file, optimize_for_inference=False)
file.seek(0)
*_, outputs = mgb_graph.load_graph(file)
outputs = mgb_graph.load_graph(file).output_vars_list

seq_1 = cgtools.get_oprs_seq(outputs, True)
assert len(seq_1) == 5


+ 2
- 2
imperative/python/test/unit/utils/test_dump_naming.py View File

@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
keep_var_name=2,
)
file.seek(0)
*_, outputs = G.load_graph(file)
outputs = G.load_graph(file).output_vars_list
ops = cgtools.get_oprs_seq(outputs)
return ops

@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name):
file = io.BytesIO()
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2)
file.seek(0)
*_, outputs = G.load_graph(file)
outputs = G.load_graph(file).output_vars_list
op = cgtools.get_oprs_seq(outputs)[-1]
assert op.inputs[0].name == var_name



+ 44
- 0
imperative/python/test/unit/utils/test_network.py View File

@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape
from megengine.utils.network_node import Host2DeviceCopy, VarNode


def test_metadata():
x = Tensor(0)

@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return x * 2

fwd(x)

orig_model = io.BytesIO()
fwd.dump(orig_model, user_info="test", optimize_for_inference=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": "test",
"graph_modified": False, # False: tracing.dump
"optimized_for_inference": False,
}

orig_model.seek(0)
graph.dump(
orig_model,
user_info={"str": "x", "tensor": x, "module": M.Module, "none": None},
optimize_for_inference=True,
enable_nchw4=True,
enable_ioc16=True,
)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata == {
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
"graph_modified": True, # True: Network.dump
"optimized_for_inference": True,
"enable_nchw4": True,
"enable_ioc16": True,
}

orig_model.seek(0)
fwd.dump(orig_model, enable_metadata=False)
orig_model.seek(0)
graph = Net.load(orig_model)
assert graph.metadata is None


def test_replace_var():

a = Tensor([1, 2])


+ 2
- 1
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec):


def make_feeds(args):
cg_rt, _, outputs = G.load_graph(args.input)
ret = G.load_graph(args.input)
cg_rt, outputs = ret.graph, ret.output_vars_list
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")

inputs = {i.name: i for i in inputs}


+ 25
- 1
src/gopt/include/megbrain/gopt/inference.h View File

@@ -322,7 +322,31 @@ namespace gopt {
static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter();
};

struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {};
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
uint64_t serialize() {
uint64_t ret = 0;
ret |= (uint64_t)layout_transform << 32;
if (f16_io_f32_comp) ret |= 1u;
if (f16_io_comp) ret |= 1u << 1;
if (fuse_conv_bias_nonlinearity) ret |= 1u << 2;
if (fuse_conv_bias_with_z) ret |= 1u << 3;
if (weight_preprocess) ret |= 1u << 4;
if (fuse_preprocess) ret |= 1u << 5;
return ret;
}

static OptimizeForInferenceOptions deserialize(uint64_t buf) {
OptimizeForInferenceOptions ret;
ret.f16_io_f32_comp = buf & 1u;
ret.f16_io_comp = buf & 1u << 1;
ret.fuse_conv_bias_nonlinearity = buf & 1u << 2;
ret.fuse_conv_bias_with_z = buf & 1u << 3;
ret.weight_preprocess = buf & 1u << 4;
ret.fuse_preprocess = buf & 1u << 5;
ret.layout_transform = (LayoutTransform)(buf >> 32);
return ret;
}
};

/*!
* \brief optimize a computing graph for inference


+ 8
- 0
src/serialization/impl/schema.fbs View File

@@ -128,6 +128,13 @@ table Operator {
name:string;
}

table Metadata {
is_valid:bool;
graph_modified:bool;
user_info:string;
optimize_options:ulong;
}

struct OutputVar {
compact_id:uint;
original_id:uint;
@@ -141,6 +148,7 @@ table Graph {
nr_shared_tensor:uint;
oprs:[Operator];
output_vars_idx:[OutputVar];
metadata:Metadata;
}

root_type Graph;

+ 40
- 2
src/serialization/impl/serializer_oss.cpp View File

@@ -30,6 +30,7 @@
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/version.h"

@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
std::vector<flatbuffers::Offset<void>> m_cur_opr_param;

void init_oprs_to_dump(const SymbolVarArray& endpoints);
flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata);
flatbuffers::Offset<fbs::Operator> build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry);

@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
public:
GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
DumpResult dump(const SymbolVarArray& output_vars,
const DumpConfig& config = {}) override;
const DumpConfig& config = {},
const Metadata& metadata = {}) override;
const GraphDumpConfig& config() const override { return m_config; }
void dump_tensor(const std::string& name, const HostTensorND& tensor,
TensorWriteMethod method) override;
@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
}
}

flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata(
const Metadata& metadata) {
auto user_info = m_builder.CreateSharedString(metadata.user_info);
fbs::MetadataBuilder builder(m_builder);
builder.add_is_valid(metadata.is_valid);
builder.add_graph_modified(metadata.graph_modified);
builder.add_user_info(user_info);
builder.add_optimize_options(metadata.optimize_options);
return builder.Finish();
}

flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistry* registry) {
m_cur_opr = opr;
@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}

GraphDumper::DumpResult GraphDumperOSS::dump(
const SymbolVarArray& output_vars, const DumpConfig& config) {
const SymbolVarArray& output_vars,
const DumpConfig& config, const Metadata& metadata) {
mgb_throw_if(output_vars.empty(), SerializationError,
"Can't dump empty graph");

@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
uint64_t offset_to_fbs = 0;
m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));

// Dump metadata
auto fbmeta = build_metadata(metadata);

// Dump operators
init_oprs_to_dump(output_vars);
std::vector<flatbuffers::Offset<fbs::Operator>> oprs;
@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
graph.add_oprs(fb_oprs);
graph.add_output_vars_idx(fb_output_vars);
graph.add_nr_shared_tensor(m_nr_shared_tensor);
graph.add_metadata(fbmeta);
m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier());

// Write actual offset_to_fbs
@@ -531,6 +550,7 @@ public:
mgb_assert(nr == 1);
}

Metadata load_metadata();
LoadResult load_oprs();
CompNode load_comp_node(const fbs::CompNode* comp_node);

@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() {
return sh_ptr_ref;
}

Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() {
const auto* fbmeta = m_loader->m_graph->metadata();
Metadata ret;
ret.is_valid = fbmeta->is_valid();
ret.graph_modified = fbmeta->graph_modified();
if (fbmeta->user_info()) {
ret.user_info = fbmeta->user_info()->str();
ret.has_user_info = true;
}
if (fbmeta->optimize_options()) {
ret.optimize_options = fbmeta->optimize_options();
ret.optimized_for_inference = true;
}
return ret;
}

void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
const fbs::Operator* fbopr) {
m_cur_opr_tensor_cnt = 0;
@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
}

OprLoadContextImpl ctx{this, m_graph->mgb_version()};
auto metadata = ctx.load_metadata();
auto result = ctx.load_oprs();
result.metadata = metadata;

auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
auto cur = m_file->tell();


+ 46
- 0
src/serialization/include/megbrain/serialization/metadata.h View File

@@ -0,0 +1,46 @@
/**
* \file src/serialization/include/megbrain/serialization/metadata.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \brief MegEngine model's metadata
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once

#include <string>

namespace mgb {
namespace serialization {

struct Metadata {
bool is_valid = false;

bool graph_modified = false;

bool has_user_info = false;
std::string user_info;

bool optimized_for_inference = false;
uint64_t optimize_options;

#define ADD_PROPERTY(type, name) \
type get_##name() const { return name; } \
void set_##name(type x) { \
name = x; \
has_##name = true; \
}
ADD_PROPERTY(std::string, user_info)
#undef ADD_PROPERTY

uint64_t get_optimize_options() { return optimize_options; }
void set_optimize_options(uint64_t value) {
optimized_for_inference = true;
optimize_options = value;
}
};

} // namespace serialization
} // namespace mgb

+ 6
- 1
src/serialization/include/megbrain/serialization/serializer.h View File

@@ -15,6 +15,7 @@
#include "megbrain/serialization/dump_format.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/metadata.h"

namespace mgb {
namespace serialization {
@@ -32,6 +33,9 @@ namespace serialization {
//! expliit dtor decl to reduce binary size
~LoadResult() noexcept;

//! metadata
Metadata metadata;

using TensorMap = std::unordered_map<
std::string, std::shared_ptr<HostTensorND>>;

@@ -178,7 +182,8 @@ namespace serialization {

virtual DumpResult dump(
const SymbolVarArray &output_vars,
const DumpConfig &config = {}) = 0;
const DumpConfig &config = {},
const Metadata &metadata = {}) = 0;
virtual GraphDumpFormat format() const = 0;
};


+ 37
- 0
src/serialization/test/serializer_oss.cpp View File

@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) {
load();
}

TEST(TestSerializer2, Metadata) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};

auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape),
host_y = std::make_shared<HostTensorND>(cn, shape);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}),
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"});
using Mode = opr::Elemwise::Mode;
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"});

Metadata metadata;
metadata.user_info = "TEST_METADATA";
metadata.has_user_info = true;

auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z")}, {}, metadata);
};

auto load = [&]() {
HostTensorGenerator<> gen;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
auto metadata = rst.metadata;
int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA");
EXPECT_EQ(cmp, 0);
};

dump();
load();
}

TEST(TestSerializer2, APlusB) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};


Loading…
Cancel
Save