You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_header_for_bin_reduce.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. #
  5. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. #
  7. # Unless required by applicable law or agreed to in writing,
  8. # software distributed under the License is distributed on an
  9. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. import argparse
  11. import json
  12. import os
  13. import re
  14. import subprocess
  15. import sys
  16. import tempfile
  17. from pathlib import Path
  18. if sys.version_info[0] != 3 or sys.version_info[1] < 5:
  19. print("This script requires Python version 3.5")
  20. sys.exit(1)
  21. MIDOUT_TRACE_MAGIC = "midout_trace v1\n"
  22. class HeaderGen:
  23. _dtypes = None
  24. _oprs = None
  25. _fout = None
  26. _elemwise_modes = None
  27. _has_netinfo = False
  28. _midout_files = None
  29. _file_without_hash = False
  30. def __init__(self):
  31. self._dtypes = set()
  32. self._oprs = set()
  33. self._elemwise_modes = set()
  34. self._graph_hashes = set()
  35. self._midout_files = []
  36. _megvii3_root_cache = None
  37. @classmethod
  38. def get_megvii3_root(cls):
  39. if cls._megvii3_root_cache is not None:
  40. return cls._megvii3_root_cache
  41. wd = Path(__file__).resolve().parent
  42. while wd.parent != wd:
  43. workspace_file = wd / "WORKSPACE"
  44. if workspace_file.is_file():
  45. cls._megvii3_root_cache = str(wd)
  46. return cls._megvii3_root_cache
  47. wd = wd.parent
  48. return None
  49. _megengine_root_cache = None
  50. @classmethod
  51. def get_megengine_root(cls):
  52. if cls._megengine_root_cache is not None:
  53. return cls._megengine_root_cache
  54. wd = Path(__file__).resolve().parent.parent
  55. cls._megengine_root_cache = str(wd)
  56. return cls._megengine_root_cache
  57. def extend_netinfo(self, data):
  58. self._has_netinfo = True
  59. if "hash" not in data:
  60. self._file_without_hash = True
  61. else:
  62. self._graph_hashes.add(str(data["hash"]))
  63. for i in data["dtypes"]:
  64. self._dtypes.add(i)
  65. for i in data["opr_types"]:
  66. self._oprs.add(i)
  67. for i in data["elemwise_modes"]:
  68. self._elemwise_modes.add(i)
  69. def extend_midout(self, fname):
  70. self._midout_files.append(fname)
  71. def generate(self, fout):
  72. self._fout = fout
  73. self._write_def("MGB_BINREDUCE_VERSION", "20190219")
  74. if self._has_netinfo:
  75. self._write_dtype()
  76. self._write_elemwise_modes()
  77. self._write_oprs()
  78. self._write_hash()
  79. self._write_midout()
  80. del self._fout
  81. def strip_opr_name_with_version(self, name):
  82. pos = len(name)
  83. t = re.search(r"V\d+$", name)
  84. if t:
  85. pos = t.start()
  86. return name[:pos]
  87. def _write_oprs(self):
  88. defs = ["}", "namespace opr {"]
  89. already_declare = set()
  90. already_instance = set()
  91. for i in self._oprs:
  92. i = self.strip_opr_name_with_version(i)
  93. if i in already_declare:
  94. continue
  95. else:
  96. already_declare.add(i)
  97. defs.append("class {};".format(i))
  98. defs.append("}")
  99. defs.append("namespace serialization {")
  100. defs.append(
  101. """
  102. template<class Opr, class Callee>
  103. struct OprRegistryCaller {
  104. }; """
  105. )
  106. for i in sorted(self._oprs):
  107. i = self.strip_opr_name_with_version(i)
  108. if i in already_instance:
  109. continue
  110. else:
  111. already_instance.add(i)
  112. defs.append(
  113. """
  114. template<class Callee>
  115. struct OprRegistryCaller<opr::{}, Callee>: public
  116. OprRegistryCallerDefaultImpl<Callee> {{
  117. }}; """.format(
  118. i
  119. )
  120. )
  121. self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs)
  122. def _write_elemwise_modes(self):
  123. with tempfile.NamedTemporaryFile() as ftmp:
  124. fpath = os.path.realpath(ftmp.name)
  125. subprocess.check_call(
  126. [
  127. "./dnn/scripts/gen_param_defs.py",
  128. "--write-enum-items",
  129. "Elemwise:Mode",
  130. "./dnn/scripts/opr_param_defs.py",
  131. fpath,
  132. ],
  133. cwd=self.get_megengine_root(),
  134. )
  135. with open(fpath) as fin:
  136. mode_list = [i.strip() for i in fin]
  137. for i in mode_list:
  138. i = i.split(" ")[0].split("=")[0]
  139. if i in self._elemwise_modes:
  140. content = "_cb({})".format(i)
  141. else:
  142. content = ""
  143. self._write_def(
  144. "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(
  145. i.split(" ")[0].split("=")[0]
  146. ),
  147. content,
  148. )
  149. self._write_def(
  150. "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)",
  151. "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)",
  152. )
  153. def _write_dtype(self):
  154. if "Float16" not in self._dtypes:
  155. # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
  156. # support in the past; however `FLOT16' is really a typo. We plan to
  157. # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
  158. # To prevent issues in the transition, we decide to define both
  159. # macros (`FLOT16' and `FLOAT16') here.
  160. #
  161. # In the future when the situation is settled and no one would ever
  162. # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
  163. # safely deleted.
  164. self._write_def("MEGDNN_DISABLE_FLOT16", 1)
  165. self._write_def("MEGDNN_DISABLE_FLOAT16", 1)
  166. def _write_hash(self):
  167. if self._file_without_hash:
  168. print(
  169. "WARNING: network info has no graph hash. Using json file "
  170. "generated by MegBrain >= 7.28.0 is recommended"
  171. )
  172. else:
  173. defs = "ULL,".join(self._graph_hashes) + "ULL"
  174. self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs)
  175. def _write_def(self, name, val):
  176. if isinstance(val, list):
  177. val = "\n".join(val)
  178. val = str(val).strip().replace("\n", " \\\n")
  179. self._fout.write("#define {} {}\n".format(name, val))
  180. def _write_midout(self):
  181. if not self._midout_files:
  182. return
  183. gen = os.path.join(
  184. self.get_megengine_root(), "third_party", "midout", "gen_header.py"
  185. )
  186. if self.get_megvii3_root():
  187. gen = os.path.join(
  188. self.get_megvii3_root(), "brain", "midout", "gen_header.py"
  189. )
  190. print("use {} to gen bin_reduce header".format(gen))
  191. cvt = subprocess.run(
  192. [gen] + self._midout_files, stdout=subprocess.PIPE, check=True,
  193. ).stdout.decode("utf-8")
  194. self._fout.write("// midout \n")
  195. self._fout.write(cvt)
  196. if cvt.find(" half,") > 0:
  197. change = open(self._fout.name).read().replace(" half,", " __fp16,")
  198. with open("fix_fp16_bin_reduce.h", "w") as fix_fp16:
  199. fix_fp16.write(change)
  200. msg = (
  201. "WARNING:\n"
  202. "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n"
  203. "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n"
  204. "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!"
  205. )
  206. print(msg)
  207. def main():
  208. parser = argparse.ArgumentParser(
  209. description="generate header file for reducing binary size by "
  210. "stripping unused oprs in a particular network; output file would "
  211. "be written to bin_reduce.h",
  212. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  213. )
  214. parser.add_argument(
  215. "inputs",
  216. nargs="+",
  217. help="input files that describe specific traits of the network; "
  218. "can be one of the following:"
  219. " 1. json files generated by "
  220. "megbrain.serialize_comp_graph_to_file() in python; "
  221. " 2. trace files generated by midout library",
  222. )
  223. default_file = os.path.join(
  224. HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h"
  225. )
  226. is_megvii3 = HeaderGen.get_megvii3_root()
  227. if is_megvii3:
  228. default_file = os.path.join(
  229. HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h"
  230. )
  231. parser.add_argument("-o", "--output", help="output file", default=default_file)
  232. args = parser.parse_args()
  233. print("config output file: {}".format(args.output))
  234. gen = HeaderGen()
  235. for i in args.inputs:
  236. print("==== processing {}".format(i))
  237. with open(i) as fin:
  238. if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
  239. gen.extend_midout(i)
  240. else:
  241. fin.seek(0)
  242. gen.extend_netinfo(json.loads(fin.read()))
  243. with open(args.output, "w") as fout:
  244. gen.generate(fout)
  245. if __name__ == "__main__":
  246. main()