You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_header_for_bin_reduce.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. #
  5. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. #
  7. # Unless required by applicable law or agreed to in writing,
  8. # software distributed under the License is distributed on an
  9. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. import sys
  11. import re
  12. if sys.version_info[0] != 3 or sys.version_info[1] < 5:
  13. print('This script requires Python version 3.5')
  14. sys.exit(1)
  15. import argparse
  16. import json
  17. import os
  18. import subprocess
  19. import tempfile
  20. from pathlib import Path
  21. MIDOUT_TRACE_MAGIC = 'midout_trace v1\n'
  22. class HeaderGen:
  23. _dtypes = None
  24. _oprs = None
  25. _fout = None
  26. _elemwise_modes = None
  27. _has_netinfo = False
  28. _midout_files = None
  29. _file_without_hash = False
  30. def __init__(self):
  31. self._dtypes = set()
  32. self._oprs = set()
  33. self._elemwise_modes = set()
  34. self._graph_hashes = set()
  35. self._midout_files = []
  36. _megvii3_root_cache = None
  37. @classmethod
  38. def get_megvii3_root(cls):
  39. if cls._megvii3_root_cache is not None:
  40. return cls._megvii3_root_cache
  41. wd = Path(__file__).resolve().parent
  42. while wd.parent != wd:
  43. workspace_file = wd / 'WORKSPACE'
  44. if workspace_file.is_file():
  45. cls._megvii3_root_cache = str(wd)
  46. return cls._megvii3_root_cache
  47. wd = wd.parent
  48. return None
  49. _megengine_root_cache = None
  50. @classmethod
  51. def get_megengine_root(cls):
  52. if cls._megengine_root_cache is not None:
  53. return cls._megengine_root_cache
  54. wd = Path(__file__).resolve().parent.parent
  55. cls._megengine_root_cache = str(wd)
  56. return cls._megengine_root_cache
  57. def extend_netinfo(self, data):
  58. self._has_netinfo = True
  59. if 'hash' not in data:
  60. self._file_without_hash = True
  61. else:
  62. self._graph_hashes.add(str(data['hash']))
  63. for i in data['dtypes']:
  64. self._dtypes.add(i)
  65. for i in data['opr_types']:
  66. self._oprs.add(i)
  67. for i in data['elemwise_modes']:
  68. self._elemwise_modes.add(i)
  69. def extend_midout(self, fname):
  70. self._midout_files.append(fname)
  71. def generate(self, fout):
  72. self._fout = fout
  73. self._write_def('MGB_BINREDUCE_VERSION', '20190219')
  74. if self._has_netinfo:
  75. self._write_dtype()
  76. self._write_elemwise_modes()
  77. self._write_oprs()
  78. self._write_hash()
  79. self._write_midout()
  80. del self._fout
  81. def strip_opr_name_with_version(self, name):
  82. pos = len(name)
  83. t = re.search(r'V\d+$', name)
  84. if t:
  85. pos = t.start()
  86. return name[:pos]
  87. def _write_oprs(self):
  88. defs = ['}', 'namespace opr {']
  89. already_declare = set()
  90. already_instance = set()
  91. for i in self._oprs:
  92. i = self.strip_opr_name_with_version(i)
  93. if i in already_declare:
  94. continue
  95. else:
  96. already_declare.add(i)
  97. defs.append('class {};'.format(i))
  98. defs.append('}')
  99. defs.append('namespace serialization {')
  100. defs.append("""
  101. template<class Opr, class Callee>
  102. struct OprRegistryCaller {
  103. }; """)
  104. for i in sorted(self._oprs):
  105. i = self.strip_opr_name_with_version(i)
  106. if i in already_instance:
  107. continue
  108. else:
  109. already_instance.add(i)
  110. defs.append("""
  111. template<class Callee>
  112. struct OprRegistryCaller<opr::{}, Callee>: public
  113. OprRegistryCallerDefaultImpl<Callee> {{
  114. }}; """.format(i))
  115. self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs)
  116. def _write_elemwise_modes(self):
  117. with tempfile.NamedTemporaryFile() as ftmp:
  118. fpath = os.path.realpath(ftmp.name)
  119. subprocess.check_call(
  120. ['./dnn/scripts/gen_param_defs.py',
  121. '--write-enum-items', 'Elemwise:Mode',
  122. './dnn/scripts/opr_param_defs.py',
  123. fpath],
  124. cwd=self.get_megengine_root()
  125. )
  126. with open(fpath) as fin:
  127. mode_list = [i.strip() for i in fin]
  128. for i in mode_list:
  129. i = i.split(' ')[0].split('=')[0]
  130. if i in self._elemwise_modes:
  131. content = '_cb({})'.format(i)
  132. else:
  133. content = ''
  134. self._write_def(
  135. '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content)
  136. self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
  137. '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')
  138. def _write_dtype(self):
  139. if 'Float16' not in self._dtypes:
  140. # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
  141. # support in the past; however `FLOT16' is really a typo. We plan to
  142. # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
  143. # To prevent issues in the transition, we decide to define both
  144. # macros (`FLOT16' and `FLOAT16') here.
  145. #
  146. # In the future when the situation is settled and no one would ever
  147. # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
  148. # safely deleted.
  149. self._write_def('MEGDNN_DISABLE_FLOT16', 1)
  150. self._write_def('MEGDNN_DISABLE_FLOAT16', 1)
  151. def _write_hash(self):
  152. if self._file_without_hash:
  153. print('WARNING: network info has no graph hash. Using json file '
  154. 'generated by MegBrain >= 7.28.0 is recommended')
  155. else:
  156. defs = 'ULL,'.join(self._graph_hashes) + 'ULL'
  157. self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs)
  158. def _write_def(self, name, val):
  159. if isinstance(val, list):
  160. val = '\n'.join(val)
  161. val = str(val).strip().replace('\n', ' \\\n')
  162. self._fout.write('#define {} {}\n'.format(name, val))
  163. def _write_midout(self):
  164. if not self._midout_files:
  165. return
  166. gen = os.path.join(self.get_megengine_root(), 'third_party', 'midout', 'gen_header.py')
  167. if self.get_megvii3_root():
  168. gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', 'gen_header.py')
  169. print('use {} to gen bin_reduce header'.format(gen))
  170. cvt = subprocess.run(
  171. [gen] + self._midout_files,
  172. stdout=subprocess.PIPE, check=True,
  173. ).stdout.decode('utf-8')
  174. self._fout.write('// midout \n')
  175. self._fout.write(cvt)
  176. if cvt.find(" half,") > 0:
  177. change = open(self._fout.name).read().replace(" half,", " __fp16,")
  178. with open("fix_fp16_bin_reduce.h", "w") as fix_fp16:
  179. fix_fp16.write(change)
  180. msg = (
  181. "WARNING:\n"
  182. "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n"
  183. "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n"
  184. "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!"
  185. )
  186. print(msg)
  187. def main():
  188. parser = argparse.ArgumentParser(
  189. description='generate header file for reducing binary size by '
  190. 'stripping unused oprs in a particular network; output file would '
  191. 'be written to bin_reduce.h',
  192. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  193. parser.add_argument(
  194. 'inputs', nargs='+',
  195. help='input files that describe specific traits of the network; '
  196. 'can be one of the following:'
  197. ' 1. json files generated by '
  198. 'megbrain.serialize_comp_graph_to_file() in python; '
  199. ' 2. trace files generated by midout library')
  200. default_file=os.path.join(HeaderGen.get_megengine_root(), 'src', 'bin_reduce_cmake.h')
  201. is_megvii3 = HeaderGen.get_megvii3_root()
  202. if is_megvii3:
  203. default_file=os.path.join(HeaderGen.get_megvii3_root(), 'utils', 'bin_reduce.h')
  204. parser.add_argument('-o', '--output', help='output file', default=default_file)
  205. args = parser.parse_args()
  206. print('config output file: {}'.format(args.output))
  207. gen = HeaderGen()
  208. for i in args.inputs:
  209. print('==== processing {}'.format(i))
  210. with open(i) as fin:
  211. if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
  212. gen.extend_midout(i)
  213. else:
  214. fin.seek(0)
  215. gen.extend_netinfo(json.loads(fin.read()))
  216. with open(args.output, 'w') as fout:
  217. gen.generate(fout)
  218. if __name__ == '__main__':
  219. main()