Browse Source

fix(midout): formatting midout tools

GitOrigin-RevId: 9aa6a9ec57
release-1.10
Megvii Engine Team 3 years ago
parent
commit
9be8de6025
1 changed files with 110 additions and 79 deletions
  1. +110
    -79
      tools/gen_header_for_bin_reduce.py

+ 110
- 79
tools/gen_header_for_bin_reduce.py View File

@@ -8,21 +8,22 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import sys
import re

if sys.version_info[0] != 3 or sys.version_info[1] < 5:
print('This script requires Python version 3.5')
sys.exit(1)

import argparse import argparse
import json import json
import os import os
import re
import subprocess import subprocess
import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path


MIDOUT_TRACE_MAGIC = 'midout_trace v1\n'
if sys.version_info[0] != 3 or sys.version_info[1] < 5:
print("This script requires Python version 3.5")
sys.exit(1)


MIDOUT_TRACE_MAGIC = "midout_trace v1\n"



class HeaderGen: class HeaderGen:
_dtypes = None _dtypes = None
@@ -42,20 +43,22 @@ class HeaderGen:
self._midout_files = [] self._midout_files = []


_megvii3_root_cache = None _megvii3_root_cache = None

@classmethod @classmethod
def get_megvii3_root(cls): def get_megvii3_root(cls):
if cls._megvii3_root_cache is not None: if cls._megvii3_root_cache is not None:
return cls._megvii3_root_cache return cls._megvii3_root_cache
wd = Path(__file__).resolve().parent wd = Path(__file__).resolve().parent
while wd.parent != wd: while wd.parent != wd:
workspace_file = wd / 'WORKSPACE'
if workspace_file.is_file():
cls._megvii3_root_cache = str(wd)
return cls._megvii3_root_cache
wd = wd.parent
workspace_file = wd / "WORKSPACE"
if workspace_file.is_file():
cls._megvii3_root_cache = str(wd)
return cls._megvii3_root_cache
wd = wd.parent
return None return None


_megengine_root_cache = None _megengine_root_cache = None

@classmethod @classmethod
def get_megengine_root(cls): def get_megengine_root(cls):
if cls._megengine_root_cache is not None: if cls._megengine_root_cache is not None:
@@ -66,15 +69,15 @@ class HeaderGen:


def extend_netinfo(self, data): def extend_netinfo(self, data):
self._has_netinfo = True self._has_netinfo = True
if 'hash' not in data:
if "hash" not in data:
self._file_without_hash = True self._file_without_hash = True
else: else:
self._graph_hashes.add(str(data['hash']))
for i in data['dtypes']:
self._graph_hashes.add(str(data["hash"]))
for i in data["dtypes"]:
self._dtypes.add(i) self._dtypes.add(i)
for i in data['opr_types']:
for i in data["opr_types"]:
self._oprs.add(i) self._oprs.add(i)
for i in data['elemwise_modes']:
for i in data["elemwise_modes"]:
self._elemwise_modes.add(i) self._elemwise_modes.add(i)


def extend_midout(self, fname): def extend_midout(self, fname):
@@ -82,7 +85,7 @@ class HeaderGen:


def generate(self, fout): def generate(self, fout):
self._fout = fout self._fout = fout
self._write_def('MGB_BINREDUCE_VERSION', '20190219')
self._write_def("MGB_BINREDUCE_VERSION", "20190219")
if self._has_netinfo: if self._has_netinfo:
self._write_dtype() self._write_dtype()
self._write_elemwise_modes() self._write_elemwise_modes()
@@ -93,13 +96,13 @@ class HeaderGen:


def strip_opr_name_with_version(self, name): def strip_opr_name_with_version(self, name):
pos = len(name) pos = len(name)
t = re.search(r'V\d+$', name)
t = re.search(r"V\d+$", name)
if t: if t:
pos = t.start() pos = t.start()
return name[:pos] return name[:pos]


def _write_oprs(self): def _write_oprs(self):
defs = ['}', 'namespace opr {']
defs = ["}", "namespace opr {"]
already_declare = set() already_declare = set()
already_instance = set() already_instance = set()
for i in self._oprs: for i in self._oprs:
@@ -109,13 +112,15 @@ class HeaderGen:
else: else:
already_declare.add(i) already_declare.add(i)


defs.append('class {};'.format(i))
defs.append('}')
defs.append('namespace serialization {')
defs.append("""
defs.append("class {};".format(i))
defs.append("}")
defs.append("namespace serialization {")
defs.append(
"""
template<class Opr, class Callee> template<class Opr, class Callee>
struct OprRegistryCaller { struct OprRegistryCaller {
}; """)
}; """
)
for i in sorted(self._oprs): for i in sorted(self._oprs):
i = self.strip_opr_name_with_version(i) i = self.strip_opr_name_with_version(i)
if i in already_instance: if i in already_instance:
@@ -123,40 +128,53 @@ class HeaderGen:
else: else:
already_instance.add(i) already_instance.add(i)


defs.append("""
defs.append(
"""
template<class Callee> template<class Callee>
struct OprRegistryCaller<opr::{}, Callee>: public struct OprRegistryCaller<opr::{}, Callee>: public
OprRegistryCallerDefaultImpl<Callee> {{ OprRegistryCallerDefaultImpl<Callee> {{
}}; """.format(i))
self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs)
}}; """.format(
i
)
)
self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs)


def _write_elemwise_modes(self): def _write_elemwise_modes(self):
with tempfile.NamedTemporaryFile() as ftmp: with tempfile.NamedTemporaryFile() as ftmp:
fpath = os.path.realpath(ftmp.name) fpath = os.path.realpath(ftmp.name)
subprocess.check_call( subprocess.check_call(
['./dnn/scripts/gen_param_defs.py',
'--write-enum-items', 'Elemwise:Mode',
'./dnn/scripts/opr_param_defs.py',
fpath],
cwd=self.get_megengine_root()
[
"./dnn/scripts/gen_param_defs.py",
"--write-enum-items",
"Elemwise:Mode",
"./dnn/scripts/opr_param_defs.py",
fpath,
],
cwd=self.get_megengine_root(),
) )


with open(fpath) as fin: with open(fpath) as fin:
mode_list = [i.strip() for i in fin] mode_list = [i.strip() for i in fin]


for i in mode_list: for i in mode_list:
i = i.split(' ')[0].split('=')[0]
i = i.split(" ")[0].split("=")[0]
if i in self._elemwise_modes: if i in self._elemwise_modes:
content = '_cb({})'.format(i)
content = "_cb({})".format(i)
else: else:
content = ''
content = ""
self._write_def( self._write_def(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content)
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(
i.split(" ")[0].split("=")[0]
),
content,
)
self._write_def(
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)",
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)",
)


def _write_dtype(self): def _write_dtype(self):
if 'Float16' not in self._dtypes:
if "Float16" not in self._dtypes:
# MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
# support in the past; however `FLOT16' is really a typo. We plan to # support in the past; however `FLOT16' is really a typo. We plan to
# change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
@@ -166,74 +184,86 @@ class HeaderGen:
# In the future when the situation is settled and no one would ever # In the future when the situation is settled and no one would ever
# use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
# safely deleted. # safely deleted.
self._write_def('MEGDNN_DISABLE_FLOT16', 1)
self._write_def('MEGDNN_DISABLE_FLOAT16', 1)
self._write_def("MEGDNN_DISABLE_FLOT16", 1)
self._write_def("MEGDNN_DISABLE_FLOAT16", 1)


def _write_hash(self): def _write_hash(self):
if self._file_without_hash: if self._file_without_hash:
print('WARNING: network info has no graph hash. Using json file '
'generated by MegBrain >= 7.28.0 is recommended')
print(
"WARNING: network info has no graph hash. Using json file "
"generated by MegBrain >= 7.28.0 is recommended"
)
else: else:
defs = 'ULL,'.join(self._graph_hashes) + 'ULL'
self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs)
defs = "ULL,".join(self._graph_hashes) + "ULL"
self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs)


def _write_def(self, name, val): def _write_def(self, name, val):
if isinstance(val, list): if isinstance(val, list):
val = '\n'.join(val)
val = str(val).strip().replace('\n', ' \\\n')
self._fout.write('#define {} {}\n'.format(name, val))
val = "\n".join(val)
val = str(val).strip().replace("\n", " \\\n")
self._fout.write("#define {} {}\n".format(name, val))


def _write_midout(self): def _write_midout(self):
if not self._midout_files: if not self._midout_files:
return return


gen = os.path.join(self.get_megengine_root(), 'third_party', 'midout', 'gen_header.py')
gen = os.path.join(
self.get_megengine_root(), "third_party", "midout", "gen_header.py"
)
if self.get_megvii3_root(): if self.get_megvii3_root():
gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', 'gen_header.py')
print('use {} to gen bin_reduce header'.format(gen))
gen = os.path.join(
self.get_megvii3_root(), "brain", "midout", "gen_header.py"
)
print("use {} to gen bin_reduce header".format(gen))
cvt = subprocess.run( cvt = subprocess.run(
[gen] + self._midout_files,
stdout=subprocess.PIPE, check=True,
).stdout.decode('utf-8')
self._fout.write('// midout \n')
[gen] + self._midout_files, stdout=subprocess.PIPE, check=True,
).stdout.decode("utf-8")
self._fout.write("// midout \n")
self._fout.write(cvt) self._fout.write(cvt)
if cvt.find(" half,") > 0: if cvt.find(" half,") > 0:
change = open(self._fout.name).read().replace(" half,", " __fp16,") change = open(self._fout.name).read().replace(" half,", " __fp16,")
with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: with open("fix_fp16_bin_reduce.h", "w") as fix_fp16:
fix_fp16.write(change) fix_fp16.write(change)
msg = ( msg = (
"WARNING:\n"
"hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n"
"which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n"
"then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!"
)
"WARNING:\n"
"hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n"
"which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n"
"then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!"
)
print(msg) print(msg)




def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate header file for reducing binary size by '
'stripping unused oprs in a particular network; output file would '
'be written to bin_reduce.h',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
description="generate header file for reducing binary size by "
"stripping unused oprs in a particular network; output file would "
"be written to bin_reduce.h",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument( parser.add_argument(
'inputs', nargs='+',
help='input files that describe specific traits of the network; '
'can be one of the following:'
' 1. json files generated by '
'megbrain.serialize_comp_graph_to_file() in python; '
' 2. trace files generated by midout library')
default_file=os.path.join(HeaderGen.get_megengine_root(), 'src', 'bin_reduce_cmake.h')
"inputs",
nargs="+",
help="input files that describe specific traits of the network; "
"can be one of the following:"
" 1. json files generated by "
"megbrain.serialize_comp_graph_to_file() in python; "
" 2. trace files generated by midout library",
)
default_file = os.path.join(
HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h"
)
is_megvii3 = HeaderGen.get_megvii3_root() is_megvii3 = HeaderGen.get_megvii3_root()
if is_megvii3: if is_megvii3:
default_file=os.path.join(HeaderGen.get_megvii3_root(), 'utils', 'bin_reduce.h')
parser.add_argument('-o', '--output', help='output file', default=default_file)
default_file = os.path.join(
HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h"
)
parser.add_argument("-o", "--output", help="output file", default=default_file)
args = parser.parse_args() args = parser.parse_args()
print('config output file: {}'.format(args.output))
print("config output file: {}".format(args.output))


gen = HeaderGen() gen = HeaderGen()
for i in args.inputs: for i in args.inputs:
print('==== processing {}'.format(i))
print("==== processing {}".format(i))
with open(i) as fin: with open(i) as fin:
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
gen.extend_midout(i) gen.extend_midout(i)
@@ -241,8 +271,9 @@ def main():
fin.seek(0) fin.seek(0)
gen.extend_netinfo(json.loads(fin.read())) gen.extend_netinfo(json.loads(fin.read()))


with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
gen.generate(fout) gen.generate(fout)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

Loading…
Cancel
Save