|
|
@@ -0,0 +1,214 @@ |
|
|
|
#!/usr/bin/env python3 |
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
|
|
|
|
import sys |
|
|
|
import re |
|
|
|
|
|
|
|
if sys.version_info[0] != 3 or sys.version_info[1] < 5: |
|
|
|
print('This script requires Python version 3.5') |
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
import argparse |
|
|
|
import json |
|
|
|
import os |
|
|
|
import subprocess |
|
|
|
import tempfile |
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
MIDOUT_TRACE_MAGIC = 'midout_trace v1\n' |
|
|
|
|
|
|
|
class HeaderGen: |
|
|
|
_dtypes = None |
|
|
|
_oprs = None |
|
|
|
_fout = None |
|
|
|
_elemwise_modes = None |
|
|
|
_has_netinfo = False |
|
|
|
_midout_files = None |
|
|
|
|
|
|
|
_file_without_hash = False |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self._dtypes = set() |
|
|
|
self._oprs = set() |
|
|
|
self._elemwise_modes = set() |
|
|
|
self._graph_hashes = set() |
|
|
|
self._midout_files = [] |
|
|
|
|
|
|
|
_megvii3_root_cache = None |
|
|
|
@classmethod |
|
|
|
def get_megvii3_root(cls): |
|
|
|
if cls._megvii3_root_cache is not None: |
|
|
|
return cls._megvii3_root_cache |
|
|
|
wd = Path(__file__).resolve().parent |
|
|
|
while wd.parent != wd: |
|
|
|
workspace_file = wd / 'WORKSPACE' |
|
|
|
if workspace_file.is_file(): |
|
|
|
cls._megvii3_root_cache = str(wd) |
|
|
|
return cls._megvii3_root_cache |
|
|
|
wd = wd.parent |
|
|
|
raise RuntimeError('This script is supposed to run in megvii3.') |
|
|
|
|
|
|
|
def extend_netinfo(self, data): |
|
|
|
self._has_netinfo = True |
|
|
|
if 'hash' not in data: |
|
|
|
self._file_without_hash = True |
|
|
|
else: |
|
|
|
self._graph_hashes.add(str(data['hash'])) |
|
|
|
for i in data['dtypes']: |
|
|
|
self._dtypes.add(i) |
|
|
|
for i in data['opr_types']: |
|
|
|
self._oprs.add(i) |
|
|
|
for i in data['elemwise_modes']: |
|
|
|
self._elemwise_modes.add(i) |
|
|
|
|
|
|
|
def extend_midout(self, fname): |
|
|
|
self._midout_files.append(fname) |
|
|
|
|
|
|
|
def generate(self, fout): |
|
|
|
self._fout = fout |
|
|
|
self._write_def('MGB_BINREDUCE_VERSION', '20190219') |
|
|
|
if self._has_netinfo: |
|
|
|
self._write_dtype() |
|
|
|
self._write_elemwise_modes() |
|
|
|
self._write_oprs() |
|
|
|
self._write_hash() |
|
|
|
self._write_midout() |
|
|
|
del self._fout |
|
|
|
|
|
|
|
def strip_opr_name_with_version(self, name): |
|
|
|
pos = len(name) |
|
|
|
t = re.search(r'V\d+$', name) |
|
|
|
if t: |
|
|
|
pos = t.start() |
|
|
|
return name[:pos] |
|
|
|
|
|
|
|
def _write_oprs(self): |
|
|
|
defs = ['}', 'namespace opr {'] |
|
|
|
already_declare = set() |
|
|
|
already_instance = set() |
|
|
|
for i in self._oprs: |
|
|
|
i = self.strip_opr_name_with_version(i) |
|
|
|
if i in already_declare: |
|
|
|
continue |
|
|
|
else: |
|
|
|
already_declare.add(i) |
|
|
|
|
|
|
|
defs.append('class {};'.format(i)) |
|
|
|
defs.append('}') |
|
|
|
defs.append('namespace serialization {') |
|
|
|
defs.append(""" |
|
|
|
template<class Opr, class Callee> |
|
|
|
struct OprRegistryCaller { |
|
|
|
}; """) |
|
|
|
for i in sorted(self._oprs): |
|
|
|
i = self.strip_opr_name_with_version(i) |
|
|
|
if i in already_instance: |
|
|
|
continue |
|
|
|
else: |
|
|
|
already_instance.add(i) |
|
|
|
|
|
|
|
defs.append(""" |
|
|
|
template<class Callee> |
|
|
|
struct OprRegistryCaller<opr::{}, Callee>: public |
|
|
|
OprRegistryCallerDefaultImpl<Callee> {{ |
|
|
|
}}; """.format(i)) |
|
|
|
self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs) |
|
|
|
|
|
|
|
def _write_elemwise_modes(self): |
|
|
|
with tempfile.NamedTemporaryFile() as ftmp: |
|
|
|
fpath = os.path.realpath(ftmp.name) |
|
|
|
subprocess.check_call( |
|
|
|
['./brain/megbrain/dnn/scripts/gen_param_defs.py', |
|
|
|
'--write-enum-items', 'Elemwise:Mode', |
|
|
|
'./brain/megbrain/dnn/scripts/opr_param_defs.py', |
|
|
|
fpath], |
|
|
|
cwd=self.get_megvii3_root() |
|
|
|
) |
|
|
|
|
|
|
|
with open(fpath) as fin: |
|
|
|
mode_list = [i.strip() for i in fin] |
|
|
|
|
|
|
|
for i in mode_list: |
|
|
|
if i in self._elemwise_modes: |
|
|
|
content = '_cb({})'.format(i) |
|
|
|
else: |
|
|
|
content = '' |
|
|
|
self._write_def( |
|
|
|
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content) |
|
|
|
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', |
|
|
|
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') |
|
|
|
|
|
|
|
def _write_dtype(self): |
|
|
|
if 'Float16' not in self._dtypes: |
|
|
|
# MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 |
|
|
|
# support in the past; however `FLOT16' is really a typo. We plan to |
|
|
|
# change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. |
|
|
|
# To prevent issues in the transition, we decide to define both |
|
|
|
# macros (`FLOT16' and `FLOAT16') here. |
|
|
|
# |
|
|
|
# In the future when the situation is settled and no one would ever |
|
|
|
# use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be |
|
|
|
# safely deleted. |
|
|
|
self._write_def('MEGDNN_DISABLE_FLOT16', 1) |
|
|
|
self._write_def('MEGDNN_DISABLE_FLOAT16', 1) |
|
|
|
|
|
|
|
def _write_hash(self): |
|
|
|
if self._file_without_hash: |
|
|
|
print('WARNING: network info has no graph hash. Using json file ' |
|
|
|
'generated by MegBrain >= 7.28.0 is recommended') |
|
|
|
else: |
|
|
|
defs = 'ULL,'.join(self._graph_hashes) + 'ULL' |
|
|
|
self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs) |
|
|
|
|
|
|
|
def _write_def(self, name, val): |
|
|
|
if isinstance(val, list): |
|
|
|
val = '\n'.join(val) |
|
|
|
val = str(val).strip().replace('\n', ' \\\n') |
|
|
|
self._fout.write('#define {} {}\n'.format(name, val)) |
|
|
|
|
|
|
|
def _write_midout(self): |
|
|
|
if not self._midout_files: |
|
|
|
return |
|
|
|
|
|
|
|
gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', |
|
|
|
'gen_header.py') |
|
|
|
cvt = subprocess.run( |
|
|
|
[gen] + self._midout_files, |
|
|
|
stdout=subprocess.PIPE, check=True, |
|
|
|
).stdout.decode('utf-8') |
|
|
|
self._fout.write('// midout \n') |
|
|
|
self._fout.write(cvt) |
|
|
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser( |
|
|
|
description='generate header file for reducing binary size by ' |
|
|
|
'stripping unused oprs in a particular network; output file would ' |
|
|
|
'be written to bin_reduce.h', |
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
parser.add_argument( |
|
|
|
'inputs', nargs='+', |
|
|
|
help='input files that describe specific traits of the network; ' |
|
|
|
'can be one of the following:' |
|
|
|
' 1. json files generated by ' |
|
|
|
'megbrain.serialize_comp_graph_to_file() in python; ' |
|
|
|
' 2. trace files generated by midout library') |
|
|
|
parser.add_argument('-o', '--output', help='output file', |
|
|
|
default=os.path.join(HeaderGen.get_megvii3_root(), |
|
|
|
'utils', 'bin_reduce.h')) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
gen = HeaderGen() |
|
|
|
for i in args.inputs: |
|
|
|
print('==== processing {}'.format(i)) |
|
|
|
with open(i) as fin: |
|
|
|
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: |
|
|
|
gen.extend_midout(i) |
|
|
|
else: |
|
|
|
fin.seek(0) |
|
|
|
gen.extend_netinfo(json.loads(fin.read())) |
|
|
|
|
|
|
|
with open(args.output, 'w') as fout: |
|
|
|
gen.generate(fout) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
main() |