You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

draw_graph.py 8.9 kB


  1. #! /usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import json
  11. import math
  12. import os
  13. from graphviz import Digraph
  14. class Node:
  15. def __init__(self, data):
  16. self.data = data
  17. self.label = ""
  18. self.output_labels = {i: "" for i in data["output"]}
  19. self.input_labels = {i: "" for i in data["input"]}
  20. def __str__(self):
  21. def quote(s):
  22. r = {
  23. "\\": "\\\\",
  24. "{": r"\{",
  25. "}": r"\}",
  26. "|": r"\|",
  27. "<": r"\<",
  28. ">": r"\>",
  29. "\n": r"\n",
  30. }
  31. for k, v in r.items():
  32. s = s.replace(k, v)
  33. return s
  34. def pport(d):
  35. return "|".join("<{}> {}".format(k, quote(v)) for k, v in d.items())
  36. in_ports = pport(self.input_labels)
  37. out_ports = pport(self.output_labels)
  38. return "{{%s}|%s|{%s}}" % (in_ports, quote(self.label), out_ports)
  39. class CompGraphPlotter:
  40. _args = None
  41. _jgraph = None
  42. """original graph represented by json"""
  43. _jgraph_profile = None
  44. _profile_normalize = None
  45. _profile_max_size = 3
  46. _profile_min_size = 1
  47. _dest = None
  48. _finished_vars = None
  49. _finished_oprs = None
  50. _var_attr = None
  51. def __init__(self, args):
  52. self._finished_vars = set()
  53. self._finished_oprs = {}
  54. self._args = args
  55. self._load_data()
  56. self._do_plot()
  57. def _do_plot(self):
  58. self._node_commands = []
  59. self._edge_commands = []
  60. n0, c0 = map(len, [self._finished_oprs, self._finished_vars])
  61. if self._args.dest_nodes:
  62. for i in map(int, self._args.dest_nodes.split(",")):
  63. self._add_var(i)
  64. elif not self._args.prune_dangling_vars:
  65. for i in self._jgraph["var"].keys():
  66. self._add_var(i)
  67. else:
  68. for i in self._jgraph["operator"]:
  69. self._add_opr(i, 0)
  70. n1, c1 = map(len, [self._finished_oprs, self._finished_vars])
  71. print("plot with {} oprs, {} vars".format(n1 - n0, c1 - c0))
  72. for i in self._node_commands:
  73. i()
  74. for i in self._edge_commands:
  75. i()
  76. del self._node_commands
  77. del self._edge_commands
  78. @property
  79. def dot_graph(self):
  80. return self._dest
  81. def _make_node_attr_for_size(self, size):
  82. return dict(
  83. height=str(size / 2),
  84. width=str(size),
  85. fontsize=str(size * 5),
  86. fixedsize="true",
  87. )
  88. @classmethod
  89. def load_single_graph(cls, fpath):
  90. prof = None
  91. with open(fpath) as fin:
  92. data = json.load(fin)
  93. if "graph_exec" in data:
  94. prof = {int(k): v for k, v in data["profiler"]["device"].items()}
  95. data = data["graph_exec"]
  96. for t in ["operator", "var"]:
  97. data[t] = {int(i): j for i, j in data[t].items()}
  98. gvars = data["var"]
  99. for oid, i in data["operator"].items():
  100. i["input"] = list(map(int, i["input"]))
  101. out = i["output"] = list(map(int, i["output"]))
  102. for j in out:
  103. gvars[j]["owner_opr"] = oid
  104. for var in data["var"].values():
  105. mp = var.get("mem_plan", None)
  106. if mp:
  107. var["shape"] = "{" + ",".join(map(str, mp["layout"]["shape"])) + "}"
  108. else:
  109. var["shape"] = "<?>"
  110. return data, prof
  111. def _load_data(self):
  112. args = self._args
  113. self._jgraph, prof = self.load_single_graph(args.input)
  114. if args.profile:
  115. for k, v in list(prof.items()):
  116. v = max(i["end"] - i["start"] for i in v.values())
  117. prof[k] = v
  118. self._jgraph_profile = prof
  119. self._profile_normalize = self._profile_max_size / max(
  120. map(math.sqrt, prof.values())
  121. )
  122. self._dest = Digraph(comment="plot for {}".format(args.input))
  123. if args.end_vars_from:
  124. eg, _ = self.load_single_graph(args.end_vars_from)
  125. for i in eg["operator"].keys():
  126. self._finished_oprs[i] = None
  127. for i in eg["var"].keys():
  128. self._finished_vars.add(i)
  129. def _add_opr(self, oprid, depth):
  130. name = "opr{}".format(oprid)
  131. if oprid in self._finished_oprs:
  132. return name
  133. oprobj = self._jgraph["operator"][oprid]
  134. if oprobj["type"] == "ImmutableTensor":
  135. self._finished_oprs[oprid] = None
  136. return name
  137. self._finished_oprs[oprid] = node = Node(oprobj)
  138. all_vars = self._jgraph["var"]
  139. dispname = [oprobj["name"], oprobj["type"]]
  140. for i in self._args.opr_attr:
  141. dispname.append("{}: {}".format(i, oprobj["extra"].get(i, "N/A")))
  142. attr = {}
  143. if self._jgraph_profile:
  144. time = self._jgraph_profile.get(oprid, 0)
  145. attr = self._make_node_attr_for_size(
  146. max(self._profile_normalize * time ** 0.5, self._profile_min_size)
  147. )
  148. dispname.append("time: {:.3f}ms".format(time * 1e3))
  149. node.label = "\n".join(dispname)
  150. self._node_commands.append(
  151. lambda: self._dest.node(name, str(node), shape="record", **attr)
  152. )
  153. for i in oprobj["input"]:
  154. inpopr = self._jgraph["operator"][all_vars[i]["owner_opr"]]
  155. if inpopr["type"] == "ImmutableTensor":
  156. node.input_labels[i] = "<const>"
  157. continue
  158. node.input_labels[i] = all_vars[i]["shape"]
  159. vi = self._add_var(i, depth)
  160. self._edge_commands.append(
  161. lambda vi=vi, name="{}:{}".format(name, i): self._dest.edge(vi, name)
  162. )
  163. return name
  164. def _add_var(self, varid, depth=0):
  165. varobj = self._jgraph["var"][varid]
  166. name = "opr{}:{}".format(varobj["owner_opr"], varid)
  167. if self._args.depth and depth > self._args.depth:
  168. return name
  169. if varid in self._finished_vars:
  170. return name
  171. self._finished_vars.add(varid)
  172. oprid = varobj["owner_opr"]
  173. oprobj = self._jgraph["operator"][oprid]
  174. dispname = [varobj["name"]] if varobj["name"] != oprobj["name"] else []
  175. dispname += [varobj["shape"]]
  176. dispname = "\n".join(dispname)
  177. self._add_opr(oprid, depth + 1)
  178. if self._finished_oprs[oprid] is not None:
  179. self._finished_oprs[oprid].output_labels[varid] = dispname
  180. return name
  181. def main():
  182. parser = argparse.ArgumentParser(
  183. "plot megbrain computing graph",
  184. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  185. )
  186. parser.add_argument(
  187. "-d",
  188. "--dest-nodes",
  189. help="target var nodes; a comma-separated list of var ids. The "
  190. "dependency graph would be plotted. If not given, "
  191. "all nodes are plotted",
  192. )
  193. parser.add_argument("--end-vars-from", help="set end vars from another file")
  194. parser.add_argument(
  195. "-i", "--input", required=True, help="input computing graph file"
  196. )
  197. parser.add_argument(
  198. "-o", "--output", required=True, help="write dot source to file"
  199. )
  200. parser.add_argument(
  201. "--profile", action="store_true", help="anonotate graph by profiling result"
  202. )
  203. parser.add_argument(
  204. "--prune-dangling-vars",
  205. action="store_true",
  206. help="remove vars not used by any opr",
  207. )
  208. parser.add_argument(
  209. "--opr-attr",
  210. action="append",
  211. default=[],
  212. help="extra opr attributes to be plotted",
  213. )
  214. parser.add_argument(
  215. "--depth",
  216. type=int,
  217. help="max depth (i.e. distance from dest nodes) " "of nodes to be plotted",
  218. )
  219. parser.add_argument(
  220. "--output-format",
  221. default="dot",
  222. help="output file format, could be .dot/.png/.pdf",
  223. )
  224. args = parser.parse_args()
  225. graph = CompGraphPlotter(args).dot_graph
  226. if args.output:
  227. output_name = args.output.split(".")[0]
  228. graph.save("{}.dot".format(output_name))
  229. if args.output_format != "dot":
  230. os.system(
  231. "dot -T{} -o {}.{} {}.dot".format(
  232. args.output_format, output_name, args.output_format, output_name
  233. )
  234. )
  235. os.system("rm -f {}.dot".format(output_name))
  236. if __name__ == "__main__":
  237. main()