You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_header_for_bin_reduce.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. #
  5. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. #
  7. # Unless required by applicable law or agreed to in writing,
  8. # software distributed under the License is distributed on an
  9. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. import argparse
  11. import json
  12. import os
  13. import re
  14. import subprocess
  15. import sys
  16. import tempfile
  17. from pathlib import Path
  18. if sys.version_info[0] != 3 or sys.version_info[1] < 5:
  19. print("This script requires Python version 3.5")
  20. sys.exit(1)
  21. MIDOUT_TRACE_MAGIC = "midout_trace v1\n"
  22. class HeaderGen:
  23. _dtypes = None
  24. _oprs = None
  25. _fout = None
  26. _elemwise_modes = None
  27. _has_netinfo = False
  28. _midout_files = None
  29. _file_without_hash = False
  30. def __init__(self):
  31. self._dtypes = set()
  32. self._oprs = set()
  33. self._elemwise_modes = set()
  34. self._graph_hashes = set()
  35. self._midout_files = []
  36. _megvii3_root_cache = None
  37. @classmethod
  38. def get_megvii3_root(cls):
  39. if cls._megvii3_root_cache is not None:
  40. return cls._megvii3_root_cache
  41. wd = Path(__file__).resolve().parent
  42. while wd.parent != wd:
  43. workspace_file = wd / "WORKSPACE"
  44. if workspace_file.is_file():
  45. cls._megvii3_root_cache = str(wd)
  46. return cls._megvii3_root_cache
  47. wd = wd.parent
  48. return None
  49. _megengine_root_cache = None
  50. @classmethod
  51. def get_megengine_root(cls):
  52. if cls._megengine_root_cache is not None:
  53. return cls._megengine_root_cache
  54. wd = Path(__file__).resolve().parent.parent
  55. cls._megengine_root_cache = str(wd)
  56. return cls._megengine_root_cache
  57. def extend_netinfo(self, data):
  58. self._has_netinfo = True
  59. if "hash" not in data:
  60. self._file_without_hash = True
  61. else:
  62. self._graph_hashes.add(str(data["hash"]))
  63. for i in data["dtypes"]:
  64. self._dtypes.add(i)
  65. for i in data["opr_types"]:
  66. self._oprs.add(i)
  67. def extend_midout(self, fname):
  68. self._midout_files.append(fname)
  69. def extend_elemwise_mode_info(self, fname):
  70. for line in open(fname):
  71. # tag write in dnn/src/common/elemwise/opr_impl.cpp
  72. idx = line.find("megdnn_common_elemwise_mode")
  73. if idx > 0:
  74. cmd = "c++filt -t {}".format(line)
  75. demangle = subprocess.check_output(cmd, shell=True).decode("utf-8")
  76. demangle = demangle.replace(">", "").split()
  77. is_find_number = False
  78. for i in demangle:
  79. if i.isnumeric():
  80. self._elemwise_modes.add(i)
  81. is_find_number = True
  82. break
  83. assert (
  84. is_find_number
  85. ), "code issue happened!! can not find elemwise mode in: {}".format(
  86. line
  87. )
  88. def generate(self, fout):
  89. self._fout = fout
  90. self._write_def("MGB_BINREDUCE_VERSION", "20220507")
  91. if self._has_netinfo:
  92. self._write_dtype()
  93. if len(self._elemwise_modes) > 0:
  94. self._write_elemwise_modes()
  95. if self._has_netinfo:
  96. self._write_oprs()
  97. self._write_hash()
  98. self._write_midout()
  99. del self._fout
  100. def strip_opr_name_with_version(self, name):
  101. pos = len(name)
  102. t = re.search(r"V\d+$", name)
  103. if t:
  104. pos = t.start()
  105. return name[:pos]
  106. def _write_oprs(self):
  107. defs = ["}", "namespace opr {"]
  108. already_declare = set()
  109. already_instance = set()
  110. for i in self._oprs:
  111. i = self.strip_opr_name_with_version(i)
  112. if i in already_declare:
  113. continue
  114. else:
  115. already_declare.add(i)
  116. defs.append("class {};".format(i))
  117. defs.append("}")
  118. defs.append("namespace serialization {")
  119. defs.append(
  120. """
  121. template<class Opr, class Callee>
  122. struct OprRegistryCaller {
  123. }; """
  124. )
  125. for i in sorted(self._oprs):
  126. i = self.strip_opr_name_with_version(i)
  127. if i in already_instance:
  128. continue
  129. else:
  130. already_instance.add(i)
  131. defs.append(
  132. """
  133. template<class Callee>
  134. struct OprRegistryCaller<opr::{}, Callee>: public
  135. OprRegistryCallerDefaultImpl<Callee> {{
  136. }}; """.format(
  137. i
  138. )
  139. )
  140. self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs)
  141. def _write_elemwise_modes(self):
  142. with tempfile.NamedTemporaryFile() as ftmp:
  143. fpath = os.path.realpath(ftmp.name)
  144. subprocess.check_call(
  145. [
  146. "./dnn/scripts/gen_param_defs.py",
  147. "--write-enum-items",
  148. "Elemwise:Mode",
  149. "./dnn/scripts/opr_param_defs.py",
  150. fpath,
  151. ],
  152. cwd=self.get_megengine_root(),
  153. )
  154. with open(fpath) as fin:
  155. mode_list = [i.strip() for i in fin]
  156. all_elemwise_modes = set()
  157. for i in mode_list:
  158. i_type = i.replace(" ", "").replace("=", " ").split()[0]
  159. i_id = i.replace(" ", "").replace("=", " ").split()[1]
  160. all_elemwise_modes.add(i_id)
  161. if i_id in self._elemwise_modes:
  162. content = "_cb({})".format(i_type)
  163. else:
  164. content = ""
  165. self._write_def(
  166. "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content,
  167. )
  168. # write end of elemwise macro
  169. self._write_def(
  170. "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)",
  171. "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)",
  172. )
  173. # finally check all self._elemwise_modes is in all_elemwise_modes
  174. for i in self._elemwise_modes:
  175. assert (
  176. i in all_elemwise_modes
  177. ), "code issue happened, can not find elemwise mode: {} in {}".format(
  178. i, all_elemwise_modes
  179. )
  180. def _write_dtype(self):
  181. if "Float16" not in self._dtypes:
  182. # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
  183. # support in the past; however `FLOT16' is really a typo. We plan to
  184. # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
  185. # To prevent issues in the transition, we decide to define both
  186. # macros (`FLOT16' and `FLOAT16') here.
  187. #
  188. # In the future when the situation is settled and no one would ever
  189. # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
  190. # safely deleted.
  191. self._write_def("MEGDNN_DISABLE_FLOT16", 1)
  192. self._write_def("MEGDNN_DISABLE_FLOAT16", 1)
  193. def _write_hash(self):
  194. if self._file_without_hash:
  195. print(
  196. "WARNING: network info has no graph hash. Using json file "
  197. "generated by MegBrain >= 7.28.0 is recommended"
  198. )
  199. else:
  200. defs = "ULL,".join(self._graph_hashes) + "ULL"
  201. self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs)
  202. def _write_def(self, name, val):
  203. if isinstance(val, list):
  204. val = "\n".join(val)
  205. val = str(val).strip().replace("\n", " \\\n")
  206. self._fout.write("#define {} {}\n".format(name, val))
  207. def _write_midout(self):
  208. if not self._midout_files:
  209. return
  210. gen = os.path.join(
  211. self.get_megengine_root(), "third_party", "midout", "gen_header.py"
  212. )
  213. if self.get_megvii3_root():
  214. gen = os.path.join(
  215. self.get_megvii3_root(), "brain", "midout", "gen_header.py"
  216. )
  217. print("use {} to gen bin_reduce header".format(gen))
  218. cvt = subprocess.run(
  219. [gen] + self._midout_files, stdout=subprocess.PIPE, check=True,
  220. ).stdout.decode("utf-8")
  221. self._fout.write("// midout \n")
  222. self._fout.write(cvt)
  223. if cvt.find(" half,") > 0:
  224. change = open(self._fout.name).read().replace(" half,", " __fp16,")
  225. with open("fix_fp16_bin_reduce.h", "w") as fix_fp16:
  226. fix_fp16.write(change)
  227. msg = (
  228. "WARNING:\n"
  229. "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n"
  230. "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n"
  231. "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!"
  232. )
  233. print(msg)
  234. def main():
  235. parser = argparse.ArgumentParser(
  236. description="generate header file for reducing binary size by "
  237. "stripping unused oprs in a particular network; output file would "
  238. "be written to bin_reduce.h",
  239. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  240. )
  241. parser.add_argument(
  242. "inputs",
  243. nargs="+",
  244. help="input files that describe specific traits of the network; "
  245. "can be one of the following:"
  246. " 1. json files generated by "
  247. "megbrain.serialize_comp_graph_to_file() in python; "
  248. " 2. trace files generated by midout library",
  249. )
  250. default_file = os.path.join(
  251. HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h"
  252. )
  253. is_megvii3 = HeaderGen.get_megvii3_root()
  254. if is_megvii3:
  255. default_file = os.path.join(
  256. HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h"
  257. )
  258. parser.add_argument("-o", "--output", help="output file", default=default_file)
  259. args = parser.parse_args()
  260. print("config output file: {}".format(args.output))
  261. gen = HeaderGen()
  262. for i in args.inputs:
  263. print("==== processing {}".format(i))
  264. with open(i) as fin:
  265. if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
  266. gen.extend_midout(i)
  267. gen.extend_elemwise_mode_info(i)
  268. else:
  269. fin.seek(0)
  270. gen.extend_netinfo(json.loads(fin.read()))
  271. with open(args.output, "w") as fout:
  272. gen.generate(fout)
  273. if __name__ == "__main__":
  274. main()