diff --git a/tools/gen_header_for_bin_reduce.py b/tools/gen_header_for_bin_reduce.py index 882deba4..0414db40 100755 --- a/tools/gen_header_for_bin_reduce.py +++ b/tools/gen_header_for_bin_reduce.py @@ -8,21 +8,22 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import sys -import re - -if sys.version_info[0] != 3 or sys.version_info[1] < 5: - print('This script requires Python version 3.5') - sys.exit(1) - import argparse import json import os +import re import subprocess +import sys import tempfile from pathlib import Path -MIDOUT_TRACE_MAGIC = 'midout_trace v1\n' +if sys.version_info[0] != 3 or sys.version_info[1] < 5: + print("This script requires Python version 3.5") + sys.exit(1) + + +MIDOUT_TRACE_MAGIC = "midout_trace v1\n" + class HeaderGen: _dtypes = None @@ -42,20 +43,22 @@ class HeaderGen: self._midout_files = [] _megvii3_root_cache = None + @classmethod def get_megvii3_root(cls): if cls._megvii3_root_cache is not None: return cls._megvii3_root_cache wd = Path(__file__).resolve().parent while wd.parent != wd: - workspace_file = wd / 'WORKSPACE' - if workspace_file.is_file(): - cls._megvii3_root_cache = str(wd) - return cls._megvii3_root_cache - wd = wd.parent + workspace_file = wd / "WORKSPACE" + if workspace_file.is_file(): + cls._megvii3_root_cache = str(wd) + return cls._megvii3_root_cache + wd = wd.parent return None _megengine_root_cache = None + @classmethod def get_megengine_root(cls): if cls._megengine_root_cache is not None: @@ -66,15 +69,15 @@ class HeaderGen: def extend_netinfo(self, data): self._has_netinfo = True - if 'hash' not in data: + if "hash" not in data: self._file_without_hash = True else: - self._graph_hashes.add(str(data['hash'])) - for i in data['dtypes']: + self._graph_hashes.add(str(data["hash"])) + for i in data["dtypes"]: self._dtypes.add(i) - for i in data['opr_types']: + for i in data["opr_types"]: self._oprs.add(i) - for i in data['elemwise_modes']: + for i in data["elemwise_modes"]: self._elemwise_modes.add(i) def extend_midout(self, fname): @@ -82,7 +85,7 @@ class HeaderGen: def generate(self, fout): self._fout = fout - self._write_def('MGB_BINREDUCE_VERSION', '20190219') + self._write_def("MGB_BINREDUCE_VERSION", "20190219") if self._has_netinfo: self._write_dtype() self._write_elemwise_modes() @@ -93,13 +96,13 @@ class HeaderGen: def strip_opr_name_with_version(self, name): pos = len(name) - t = re.search(r'V\d+$', name) + t = re.search(r"V\d+$", name) if t: pos = t.start() return name[:pos] def _write_oprs(self): - defs = ['}', 'namespace opr {'] + defs = ["}", "namespace opr {"] already_declare = set() already_instance = set() for i in self._oprs: @@ -109,13 +112,15 @@ class HeaderGen: else: already_declare.add(i) - defs.append('class {};'.format(i)) - defs.append('}') - defs.append('namespace serialization {') - defs.append(""" + defs.append("class {};".format(i)) + defs.append("}") + defs.append("namespace serialization {") + defs.append( + """ template struct OprRegistryCaller { - }; """) + }; """ + ) for i in sorted(self._oprs): i = self.strip_opr_name_with_version(i) if i in already_instance: @@ -123,40 +128,53 @@ class HeaderGen: else: already_instance.add(i) - defs.append(""" + defs.append( + """ template struct OprRegistryCaller: public OprRegistryCallerDefaultImpl {{ - }}; """.format(i)) - self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs) + }}; """.format( + i + ) + ) + self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs) def _write_elemwise_modes(self): with tempfile.NamedTemporaryFile() as ftmp: fpath = os.path.realpath(ftmp.name) subprocess.check_call( - ['./dnn/scripts/gen_param_defs.py', - '--write-enum-items', 'Elemwise:Mode', - './dnn/scripts/opr_param_defs.py', - fpath], - cwd=self.get_megengine_root() + [ + "./dnn/scripts/gen_param_defs.py", + "--write-enum-items", + "Elemwise:Mode", + "./dnn/scripts/opr_param_defs.py", + fpath, + ], + cwd=self.get_megengine_root(), ) with open(fpath) as fin: mode_list = [i.strip() for i in fin] for i in mode_list: - i = i.split(' ')[0].split('=')[0] + i = i.split(" ")[0].split("=")[0] if i in self._elemwise_modes: - content = '_cb({})'.format(i) + content = "_cb({})".format(i) else: - content = '' + content = "" self._write_def( - '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content) - self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', - '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') + "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( + i.split(" ")[0].split("=")[0] + ), + content, + ) + self._write_def( + "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", + "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", + ) def _write_dtype(self): - if 'Float16' not in self._dtypes: + if "Float16" not in self._dtypes: # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 # support in the past; however `FLOT16' is really a typo. We plan to # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. @@ -166,74 +184,86 @@ class HeaderGen: # In the future when the situation is settled and no one would ever # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be # safely deleted. - self._write_def('MEGDNN_DISABLE_FLOT16', 1) - self._write_def('MEGDNN_DISABLE_FLOAT16', 1) + self._write_def("MEGDNN_DISABLE_FLOT16", 1) + self._write_def("MEGDNN_DISABLE_FLOAT16", 1) def _write_hash(self): if self._file_without_hash: - print('WARNING: network info has no graph hash. Using json file ' - 'generated by MegBrain >= 7.28.0 is recommended') + print( + "WARNING: network info has no graph hash. Using json file " + "generated by MegBrain >= 7.28.0 is recommended" + ) else: - defs = 'ULL,'.join(self._graph_hashes) + 'ULL' - self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs) + defs = "ULL,".join(self._graph_hashes) + "ULL" + self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs) def _write_def(self, name, val): if isinstance(val, list): - val = '\n'.join(val) - val = str(val).strip().replace('\n', ' \\\n') - self._fout.write('#define {} {}\n'.format(name, val)) + val = "\n".join(val) + val = str(val).strip().replace("\n", " \\\n") + self._fout.write("#define {} {}\n".format(name, val)) def _write_midout(self): if not self._midout_files: return - gen = os.path.join(self.get_megengine_root(), 'third_party', 'midout', 'gen_header.py') + gen = os.path.join( + self.get_megengine_root(), "third_party", "midout", "gen_header.py" + ) if self.get_megvii3_root(): - gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', 'gen_header.py') - print('use {} to gen bin_reduce header'.format(gen)) + gen = os.path.join( + self.get_megvii3_root(), "brain", "midout", "gen_header.py" + ) + print("use {} to gen bin_reduce header".format(gen)) cvt = subprocess.run( - [gen] + self._midout_files, - stdout=subprocess.PIPE, check=True, - ).stdout.decode('utf-8') - self._fout.write('// midout \n') + [gen] + self._midout_files, stdout=subprocess.PIPE, check=True, + ).stdout.decode("utf-8") + self._fout.write("// midout \n") self._fout.write(cvt) if cvt.find(" half,") > 0: change = open(self._fout.name).read().replace(" half,", " __fp16,") with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: fix_fp16.write(change) msg = ( - "WARNING:\n" - "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" - "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" - "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" - ) + "WARNING:\n" + "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" + "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" + "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" + ) print(msg) def main(): parser = argparse.ArgumentParser( - description='generate header file for reducing binary size by ' - 'stripping unused oprs in a particular network; output file would ' - 'be written to bin_reduce.h', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + description="generate header file for reducing binary size by " + "stripping unused oprs in a particular network; output file would " + "be written to bin_reduce.h", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( - 'inputs', nargs='+', - help='input files that describe specific traits of the network; ' - 'can be one of the following:' - ' 1. json files generated by ' - 'megbrain.serialize_comp_graph_to_file() in python; ' - ' 2. trace files generated by midout library') - default_file=os.path.join(HeaderGen.get_megengine_root(), 'src', 'bin_reduce_cmake.h') + "inputs", + nargs="+", + help="input files that describe specific traits of the network; " + "can be one of the following:" + " 1. json files generated by " + "megbrain.serialize_comp_graph_to_file() in python; " + " 2. trace files generated by midout library", + ) + default_file = os.path.join( + HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h" + ) is_megvii3 = HeaderGen.get_megvii3_root() if is_megvii3: - default_file=os.path.join(HeaderGen.get_megvii3_root(), 'utils', 'bin_reduce.h') - parser.add_argument('-o', '--output', help='output file', default=default_file) + default_file = os.path.join( + HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h" + ) + parser.add_argument("-o", "--output", help="output file", default=default_file) args = parser.parse_args() - print('config output file: {}'.format(args.output)) + print("config output file: {}".format(args.output)) gen = HeaderGen() for i in args.inputs: - print('==== processing {}'.format(i)) + print("==== processing {}".format(i)) with open(i) as fin: if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: gen.extend_midout(i) @@ -241,8 +271,9 @@ def main(): fin.seek(0) gen.extend_netinfo(json.loads(fin.read())) - with open(args.output, 'w') as fout: + with open(args.output, "w") as fout: gen.generate(fout) -if __name__ == '__main__': + +if __name__ == "__main__": main()