You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_header_for_bin_reduce.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. #
  5. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. #
  7. # Unless required by applicable law or agreed to in writing,
  8. # software distributed under the License is distributed on an
  9. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. import sys
  11. import re
  12. if sys.version_info[0] != 3 or sys.version_info[1] < 5:
  13. print('This script requires Python version 3.5')
  14. sys.exit(1)
  15. import argparse
  16. import json
  17. import os
  18. import subprocess
  19. import tempfile
  20. from pathlib import Path
  21. MIDOUT_TRACE_MAGIC = 'midout_trace v1\n'
  22. class HeaderGen:
  23. _dtypes = None
  24. _oprs = None
  25. _fout = None
  26. _elemwise_modes = None
  27. _has_netinfo = False
  28. _midout_files = None
  29. _file_without_hash = False
  30. def __init__(self):
  31. self._dtypes = set()
  32. self._oprs = set()
  33. self._elemwise_modes = set()
  34. self._graph_hashes = set()
  35. self._midout_files = []
  36. _megvii3_root_cache = None
  37. @classmethod
  38. def get_megvii3_root(cls):
  39. if cls._megvii3_root_cache is not None:
  40. return cls._megvii3_root_cache
  41. wd = Path(__file__).resolve().parent
  42. while wd.parent != wd:
  43. workspace_file = wd / 'WORKSPACE'
  44. if workspace_file.is_file():
  45. cls._megvii3_root_cache = str(wd)
  46. return cls._megvii3_root_cache
  47. wd = wd.parent
  48. raise RuntimeError('This script is supposed to run in megvii3.')
  49. def extend_netinfo(self, data):
  50. self._has_netinfo = True
  51. if 'hash' not in data:
  52. self._file_without_hash = True
  53. else:
  54. self._graph_hashes.add(str(data['hash']))
  55. for i in data['dtypes']:
  56. self._dtypes.add(i)
  57. for i in data['opr_types']:
  58. self._oprs.add(i)
  59. for i in data['elemwise_modes']:
  60. self._elemwise_modes.add(i)
  61. def extend_midout(self, fname):
  62. self._midout_files.append(fname)
  63. def generate(self, fout):
  64. self._fout = fout
  65. self._write_def('MGB_BINREDUCE_VERSION', '20190219')
  66. if self._has_netinfo:
  67. self._write_dtype()
  68. self._write_elemwise_modes()
  69. self._write_oprs()
  70. self._write_hash()
  71. self._write_midout()
  72. del self._fout
  73. def strip_opr_name_with_version(self, name):
  74. pos = len(name)
  75. t = re.search(r'V\d+$', name)
  76. if t:
  77. pos = t.start()
  78. return name[:pos]
  79. def _write_oprs(self):
  80. defs = ['}', 'namespace opr {']
  81. already_declare = set()
  82. already_instance = set()
  83. for i in self._oprs:
  84. i = self.strip_opr_name_with_version(i)
  85. if i in already_declare:
  86. continue
  87. else:
  88. already_declare.add(i)
  89. defs.append('class {};'.format(i))
  90. defs.append('}')
  91. defs.append('namespace serialization {')
  92. defs.append("""
  93. template<class Opr, class Callee>
  94. struct OprRegistryCaller {
  95. }; """)
  96. for i in sorted(self._oprs):
  97. i = self.strip_opr_name_with_version(i)
  98. if i in already_instance:
  99. continue
  100. else:
  101. already_instance.add(i)
  102. defs.append("""
  103. template<class Callee>
  104. struct OprRegistryCaller<opr::{}, Callee>: public
  105. OprRegistryCallerDefaultImpl<Callee> {{
  106. }}; """.format(i))
  107. self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs)
  108. def _write_elemwise_modes(self):
  109. with tempfile.NamedTemporaryFile() as ftmp:
  110. fpath = os.path.realpath(ftmp.name)
  111. subprocess.check_call(
  112. ['./brain/megbrain/dnn/scripts/gen_param_defs.py',
  113. '--write-enum-items', 'Elemwise:Mode',
  114. './brain/megbrain/dnn/scripts/opr_param_defs.py',
  115. fpath],
  116. cwd=self.get_megvii3_root()
  117. )
  118. with open(fpath) as fin:
  119. mode_list = [i.strip() for i in fin]
  120. for i in mode_list:
  121. i = i.split(' ')[0].split('=')[0]
  122. if i in self._elemwise_modes:
  123. content = '_cb({})'.format(i)
  124. else:
  125. content = ''
  126. self._write_def(
  127. '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content)
  128. self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
  129. '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')
  130. def _write_dtype(self):
  131. if 'Float16' not in self._dtypes:
  132. # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
  133. # support in the past; however `FLOT16' is really a typo. We plan to
  134. # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
  135. # To prevent issues in the transition, we decide to define both
  136. # macros (`FLOT16' and `FLOAT16') here.
  137. #
  138. # In the future when the situation is settled and no one would ever
  139. # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
  140. # safely deleted.
  141. self._write_def('MEGDNN_DISABLE_FLOT16', 1)
  142. self._write_def('MEGDNN_DISABLE_FLOAT16', 1)
  143. def _write_hash(self):
  144. if self._file_without_hash:
  145. print('WARNING: network info has no graph hash. Using json file '
  146. 'generated by MegBrain >= 7.28.0 is recommended')
  147. else:
  148. defs = 'ULL,'.join(self._graph_hashes) + 'ULL'
  149. self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs)
  150. def _write_def(self, name, val):
  151. if isinstance(val, list):
  152. val = '\n'.join(val)
  153. val = str(val).strip().replace('\n', ' \\\n')
  154. self._fout.write('#define {} {}\n'.format(name, val))
  155. def _write_midout(self):
  156. if not self._midout_files:
  157. return
  158. gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout',
  159. 'gen_header.py')
  160. cvt = subprocess.run(
  161. [gen] + self._midout_files,
  162. stdout=subprocess.PIPE, check=True,
  163. ).stdout.decode('utf-8')
  164. self._fout.write('// midout \n')
  165. self._fout.write(cvt)
  166. if cvt.find(" half,") > 0:
  167. change = open(self._fout.name).read().replace(" half,", " __fp16,")
  168. with open("fix_fp16_bin_reduce.h", "w") as fix_fp16:
  169. fix_fp16.write(change)
  170. msg = (
  171. "WARNING:\n"
  172. "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n"
  173. "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n"
  174. "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!"
  175. )
  176. print(msg)
  177. def main():
  178. parser = argparse.ArgumentParser(
  179. description='generate header file for reducing binary size by '
  180. 'stripping unused oprs in a particular network; output file would '
  181. 'be written to bin_reduce.h',
  182. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  183. parser.add_argument(
  184. 'inputs', nargs='+',
  185. help='input files that describe specific traits of the network; '
  186. 'can be one of the following:'
  187. ' 1. json files generated by '
  188. 'megbrain.serialize_comp_graph_to_file() in python; '
  189. ' 2. trace files generated by midout library')
  190. parser.add_argument('-o', '--output', help='output file',
  191. default=os.path.join(HeaderGen.get_megvii3_root(),
  192. 'utils', 'bin_reduce.h'))
  193. args = parser.parse_args()
  194. gen = HeaderGen()
  195. for i in args.inputs:
  196. print('==== processing {}'.format(i))
  197. with open(i) as fin:
  198. if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
  199. gen.extend_midout(i)
  200. else:
  201. fin.seek(0)
  202. gen.extend_netinfo(json.loads(fin.read()))
  203. with open(args.output, 'w') as fout:
  204. gen.generate(fout)
  205. if __name__ == '__main__':
  206. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台