GitOrigin-RevId: 337d95c7c2
release-1.1
@@ -8,6 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | import collections | ||||
import json | import json | ||||
import os | |||||
import threading | import threading | ||||
import weakref | import weakref | ||||
from concurrent.futures import Future, ThreadPoolExecutor | from concurrent.futures import Future, ThreadPoolExecutor | ||||
@@ -274,7 +275,8 @@ def dump_graph( | |||||
keep_var_name: int = 1, | keep_var_name: int = 1, | ||||
keep_param_name: bool = False, | keep_param_name: bool = False, | ||||
keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
strip_info_file=None | |||||
strip_info_file=None, | |||||
append_json=False | |||||
): | ): | ||||
"""serialize the computing graph of `output_vars` and get byte result. | """serialize the computing graph of `output_vars` and get byte result. | ||||
@@ -295,6 +297,9 @@ def dump_graph( | |||||
:param keep_opr_priority: whether to keep priority setting for operators | :param keep_opr_priority: whether to keep priority setting for operators | ||||
:param strip_info_file: a string for path or a file handler. if is not None, | :param strip_info_file: a string for path or a file handler. if is not None, | ||||
then the dump information for code strip would be written to ``strip_info_file`` | then the dump information for code strip would be written to ``strip_info_file`` | ||||
:param append_json: will be check when `strip_info_file` is not None. if set | |||||
true, the information for code strip will be append to strip_info_file. | |||||
if set false, will rewrite strip_info_file | |||||
:return: dump result as byte string, and an instance of namedtuple | :return: dump result as byte string, and an instance of namedtuple | ||||
:class:`CompGraphDumpResult`, whose fields are: | :class:`CompGraphDumpResult`, whose fields are: | ||||
@@ -342,10 +347,25 @@ def dump_graph( | |||||
if strip_info_file is not None: | if strip_info_file is not None: | ||||
if isinstance(strip_info_file, str): | if isinstance(strip_info_file, str): | ||||
strip_info_file = open(strip_info_file, "w") | |||||
strip_info = json.loads(_imperative_rt.get_info_for_strip(ov)) | |||||
strip_info["hash"] = dump_info.content_hash | |||||
json.dump(strip_info, strip_info_file) | |||||
if not os.path.exists(strip_info_file): | |||||
os.mknod(strip_info_file) | |||||
strip_info_file = open(strip_info_file, "r+") | |||||
new_strip_dict = json.loads(_imperative_rt.get_info_for_strip(ov)) | |||||
ori_strip_dict = new_strip_dict | |||||
json_content = strip_info_file.read() | |||||
if append_json and len(json_content) != 0: | |||||
# if there are contents in json file. Read them first and then append new information | |||||
ori_strip_dict = json.loads(json_content) | |||||
for k in ori_strip_dict: | |||||
new_strip_dict_v = new_strip_dict.get(k) | |||||
if new_strip_dict_v is not None: | |||||
for value in new_strip_dict_v: | |||||
if not value in ori_strip_dict[k]: | |||||
ori_strip_dict[k].append(value) | |||||
ori_strip_dict["hash"] = dump_info.content_hash | |||||
strip_info_file.seek(0) | |||||
strip_info_file.truncate() | |||||
json.dump(ori_strip_dict, strip_info_file) | |||||
return dump_content, dump_info | return dump_content, dump_info | ||||
@@ -267,7 +267,7 @@ void init_graph_rt(py::module m) { | |||||
{"opr_types", to_json(opr_types)}, | {"opr_types", to_json(opr_types)}, | ||||
{"dtypes", to_json(dtype_names)}, | {"dtypes", to_json(dtype_names)}, | ||||
{"elemwise_modes", to_json(elemwise_modes)}, | {"elemwise_modes", to_json(elemwise_modes)}, | ||||
}); | |||||
})->to_string(); | |||||
}); | }); | ||||
m.def("dump_graph", []( | m.def("dump_graph", []( | ||||
@@ -17,10 +17,10 @@ import numpy as np | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.core._imperative_rt as rt | import megengine.core._imperative_rt as rt | ||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
from megengine.core.tensor.megbrain_graph import VarNode | |||||
from megengine import cgtools | from megengine import cgtools | ||||
from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
from megengine.core.tensor.core import apply | from megengine.core.tensor.core import apply | ||||
from megengine.core.tensor.megbrain_graph import VarNode | |||||
from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
logger = mge.get_logger(__name__) | logger = mge.get_logger(__name__) | ||||
@@ -485,13 +485,30 @@ def main(): | |||||
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | ||||
else: | else: | ||||
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | ||||
strip_info_file = args.output + '.json' if args.output_strip_info else None | |||||
with open(args.output, "wb") as fout: | with open(args.output, "wb") as fout: | ||||
fout.write(b"mgbtest0") | fout.write(b"mgbtest0") | ||||
fout.write(struct.pack("I", len(feeds["testcases"]))) | fout.write(struct.pack("I", len(feeds["testcases"]))) | ||||
dump_content, _ = G.dump_graph([VarNode(i) for i in output_mgbvars]) | |||||
if isinstance(output_mgbvars, dict): | |||||
wrap_output_vars = dict([(i,VarNode(j)) for i,j in output_mgbvars]) | |||||
else: | |||||
wrap_output_vars = [VarNode(i) for i in output_mgbvars] | |||||
dump_content, stat = G.dump_graph( | |||||
wrap_output_vars, | |||||
append_json=True, | |||||
strip_info_file=strip_info_file, | |||||
**sereg_kwargs) | |||||
fout.write(dump_content) | fout.write(dump_content) | ||||
logger.info( | |||||
'graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'.format( | |||||
stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||||
) | |||||
) | |||||
def make_dev_tensor(value, dtype=None, device=None): | def make_dev_tensor(value, dtype=None, device=None): | ||||
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() | return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() | ||||
@@ -509,8 +526,11 @@ def main(): | |||||
testcase.keys() | testcase.keys() | ||||
) | ) | ||||
with open(args.output, "ab") as fout: | with open(args.output, "ab") as fout: | ||||
dump_content, _ = G.dump_graph(output_mgbvars) | |||||
fout.write(dump_content) | |||||
dump_content, _ = G.dump_graph( | |||||
output_mgbvars, | |||||
strip_info_file = strip_info_file, | |||||
append_json=True) | |||||
fout.write(dump_content) | |||||