You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_header_for_bin_reduce.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import sys
  4. import re
  5. if sys.version_info[0] != 3 or sys.version_info[1] < 5:
  6. print('This script requires Python version 3.5')
  7. sys.exit(1)
  8. import argparse
  9. import json
  10. import os
  11. import subprocess
  12. import tempfile
  13. from pathlib import Path
  14. MIDOUT_TRACE_MAGIC = 'midout_trace v1\n'
  15. class HeaderGen:
  16. _dtypes = None
  17. _oprs = None
  18. _fout = None
  19. _elemwise_modes = None
  20. _has_netinfo = False
  21. _midout_files = None
  22. _file_without_hash = False
  23. def __init__(self):
  24. self._dtypes = set()
  25. self._oprs = set()
  26. self._elemwise_modes = set()
  27. self._graph_hashes = set()
  28. self._midout_files = []
  29. _megvii3_root_cache = None
  30. @classmethod
  31. def get_megvii3_root(cls):
  32. if cls._megvii3_root_cache is not None:
  33. return cls._megvii3_root_cache
  34. wd = Path(__file__).resolve().parent
  35. while wd.parent != wd:
  36. workspace_file = wd / 'WORKSPACE'
  37. if workspace_file.is_file():
  38. cls._megvii3_root_cache = str(wd)
  39. return cls._megvii3_root_cache
  40. wd = wd.parent
  41. raise RuntimeError('This script is supposed to run in megvii3.')
  42. def extend_netinfo(self, data):
  43. self._has_netinfo = True
  44. if 'hash' not in data:
  45. self._file_without_hash = True
  46. else:
  47. self._graph_hashes.add(str(data['hash']))
  48. for i in data['dtypes']:
  49. self._dtypes.add(i)
  50. for i in data['opr_types']:
  51. self._oprs.add(i)
  52. for i in data['elemwise_modes']:
  53. self._elemwise_modes.add(i)
  54. def extend_midout(self, fname):
  55. self._midout_files.append(fname)
  56. def generate(self, fout):
  57. self._fout = fout
  58. self._write_def('MGB_BINREDUCE_VERSION', '20190219')
  59. if self._has_netinfo:
  60. self._write_dtype()
  61. self._write_elemwise_modes()
  62. self._write_oprs()
  63. self._write_hash()
  64. self._write_midout()
  65. del self._fout
  66. def strip_opr_name_with_version(self, name):
  67. pos = len(name)
  68. t = re.search(r'V\d+$', name)
  69. if t:
  70. pos = t.start()
  71. return name[:pos]
  72. def _write_oprs(self):
  73. defs = ['}', 'namespace opr {']
  74. already_declare = set()
  75. already_instance = set()
  76. for i in self._oprs:
  77. i = self.strip_opr_name_with_version(i)
  78. if i in already_declare:
  79. continue
  80. else:
  81. already_declare.add(i)
  82. defs.append('class {};'.format(i))
  83. defs.append('}')
  84. defs.append('namespace serialization {')
  85. defs.append("""
  86. template<class Opr, class Callee>
  87. struct OprRegistryCaller {
  88. }; """)
  89. for i in sorted(self._oprs):
  90. i = self.strip_opr_name_with_version(i)
  91. if i in already_instance:
  92. continue
  93. else:
  94. already_instance.add(i)
  95. defs.append("""
  96. template<class Callee>
  97. struct OprRegistryCaller<opr::{}, Callee>: public
  98. OprRegistryCallerDefaultImpl<Callee> {{
  99. }}; """.format(i))
  100. self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs)
  101. def _write_elemwise_modes(self):
  102. with tempfile.NamedTemporaryFile() as ftmp:
  103. fpath = os.path.realpath(ftmp.name)
  104. subprocess.check_call(
  105. ['./brain/megbrain/dnn/scripts/gen_param_defs.py',
  106. '--write-enum-items', 'Elemwise:Mode',
  107. './brain/megbrain/dnn/scripts/opr_param_defs.py',
  108. fpath],
  109. cwd=self.get_megvii3_root()
  110. )
  111. with open(fpath) as fin:
  112. mode_list = [i.strip() for i in fin]
  113. for i in mode_list:
  114. if i in self._elemwise_modes:
  115. content = '_cb({})'.format(i)
  116. else:
  117. content = ''
  118. self._write_def(
  119. '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content)
  120. self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
  121. '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')
  122. def _write_dtype(self):
  123. if 'Float16' not in self._dtypes:
  124. # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
  125. # support in the past; however `FLOT16' is really a typo. We plan to
  126. # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
  127. # To prevent issues in the transition, we decide to define both
  128. # macros (`FLOT16' and `FLOAT16') here.
  129. #
  130. # In the future when the situation is settled and no one would ever
  131. # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
  132. # safely deleted.
  133. self._write_def('MEGDNN_DISABLE_FLOT16', 1)
  134. self._write_def('MEGDNN_DISABLE_FLOAT16', 1)
  135. def _write_hash(self):
  136. if self._file_without_hash:
  137. print('WARNING: network info has no graph hash. Using json file '
  138. 'generated by MegBrain >= 7.28.0 is recommended')
  139. else:
  140. defs = 'ULL,'.join(self._graph_hashes) + 'ULL'
  141. self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs)
  142. def _write_def(self, name, val):
  143. if isinstance(val, list):
  144. val = '\n'.join(val)
  145. val = str(val).strip().replace('\n', ' \\\n')
  146. self._fout.write('#define {} {}\n'.format(name, val))
  147. def _write_midout(self):
  148. if not self._midout_files:
  149. return
  150. gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout',
  151. 'gen_header.py')
  152. cvt = subprocess.run(
  153. [gen] + self._midout_files,
  154. stdout=subprocess.PIPE, check=True,
  155. ).stdout.decode('utf-8')
  156. self._fout.write('// midout \n')
  157. self._fout.write(cvt)
  158. def main():
  159. parser = argparse.ArgumentParser(
  160. description='generate header file for reducing binary size by '
  161. 'stripping unused oprs in a particular network; output file would '
  162. 'be written to bin_reduce.h',
  163. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  164. parser.add_argument(
  165. 'inputs', nargs='+',
  166. help='input files that describe specific traits of the network; '
  167. 'can be one of the following:'
  168. ' 1. json files generated by '
  169. 'megbrain.serialize_comp_graph_to_file() in python; '
  170. ' 2. trace files generated by midout library')
  171. parser.add_argument('-o', '--output', help='output file',
  172. default=os.path.join(HeaderGen.get_megvii3_root(),
  173. 'utils', 'bin_reduce.h'))
  174. args = parser.parse_args()
  175. gen = HeaderGen()
  176. for i in args.inputs:
  177. print('==== processing {}'.format(i))
  178. with open(i) as fin:
  179. if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
  180. gen.extend_midout(i)
  181. else:
  182. fin.seek(0)
  183. gen.extend_netinfo(json.loads(fin.read()))
  184. with open(args.output, 'w') as fout:
  185. gen.generate(fout)
  186. if __name__ == '__main__':
  187. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台