|
- #! /usr/bin/env python3
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import argparse
- import json
- import math
- import os
-
- from graphviz import Digraph
-
-
- class Node:
- def __init__(self, data):
- self.data = data
- self.label = ""
- self.output_labels = {i: "" for i in data["output"]}
- self.input_labels = {i: "" for i in data["input"]}
-
- def __str__(self):
- def quote(s):
- r = {
- "\\": "\\\\",
- "{": r"\{",
- "}": r"\}",
- "|": r"\|",
- "<": r"\<",
- ">": r"\>",
- "\n": r"\n",
- }
- for k, v in r.items():
- s = s.replace(k, v)
- return s
-
- def pport(d):
- return "|".join("<{}> {}".format(k, quote(v)) for k, v in d.items())
-
- in_ports = pport(self.input_labels)
- out_ports = pport(self.output_labels)
-
- return "{{%s}|%s|{%s}}" % (in_ports, quote(self.label), out_ports)
-
-
- class CompGraphPlotter:
- _args = None
-
- _jgraph = None
- """original graph represented by json"""
-
- _jgraph_profile = None
- _profile_normalize = None
- _profile_max_size = 3
- _profile_min_size = 1
-
- _dest = None
-
- _finished_vars = None
- _finished_oprs = None
- _var_attr = None
-
- def __init__(self, args):
- self._finished_vars = set()
- self._finished_oprs = {}
- self._args = args
-
- self._load_data()
- self._do_plot()
-
- def _do_plot(self):
- self._node_commands = []
- self._edge_commands = []
-
- n0, c0 = map(len, [self._finished_oprs, self._finished_vars])
- if self._args.dest_nodes:
- for i in map(int, self._args.dest_nodes.split(",")):
- self._add_var(i)
- elif not self._args.prune_dangling_vars:
- for i in self._jgraph["var"].keys():
- self._add_var(i)
- else:
- for i in self._jgraph["operator"]:
- self._add_opr(i, 0)
-
- n1, c1 = map(len, [self._finished_oprs, self._finished_vars])
- print("plot with {} oprs, {} vars".format(n1 - n0, c1 - c0))
-
- for i in self._node_commands:
- i()
- for i in self._edge_commands:
- i()
- del self._node_commands
- del self._edge_commands
-
- @property
- def dot_graph(self):
- return self._dest
-
- def _make_node_attr_for_size(self, size):
- return dict(
- height=str(size / 2),
- width=str(size),
- fontsize=str(size * 5),
- fixedsize="true",
- )
-
- @classmethod
- def load_single_graph(cls, fpath):
- prof = None
- with open(fpath) as fin:
- data = json.load(fin)
- if "graph_exec" in data:
- prof = {int(k): v for k, v in data["profiler"]["device"].items()}
- data = data["graph_exec"]
-
- for t in ["operator", "var"]:
- data[t] = {int(i): j for i, j in data[t].items()}
-
- gvars = data["var"]
- for oid, i in data["operator"].items():
- i["input"] = list(map(int, i["input"]))
- out = i["output"] = list(map(int, i["output"]))
- for j in out:
- gvars[j]["owner_opr"] = oid
-
- for var in data["var"].values():
- mp = var.get("mem_plan", None)
- if mp:
- var["shape"] = "{" + ",".join(map(str, mp["layout"]["shape"])) + "}"
- else:
- var["shape"] = "<?>"
-
- return data, prof
-
- def _load_data(self):
- args = self._args
- self._jgraph, prof = self.load_single_graph(args.input)
- if args.profile:
- for k, v in list(prof.items()):
- v = max(i["end"] - i["start"] for i in v.values())
- prof[k] = v
- self._jgraph_profile = prof
- self._profile_normalize = self._profile_max_size / max(
- map(math.sqrt, prof.values())
- )
-
- self._dest = Digraph(comment="plot for {}".format(args.input))
-
- if args.end_vars_from:
- eg, _ = self.load_single_graph(args.end_vars_from)
- for i in eg["operator"].keys():
- self._finished_oprs[i] = None
- for i in eg["var"].keys():
- self._finished_vars.add(i)
-
- def _add_opr(self, oprid, depth):
- name = "opr{}".format(oprid)
- if oprid in self._finished_oprs:
- return name
- oprobj = self._jgraph["operator"][oprid]
- if oprobj["type"] == "ImmutableTensor":
- self._finished_oprs[oprid] = None
- return name
-
- self._finished_oprs[oprid] = node = Node(oprobj)
-
- all_vars = self._jgraph["var"]
- dispname = [oprobj["name"], oprobj["type"]]
- for i in self._args.opr_attr:
- dispname.append("{}: {}".format(i, oprobj["extra"].get(i, "N/A")))
-
- attr = {}
- if self._jgraph_profile:
- time = self._jgraph_profile.get(oprid, 0)
- attr = self._make_node_attr_for_size(
- max(self._profile_normalize * time ** 0.5, self._profile_min_size)
- )
- dispname.append("time: {:.3f}ms".format(time * 1e3))
-
- node.label = "\n".join(dispname)
-
- self._node_commands.append(
- lambda: self._dest.node(name, str(node), shape="record", **attr)
- )
-
- for i in oprobj["input"]:
- inpopr = self._jgraph["operator"][all_vars[i]["owner_opr"]]
- if inpopr["type"] == "ImmutableTensor":
- node.input_labels[i] = "<const>"
- continue
- node.input_labels[i] = all_vars[i]["shape"]
- vi = self._add_var(i, depth)
- self._edge_commands.append(
- lambda vi=vi, name="{}:{}".format(name, i): self._dest.edge(vi, name)
- )
-
- return name
-
- def _add_var(self, varid, depth=0):
- varobj = self._jgraph["var"][varid]
- name = "opr{}:{}".format(varobj["owner_opr"], varid)
- if self._args.depth and depth > self._args.depth:
- return name
- if varid in self._finished_vars:
- return name
- self._finished_vars.add(varid)
-
- oprid = varobj["owner_opr"]
- oprobj = self._jgraph["operator"][oprid]
- dispname = [varobj["name"]] if varobj["name"] != oprobj["name"] else []
- dispname += [varobj["shape"]]
- dispname = "\n".join(dispname)
-
- self._add_opr(oprid, depth + 1)
- if self._finished_oprs[oprid] is not None:
- self._finished_oprs[oprid].output_labels[varid] = dispname
-
- return name
-
-
- def main():
- parser = argparse.ArgumentParser(
- "plot megbrain computing graph",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "-d",
- "--dest-nodes",
- help="target var nodes; a comma-separated list of var ids. The "
- "dependency graph would be plotted. If not given, "
- "all nodes are plotted",
- )
- parser.add_argument("--end-vars-from", help="set end vars from another file")
- parser.add_argument(
- "-i", "--input", required=True, help="input computing graph file"
- )
- parser.add_argument(
- "-o", "--output", required=True, help="write dot source to file"
- )
- parser.add_argument(
- "--profile", action="store_true", help="anonotate graph by profiling result"
- )
- parser.add_argument(
- "--prune-dangling-vars",
- action="store_true",
- help="remove vars not used by any opr",
- )
- parser.add_argument(
- "--opr-attr",
- action="append",
- default=[],
- help="extra opr attributes to be plotted",
- )
- parser.add_argument(
- "--depth",
- type=int,
- help="max depth (i.e. distance from dest nodes) " "of nodes to be plotted",
- )
- parser.add_argument(
- "--output-format",
- default="dot",
- help="output file format, could be .dot/.png/.pdf",
- )
- args = parser.parse_args()
-
- graph = CompGraphPlotter(args).dot_graph
- if args.output:
- output_name = args.output.split(".")[0]
- graph.save("{}.dot".format(output_name))
- if args.output_format != "dot":
- os.system(
- "dot -T{} -o {}.{} {}.dot".format(
- args.output_format, output_name, args.output_format, output_name
- )
- )
- os.system("rm -f {}.dot".format(output_name))
-
-
- if __name__ == "__main__":
- main()
|