#! /usr/bin/env python3 import argparse import collections import json import re import textwrap import numpy as np from tabulate import tabulate from megengine.utils.profile_analyzer import ( NonExistNum, ProfileAnalyzer, TimeFuncHelper, ) def _tabulate_ml(tab, **kwargs): r"""Tabulate profile output with multi-line support.""" new_tab = [] new_tab_is_row = [] for row in tab: col_lines = [str(i).split("\n") for i in row] max_nr_line = max(map(len, col_lines)) new_tab_is_row.append(True) if max_nr_line > 1: new_tab_is_row.extend([False] * (max_nr_line - 1)) for i in col_lines: if len(i) < max_nr_line: i.extend([""] * (max_nr_line - len(i))) new_tab.extend(zip(*col_lines)) else: new_tab.append(row) assert len(new_tab_is_row) == len(new_tab) ret = [i + "\n" for i in tabulate(new_tab, **kwargs).split("\n")] for idx, val in enumerate(new_tab_is_row): if not val: ret[idx * 2 + 2] = "" return "".join(ret)[:-1] def _tabulate_confluence(tab, **kwargs): r"""Tabulate profile output.""" kwargs.pop("tablefmt", None) s = tabulate(tab, tablefmt="orgtbl", **kwargs) lines = s.split("\n") lines[1] = lines[1].replace("+", "|") return "\n".join(lines) def main(passed_args=None): # pylint: disable=too-many-statements r"""Analyses profile info from :mod:`~.utils.profile_analyzer` . Run this file with ``--help`` to get more usage. """ parser = argparse.ArgumentParser( description="analyze analyzer result", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("dump") parser.add_argument( "-t", "--top", type=int, default=3, help="number of most time-consuming operators to print", ) parser.add_argument( "--type", action="append", help="filter oprs in the top list by type" ) parser.add_argument( "--aggregate-by", default=None, choices=["type"], help="aggragate profiling result by", ) parser.add_argument( "--opr-name", help="filter oprs in the top list by regex of name" ) parser.add_argument( "--input-dtype", type=str, help="filter oprs in the top list by input dtype" ) parser.add_argument( "--top-end-key", default="end", choices=["end", "kern"], help="how time in top is calculated; end corresponds " "to total device time, and kern corresponds to only " "wait time", ) parser.add_argument( "--aggregate", default=None, help="aggregate operations", choices=["max", "min", "sum", "mean"], ) parser.add_argument( "--order-by", default="time", help="sort result according to given column; the param can be " " or +, meaning sorting in descending or " "ascending order respectively", ) parser.add_argument( "--copy-time", action="store_true", help="show copy time related result" ) parser.add_argument( "--min-time", type=float, default=float("-inf"), help="minimal time of a result to be printed", ) parser.add_argument( "--max-time", type=float, default=float("inf"), help="maximal time of a result to be printed", ) parser.add_argument( "--show-host", action="store_true", help="show host profiling info" ) parser.add_argument( "--dump-only-opr", action="store_true", help="only dump operator info as plaintext; useful " "for diff between two filtered profile results", ) parser.add_argument( "--confluence", "--wiki", action="store_true", help="output confluence-markdown-compatible table", ) parser.add_argument( "--print-only", choices={"summary", "device", "host"}, help="print only chosen info", ) args = parser.parse_args(passed_args) opr_filters = [] if args.type: opr_filters.append(lambda o, a, b: o["type"] in args.type) if args.opr_name: opr_filters.append( lambda o, a, b, r=re.compile(args.opr_name): r.match(o["name"]) ) if args.input_dtype: opr_filters.append( lambda o, a, b: any( [i["mem_plan"]["layout"]["dtype"] == args.input_dtype for i in a] ) ) if not opr_filters: def opr_filter(o, a, b): # pylint: disable=unused-argument return True else: def opr_filter(o, a, b): return all(i(o, a, b) for i in opr_filters) with open(args.dump) as fin: dump = json.load(fin) analyzer = ProfileAnalyzer(dump, opr_filter) analyzer_tot = ProfileAnalyzer(dump, lambda _, __, ___: True) def summary(): device_end_func = TimeFuncHelper.eval_time_func("device", "end", np.max) device_kern_func = TimeFuncHelper.eval_time_func("device", "kern", np.max) host_end_func = TimeFuncHelper.eval_time_func("host", "end", np.max) def get_tot_time(func): rec = analyzer_tot.select(func, aggregate=np.sum) if not rec: return "N/A" rec = rec[0] return rec.time tab = [] tot_dev_time = get_tot_time(device_end_func) tot_host_time = get_tot_time(host_end_func) tab.append(("total device time", tot_dev_time)) tab.append(("total host time", tot_host_time)) if args.copy_time: def fmt(a, b): a = a[0] b = b[0] return "tot={:.4f} avg={:.4f}".format(a.time, b.time) tab.append( ( "copy time", fmt( analyzer.select( device_end_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.sum, ), analyzer.select( device_end_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.mean, ), ), ) ) tab.append( ( "copy wait time", fmt( analyzer.select( device_kern_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.sum, ), analyzer.select( device_kern_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.mean, ), ), ) ) if args.confluence: tab_str = _tabulate_confluence(tab, headers=["name", "value"]) else: tab_str = tabulate(tab) return tab_str, tot_dev_time, tot_host_time def prof_details(prof_type, tot_time): tab = [] def func( opr, *, f0=TimeFuncHelper.eval_time_func(prof_type, args.top_end_key, np.max) ): t = f0(opr) if t is not None and (t < args.min_time or t > args.max_time): return None return t records = analyzer.select( func, aggregate=args.aggregate, aggregate_by=args.aggregate_by, top_k=args.top, sort_by=args.order_by, ) if args.dump_only_opr: ret = [] for i in records: ret.append(" ".join(i.info.values())) return "\n".join(ret) def format_shapes(shapes, layouts=None, sep="\n"): if isinstance(shapes, NonExistNum) or shapes is None: return repr(shapes) if layouts is None: layouts = [None] * len(shapes) comp = [] for i, j in zip(shapes, layouts): i = "{" + ",".join(map(str, i)) + "}" if j: i += "\n -[" + ",".join(map(str, j)) + "]" comp.append(i) return sep.join(comp) def fix_num_and_find_unit(x, base): if isinstance(x, NonExistNum) or ( isinstance(x, float) and not np.isfinite(x) ): return x, "" unit = iter(["", "K", "M", "G", "T", "P"]) while x >= base: x /= base next(unit) return x, next(unit) def get_number_with_unit(num, unit, base, sep="\n"): num, unit_prefix = fix_num_and_find_unit(num, base) if isinstance(unit, list): unit = unit[int(unit_prefix != "")] return ("{:.2f}" + sep + "{}{}").format(num, unit_prefix, unit) if args.confluence: rows = [] cum_time = 0 max_time = max([r.time for r in records]) max_bandwidth = max([r.bandwidth for r in records]) max_flops = max( [r.flops for r in records if not isinstance(r.flops, NonExistNum)] ) bar_length = 15 for idx, record in enumerate(records): cum_time += record.time opr_info = [("opr " + k, v) for k, v in record.info.items()] row = collections.OrderedDict( [ ("#", idx), ("time", "{:.3}".format(record.time)), ("ratio", "{:.1f}%".format(record.time / tot_time * 100)), ("time bar", "#" * int(record.time / max_time * bar_length)), ("cum-time", cum_time), ("cum-time ratio", cum_time / tot_time), ] + opr_info + [ ( "computation (MFLO)", "{:.1f}".format(record.computation / 1000 ** 2), ), ("MFLOPS", "{:.1f}".format(record.flops / 1000 ** 2)), ( "MFLOPS-bar", "" if isinstance(record.flops, NonExistNum) else ("#" * int(record.flops / max_flops * bar_length)), ), ("memory (MB)", "{:.1f}".format(record.memory / 1024 ** 2)), ( "bandwidth (MiB/s)", "{:.1f}".format(record.bandwidth / 1024 ** 2), ), ( "bandwidth bar", "#" * int(record.bandwidth / max_bandwidth * bar_length), ), ( "in_shapes", format_shapes( record.in_shapes, record.in_layouts, sep=", " ), ), ("out_shapes", format_shapes(record.out_shapes, sep=", ")), ] ) rows.append(row) headers = list(rows[0].keys()) tab = [[row[i] for i in headers] for row in rows] return _tabulate_confluence(tab, headers=headers) else: cum_time = 0 for idx, record in enumerate(records): cum_time += record.time tab.append( ( "#{}\n{:.3}\n{:.1f}%".format( idx, record.time, record.time / tot_time * 100 ), "{:.3}\n{:.1f}%".format(cum_time, cum_time / tot_time * 100), "\n".join( "\n- ".join(textwrap.wrap(str(i), width=30)) for i in record.info.values() ), get_number_with_unit(record.computation, "FLO", 1000), get_number_with_unit(record.flops, "FLOPS", 1000), get_number_with_unit(record.memory, ["byte", "iB"], 1024), get_number_with_unit( record.bandwidth, ["byte/s", "iB/s"], 1024 ), format_shapes(record.in_shapes, record.in_layouts), format_shapes(record.out_shapes), ) ) return _tabulate_ml( tab, headers=[ "{} self time".format(prof_type), "cumulative", "operator info", "computation", "FLOPS", "memory", "bandwidth", "in_shapes", "out_shapes", ], tablefmt="fancy_grid", ) summary_tab, tot_dev_time, tot_host_time = summary() if args.print_only: print( { "summary": lambda: summary_tab, "device": lambda: prof_details("device", tot_dev_time), "host": lambda: prof_details("host", tot_host_time), }[args.print_only]() ) else: print(summary_tab) print() print(prof_details("device", tot_dev_time)) if args.show_host: print() print(prof_details("host", tot_host_time)) if __name__ == "__main__": main()