#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import json import os import re import subprocess import sys import tempfile from pathlib import Path if sys.version_info[0] != 3 or sys.version_info[1] < 5: print("This script requires Python version 3.5") sys.exit(1) MIDOUT_TRACE_MAGIC = "midout_trace v1\n" class HeaderGen: _dtypes = None _oprs = None _fout = None _elemwise_modes = None _has_netinfo = False _midout_files = None _file_without_hash = False def __init__(self): self._dtypes = set() self._oprs = set() self._elemwise_modes = set() self._graph_hashes = set() self._midout_files = [] _megvii3_root_cache = None @classmethod def get_megvii3_root(cls): if cls._megvii3_root_cache is not None: return cls._megvii3_root_cache wd = Path(__file__).resolve().parent while wd.parent != wd: workspace_file = wd / "WORKSPACE" if workspace_file.is_file(): cls._megvii3_root_cache = str(wd) return cls._megvii3_root_cache wd = wd.parent return None _megengine_root_cache = None @classmethod def get_megengine_root(cls): if cls._megengine_root_cache is not None: return cls._megengine_root_cache wd = Path(__file__).resolve().parent.parent cls._megengine_root_cache = str(wd) return cls._megengine_root_cache def extend_netinfo(self, data): self._has_netinfo = True if "hash" not in data: self._file_without_hash = True else: self._graph_hashes.add(str(data["hash"])) for i in data["dtypes"]: self._dtypes.add(i) for i in data["opr_types"]: self._oprs.add(i) def extend_midout(self, fname): self._midout_files.append(fname) def extend_elemwise_mode_info(self, fname): for line in open(fname): # tag write in dnn/src/common/elemwise/opr_impl.cpp idx = line.find("megdnn_common_elemwise_mode") if idx > 0: cmd = "c++filt -t {}".format(line) demangle = subprocess.check_output(cmd, shell=True).decode("utf-8") demangle = demangle.replace(">", "").split() is_find_number = False for i in demangle: if i.isnumeric(): self._elemwise_modes.add(i) is_find_number = True break assert ( is_find_number ), "code issue happened!! can not find elemwise mode in: {}".format( line ) def generate(self, fout): self._fout = fout self._write_def("MGB_BINREDUCE_VERSION", "20220507") if self._has_netinfo: self._write_dtype() if len(self._elemwise_modes) > 0: self._write_elemwise_modes() if self._has_netinfo: self._write_oprs() self._write_hash() self._write_midout() del self._fout def strip_opr_name_with_version(self, name): pos = len(name) t = re.search(r"V\d+$", name) if t: pos = t.start() return name[:pos] def _write_oprs(self): defs = ["}", "namespace opr {"] already_declare = set() already_instance = set() for i in self._oprs: i = self.strip_opr_name_with_version(i) if i in already_declare: continue else: already_declare.add(i) defs.append("class {};".format(i)) defs.append("}") defs.append("namespace serialization {") defs.append( """ template struct OprRegistryCaller { }; """ ) for i in sorted(self._oprs): i = self.strip_opr_name_with_version(i) if i in already_instance: continue else: already_instance.add(i) defs.append( """ template struct OprRegistryCaller: public OprRegistryCallerDefaultImpl {{ }}; """.format( i ) ) self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs) def _write_elemwise_modes(self): with tempfile.NamedTemporaryFile() as ftmp: fpath = os.path.realpath(ftmp.name) subprocess.check_call( [ "./dnn/scripts/gen_param_defs.py", "--write-enum-items", "Elemwise:Mode", "./dnn/scripts/opr_param_defs.py", fpath, ], cwd=self.get_megengine_root(), ) with open(fpath) as fin: mode_list = [i.strip() for i in fin] all_elemwise_modes = set() for i in mode_list: i_type = i.replace(" ", "").replace("=", " ").split()[0] i_id = i.replace(" ", "").replace("=", " ").split()[1] all_elemwise_modes.add(i_id) if i_id in self._elemwise_modes: content = "_cb({})".format(i_type) else: content = "" self._write_def( "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content, ) # write end of elemwise macro self._write_def( "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", ) # finally check all self._elemwise_modes is in all_elemwise_modes for i in self._elemwise_modes: assert ( i in all_elemwise_modes ), "code issue happened, can not find elemwise mode: {} in {}".format( i, all_elemwise_modes ) def _write_dtype(self): if "Float16" not in self._dtypes: # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 # support in the past; however `FLOT16' is really a typo. We plan to # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. # To prevent issues in the transition, we decide to define both # macros (`FLOT16' and `FLOAT16') here. # # In the future when the situation is settled and no one would ever # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be # safely deleted. self._write_def("MEGDNN_DISABLE_FLOT16", 1) self._write_def("MEGDNN_DISABLE_FLOAT16", 1) def _write_hash(self): if self._file_without_hash: print( "WARNING: network info has no graph hash. Using json file " "generated by MegBrain >= 7.28.0 is recommended" ) else: defs = "ULL,".join(self._graph_hashes) + "ULL" self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs) def _write_def(self, name, val): if isinstance(val, list): val = "\n".join(val) val = str(val).strip().replace("\n", " \\\n") self._fout.write("#define {} {}\n".format(name, val)) def _write_midout(self): if not self._midout_files: return gen = os.path.join( self.get_megengine_root(), "third_party", "midout", "gen_header.py" ) if self.get_megvii3_root(): gen = os.path.join( self.get_megvii3_root(), "brain", "midout", "gen_header.py" ) print("use {} to gen bin_reduce header".format(gen)) cvt = subprocess.run( [gen] + self._midout_files, stdout=subprocess.PIPE, check=True, ).stdout.decode("utf-8") self._fout.write("// midout \n") self._fout.write(cvt) if cvt.find(" half,") > 0: change = open(self._fout.name).read().replace(" half,", " __fp16,") with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: fix_fp16.write(change) msg = ( "WARNING:\n" "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" ) print(msg) def main(): parser = argparse.ArgumentParser( description="generate header file for reducing binary size by " "stripping unused oprs in a particular network; output file would " "be written to bin_reduce.h", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "inputs", nargs="+", help="input files that describe specific traits of the network; " "can be one of the following:" " 1. json files generated by " "megbrain.serialize_comp_graph_to_file() in python; " " 2. trace files generated by midout library", ) default_file = os.path.join( HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h" ) is_megvii3 = HeaderGen.get_megvii3_root() if is_megvii3: default_file = os.path.join( HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h" ) parser.add_argument("-o", "--output", help="output file", default=default_file) args = parser.parse_args() print("config output file: {}".format(args.output)) gen = HeaderGen() for i in args.inputs: print("==== processing {}".format(i)) with open(i) as fin: if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: gen.extend_midout(i) gen.extend_elemwise_mode_info(i) else: fin.seek(0) gen.extend_netinfo(json.loads(fin.read())) with open(args.output, "w") as fout: gen.generate(fout) if __name__ == "__main__": main()