You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

draw_graph.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #! /usr/bin/env python3
  2. import argparse
  3. import json
  4. import math
  5. import os
  6. from graphviz import Digraph
  7. class Node:
  8. def __init__(self, data):
  9. self.data = data
  10. self.label = ""
  11. self.output_labels = {i: "" for i in data["output"]}
  12. self.input_labels = {i: "" for i in data["input"]}
  13. def __str__(self):
  14. def quote(s):
  15. r = {
  16. "\\": "\\\\",
  17. "{": r"\{",
  18. "}": r"\}",
  19. "|": r"\|",
  20. "<": r"\<",
  21. ">": r"\>",
  22. "\n": r"\n",
  23. }
  24. for k, v in r.items():
  25. s = s.replace(k, v)
  26. return s
  27. def pport(d):
  28. return "|".join("<{}> {}".format(k, quote(v)) for k, v in d.items())
  29. in_ports = pport(self.input_labels)
  30. out_ports = pport(self.output_labels)
  31. return "{{%s}|%s|{%s}}" % (in_ports, quote(self.label), out_ports)
  32. class CompGraphPlotter:
  33. _args = None
  34. _jgraph = None
  35. """original graph represented by json"""
  36. _jgraph_profile = None
  37. _profile_normalize = None
  38. _profile_max_size = 3
  39. _profile_min_size = 1
  40. _dest = None
  41. _finished_vars = None
  42. _finished_oprs = None
  43. _var_attr = None
  44. def __init__(self, args):
  45. self._finished_vars = set()
  46. self._finished_oprs = {}
  47. self._args = args
  48. self._load_data()
  49. self._do_plot()
  50. def _do_plot(self):
  51. self._node_commands = []
  52. self._edge_commands = []
  53. n0, c0 = map(len, [self._finished_oprs, self._finished_vars])
  54. if self._args.dest_nodes:
  55. for i in map(int, self._args.dest_nodes.split(",")):
  56. self._add_var(i)
  57. elif not self._args.prune_dangling_vars:
  58. for i in self._jgraph["var"].keys():
  59. self._add_var(i)
  60. else:
  61. for i in self._jgraph["operator"]:
  62. self._add_opr(i, 0)
  63. n1, c1 = map(len, [self._finished_oprs, self._finished_vars])
  64. print("plot with {} oprs, {} vars".format(n1 - n0, c1 - c0))
  65. for i in self._node_commands:
  66. i()
  67. for i in self._edge_commands:
  68. i()
  69. del self._node_commands
  70. del self._edge_commands
  71. @property
  72. def dot_graph(self):
  73. return self._dest
  74. def _make_node_attr_for_size(self, size):
  75. return dict(
  76. height=str(size / 2),
  77. width=str(size),
  78. fontsize=str(size * 5),
  79. fixedsize="true",
  80. )
  81. @classmethod
  82. def load_single_graph(cls, fpath):
  83. prof = None
  84. with open(fpath) as fin:
  85. data = json.load(fin)
  86. if "graph_exec" in data:
  87. prof = {int(k): v for k, v in data["profiler"]["device"].items()}
  88. data = data["graph_exec"]
  89. for t in ["operator", "var"]:
  90. data[t] = {int(i): j for i, j in data[t].items()}
  91. gvars = data["var"]
  92. for oid, i in data["operator"].items():
  93. i["input"] = list(map(int, i["input"]))
  94. out = i["output"] = list(map(int, i["output"]))
  95. for j in out:
  96. gvars[j]["owner_opr"] = oid
  97. for var in data["var"].values():
  98. mp = var.get("mem_plan", None)
  99. if mp:
  100. var["shape"] = "{" + ",".join(map(str, mp["layout"]["shape"])) + "}"
  101. else:
  102. var["shape"] = "<?>"
  103. return data, prof
  104. def _load_data(self):
  105. args = self._args
  106. self._jgraph, prof = self.load_single_graph(args.input)
  107. if args.profile:
  108. for k, v in list(prof.items()):
  109. v = max(i["end"] - i["start"] for i in v.values())
  110. prof[k] = v
  111. self._jgraph_profile = prof
  112. self._profile_normalize = self._profile_max_size / max(
  113. map(math.sqrt, prof.values())
  114. )
  115. self._dest = Digraph(comment="plot for {}".format(args.input))
  116. if args.end_vars_from:
  117. eg, _ = self.load_single_graph(args.end_vars_from)
  118. for i in eg["operator"].keys():
  119. self._finished_oprs[i] = None
  120. for i in eg["var"].keys():
  121. self._finished_vars.add(i)
  122. def _add_opr(self, oprid, depth):
  123. name = "opr{}".format(oprid)
  124. if oprid in self._finished_oprs:
  125. return name
  126. oprobj = self._jgraph["operator"][oprid]
  127. if oprobj["type"] == "ImmutableTensor":
  128. self._finished_oprs[oprid] = None
  129. return name
  130. self._finished_oprs[oprid] = node = Node(oprobj)
  131. all_vars = self._jgraph["var"]
  132. dispname = [oprobj["name"], oprobj["type"]]
  133. for i in self._args.opr_attr:
  134. dispname.append("{}: {}".format(i, oprobj["extra"].get(i, "N/A")))
  135. attr = {}
  136. if self._jgraph_profile:
  137. time = self._jgraph_profile.get(oprid, 0)
  138. attr = self._make_node_attr_for_size(
  139. max(self._profile_normalize * time ** 0.5, self._profile_min_size)
  140. )
  141. dispname.append("time: {:.3f}ms".format(time * 1e3))
  142. node.label = "\n".join(dispname)
  143. self._node_commands.append(
  144. lambda: self._dest.node(name, str(node), shape="record", **attr)
  145. )
  146. for i in oprobj["input"]:
  147. inpopr = self._jgraph["operator"][all_vars[i]["owner_opr"]]
  148. if inpopr["type"] == "ImmutableTensor":
  149. node.input_labels[i] = "<const>"
  150. continue
  151. node.input_labels[i] = all_vars[i]["shape"]
  152. vi = self._add_var(i, depth)
  153. self._edge_commands.append(
  154. lambda vi=vi, name="{}:{}".format(name, i): self._dest.edge(vi, name)
  155. )
  156. return name
  157. def _add_var(self, varid, depth=0):
  158. varobj = self._jgraph["var"][varid]
  159. name = "opr{}:{}".format(varobj["owner_opr"], varid)
  160. if self._args.depth and depth > self._args.depth:
  161. return name
  162. if varid in self._finished_vars:
  163. return name
  164. self._finished_vars.add(varid)
  165. oprid = varobj["owner_opr"]
  166. oprobj = self._jgraph["operator"][oprid]
  167. dispname = [varobj["name"]] if varobj["name"] != oprobj["name"] else []
  168. dispname += [varobj["shape"]]
  169. dispname = "\n".join(dispname)
  170. self._add_opr(oprid, depth + 1)
  171. if self._finished_oprs[oprid] is not None:
  172. self._finished_oprs[oprid].output_labels[varid] = dispname
  173. return name
  174. def main():
  175. parser = argparse.ArgumentParser(
  176. "plot megbrain computing graph",
  177. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  178. )
  179. parser.add_argument(
  180. "-d",
  181. "--dest-nodes",
  182. help="target var nodes; a comma-separated list of var ids. The "
  183. "dependency graph would be plotted. If not given, "
  184. "all nodes are plotted",
  185. )
  186. parser.add_argument("--end-vars-from", help="set end vars from another file")
  187. parser.add_argument(
  188. "-i", "--input", required=True, help="input computing graph file"
  189. )
  190. parser.add_argument(
  191. "-o", "--output", required=True, help="write dot source to file"
  192. )
  193. parser.add_argument(
  194. "--profile", action="store_true", help="anonotate graph by profiling result"
  195. )
  196. parser.add_argument(
  197. "--prune-dangling-vars",
  198. action="store_true",
  199. help="remove vars not used by any opr",
  200. )
  201. parser.add_argument(
  202. "--opr-attr",
  203. action="append",
  204. default=[],
  205. help="extra opr attributes to be plotted",
  206. )
  207. parser.add_argument(
  208. "--depth",
  209. type=int,
  210. help="max depth (i.e. distance from dest nodes) " "of nodes to be plotted",
  211. )
  212. parser.add_argument(
  213. "--output-format",
  214. default="dot",
  215. help="output file format, could be .dot/.png/.pdf",
  216. )
  217. args = parser.parse_args()
  218. graph = CompGraphPlotter(args).dot_graph
  219. if args.output:
  220. output_name = args.output.split(".")[0]
  221. graph.save("{}.dot".format(output_name))
  222. if args.output_format != "dot":
  223. os.system(
  224. "dot -T{} -o {}.{} {}.dot".format(
  225. args.output_format, output_name, args.output_format, output_name
  226. )
  227. )
  228. os.system("rm -f {}.dot".format(output_name))
  229. if __name__ == "__main__":
  230. main()