You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_header_for_bin_reduce.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import argparse
  4. import json
  5. import os
  6. import re
  7. import subprocess
  8. import sys
  9. import tempfile
  10. from pathlib import Path
  11. if sys.version_info[0] != 3 or sys.version_info[1] < 5:
  12. print("This script requires Python version 3.5")
  13. sys.exit(1)
  14. MIDOUT_TRACE_MAGIC = "midout_trace v1\n"
  15. class HeaderGen:
  16. _dtypes = None
  17. _oprs = None
  18. _fout = None
  19. _elemwise_modes = None
  20. _has_netinfo = False
  21. _midout_files = None
  22. _file_without_hash = False
  23. def __init__(self):
  24. self._dtypes = set()
  25. self._oprs = set()
  26. self._elemwise_modes = set()
  27. self._graph_hashes = set()
  28. self._midout_files = []
  29. _megvii3_root_cache = None
  30. @classmethod
  31. def get_megvii3_root(cls):
  32. if cls._megvii3_root_cache is not None:
  33. return cls._megvii3_root_cache
  34. wd = Path(__file__).resolve().parent
  35. while wd.parent != wd:
  36. workspace_file = wd / "WORKSPACE"
  37. if workspace_file.is_file():
  38. cls._megvii3_root_cache = str(wd)
  39. return cls._megvii3_root_cache
  40. wd = wd.parent
  41. return None
  42. _megengine_root_cache = None
  43. @classmethod
  44. def get_megengine_root(cls):
  45. if cls._megengine_root_cache is not None:
  46. return cls._megengine_root_cache
  47. wd = Path(__file__).resolve().parent.parent
  48. cls._megengine_root_cache = str(wd)
  49. return cls._megengine_root_cache
  50. def extend_netinfo(self, data):
  51. self._has_netinfo = True
  52. if "hash" not in data:
  53. self._file_without_hash = True
  54. else:
  55. self._graph_hashes.add(str(data["hash"]))
  56. for i in data["dtypes"]:
  57. self._dtypes.add(i)
  58. for i in data["opr_types"]:
  59. self._oprs.add(i)
  60. def extend_midout(self, fname):
  61. self._midout_files.append(fname)
  62. def extend_elemwise_mode_info(self, fname):
  63. for line in open(fname):
  64. # tag write in dnn/src/common/elemwise/opr_impl.cpp
  65. idx = line.find("megdnn_common_elemwise_mode")
  66. if idx > 0:
  67. cmd = "c++filt -t {}".format(line)
  68. demangle = subprocess.check_output(cmd, shell=True).decode("utf-8")
  69. demangle = demangle.replace(">", "").split()
  70. is_find_number = False
  71. for i in demangle:
  72. if i.isnumeric():
  73. self._elemwise_modes.add(i)
  74. is_find_number = True
  75. break
  76. assert (
  77. is_find_number
  78. ), "code issue happened!! can not find elemwise mode in: {}".format(
  79. line
  80. )
  81. def generate(self, fout):
  82. self._fout = fout
  83. self._write_def("MGB_BINREDUCE_VERSION", "20220507")
  84. if self._has_netinfo:
  85. self._write_dtype()
  86. if len(self._elemwise_modes) > 0:
  87. self._write_elemwise_modes()
  88. if self._has_netinfo:
  89. self._write_oprs()
  90. self._write_hash()
  91. self._write_midout()
  92. del self._fout
  93. def strip_opr_name_with_version(self, name):
  94. pos = len(name)
  95. t = re.search(r"V\d+$", name)
  96. if t:
  97. pos = t.start()
  98. return name[:pos]
  99. def _write_oprs(self):
  100. defs = ["}", "namespace opr {"]
  101. already_declare = set()
  102. already_instance = set()
  103. for i in self._oprs:
  104. i = self.strip_opr_name_with_version(i)
  105. if i in already_declare:
  106. continue
  107. else:
  108. already_declare.add(i)
  109. defs.append("class {};".format(i))
  110. defs.append("}")
  111. defs.append("namespace serialization {")
  112. defs.append(
  113. """
  114. template<class Opr, class Callee>
  115. struct OprRegistryCaller {
  116. }; """
  117. )
  118. for i in sorted(self._oprs):
  119. i = self.strip_opr_name_with_version(i)
  120. if i in already_instance:
  121. continue
  122. else:
  123. already_instance.add(i)
  124. defs.append(
  125. """
  126. template<class Callee>
  127. struct OprRegistryCaller<opr::{}, Callee>: public
  128. OprRegistryCallerDefaultImpl<Callee> {{
  129. }}; """.format(
  130. i
  131. )
  132. )
  133. self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs)
  134. def _write_elemwise_modes(self):
  135. with tempfile.NamedTemporaryFile() as ftmp:
  136. fpath = os.path.realpath(ftmp.name)
  137. subprocess.check_call(
  138. [
  139. "./dnn/scripts/gen_param_defs.py",
  140. "--write-enum-items",
  141. "Elemwise:Mode",
  142. "./dnn/scripts/opr_param_defs.py",
  143. fpath,
  144. ],
  145. cwd=self.get_megengine_root(),
  146. )
  147. with open(fpath) as fin:
  148. mode_list = [i.strip() for i in fin]
  149. all_elemwise_modes = set()
  150. for i in mode_list:
  151. i_type = i.replace(" ", "").replace("=", " ").split()[0]
  152. i_id = i.replace(" ", "").replace("=", " ").split()[1]
  153. all_elemwise_modes.add(i_id)
  154. if i_id in self._elemwise_modes:
  155. content = "_cb({})".format(i_type)
  156. else:
  157. content = ""
  158. self._write_def(
  159. "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content,
  160. )
  161. # write end of elemwise macro
  162. self._write_def(
  163. "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)",
  164. "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)",
  165. )
  166. # finally check all self._elemwise_modes is in all_elemwise_modes
  167. for i in self._elemwise_modes:
  168. assert (
  169. i in all_elemwise_modes
  170. ), "code issue happened, can not find elemwise mode: {} in {}".format(
  171. i, all_elemwise_modes
  172. )
  173. def _write_dtype(self):
  174. if "Float16" not in self._dtypes:
  175. # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
  176. # support in the past; however `FLOT16' is really a typo. We plan to
  177. # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
  178. # To prevent issues in the transition, we decide to define both
  179. # macros (`FLOT16' and `FLOAT16') here.
  180. #
  181. # In the future when the situation is settled and no one would ever
  182. # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
  183. # safely deleted.
  184. self._write_def("MEGDNN_DISABLE_FLOT16", 1)
  185. self._write_def("MEGDNN_DISABLE_FLOAT16", 1)
  186. def _write_hash(self):
  187. if self._file_without_hash:
  188. print(
  189. "WARNING: network info has no graph hash. Using json file "
  190. "generated by MegBrain >= 7.28.0 is recommended"
  191. )
  192. else:
  193. defs = "ULL,".join(self._graph_hashes) + "ULL"
  194. self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs)
  195. def _write_def(self, name, val):
  196. if isinstance(val, list):
  197. val = "\n".join(val)
  198. val = str(val).strip().replace("\n", " \\\n")
  199. self._fout.write("#define {} {}\n".format(name, val))
  200. def _write_midout(self):
  201. if not self._midout_files:
  202. return
  203. gen = os.path.join(
  204. self.get_megengine_root(), "third_party", "midout", "gen_header.py"
  205. )
  206. if self.get_megvii3_root():
  207. gen = os.path.join(
  208. self.get_megvii3_root(), "brain", "midout", "gen_header.py"
  209. )
  210. print("use {} to gen bin_reduce header".format(gen))
  211. cvt = subprocess.run(
  212. [gen] + self._midout_files, stdout=subprocess.PIPE, check=True,
  213. ).stdout.decode("utf-8")
  214. self._fout.write("// midout \n")
  215. self._fout.write(cvt)
  216. if cvt.find(" half,") > 0:
  217. change = open(self._fout.name).read().replace(" half,", " __fp16,")
  218. with open("fix_fp16_bin_reduce.h", "w") as fix_fp16:
  219. fix_fp16.write(change)
  220. msg = (
  221. "WARNING:\n"
  222. "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n"
  223. "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n"
  224. "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!"
  225. )
  226. print(msg)
  227. def main():
  228. parser = argparse.ArgumentParser(
  229. description="generate header file for reducing binary size by "
  230. "stripping unused oprs in a particular network; output file would "
  231. "be written to bin_reduce.h",
  232. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  233. )
  234. parser.add_argument(
  235. "inputs",
  236. nargs="+",
  237. help="input files that describe specific traits of the network; "
  238. "can be one of the following:"
  239. " 1. json files generated by "
  240. "megbrain.serialize_comp_graph_to_file() in python; "
  241. " 2. trace files generated by midout library",
  242. )
  243. default_file = os.path.join(
  244. HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h"
  245. )
  246. is_megvii3 = HeaderGen.get_megvii3_root()
  247. if is_megvii3:
  248. default_file = os.path.join(
  249. HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h"
  250. )
  251. parser.add_argument("-o", "--output", help="output file", default=default_file)
  252. args = parser.parse_args()
  253. print("config output file: {}".format(args.output))
  254. gen = HeaderGen()
  255. for i in args.inputs:
  256. print("==== processing {}".format(i))
  257. with open(i) as fin:
  258. if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
  259. gen.extend_midout(i)
  260. gen.extend_elemwise_mode_info(i)
  261. else:
  262. fin.seek(0)
  263. gen.extend_netinfo(json.loads(fin.read()))
  264. with open(args.output, "w") as fout:
  265. gen.generate(fout)
  266. if __name__ == "__main__":
  267. main()