|
@@ -8,21 +8,22 @@ |
|
|
# software distributed under the License is distributed on an |
|
|
# software distributed under the License is distributed on an |
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
if sys.version_info[0] != 3 or sys.version_info[1] < 5: |
|
|
|
|
|
print('This script requires Python version 3.5') |
|
|
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import argparse |
|
|
import json |
|
|
import json |
|
|
import os |
|
|
import os |
|
|
|
|
|
import re |
|
|
import subprocess |
|
|
import subprocess |
|
|
|
|
|
import sys |
|
|
import tempfile |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
MIDOUT_TRACE_MAGIC = 'midout_trace v1\n' |
|
|
|
|
|
|
|
|
if sys.version_info[0] != 3 or sys.version_info[1] < 5: |
|
|
|
|
|
print("This script requires Python version 3.5") |
|
|
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MIDOUT_TRACE_MAGIC = "midout_trace v1\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HeaderGen: |
|
|
class HeaderGen: |
|
|
_dtypes = None |
|
|
_dtypes = None |
|
@@ -42,20 +43,22 @@ class HeaderGen: |
|
|
self._midout_files = [] |
|
|
self._midout_files = [] |
|
|
|
|
|
|
|
|
_megvii3_root_cache = None |
|
|
_megvii3_root_cache = None |
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
def get_megvii3_root(cls): |
|
|
def get_megvii3_root(cls): |
|
|
if cls._megvii3_root_cache is not None: |
|
|
if cls._megvii3_root_cache is not None: |
|
|
return cls._megvii3_root_cache |
|
|
return cls._megvii3_root_cache |
|
|
wd = Path(__file__).resolve().parent |
|
|
wd = Path(__file__).resolve().parent |
|
|
while wd.parent != wd: |
|
|
while wd.parent != wd: |
|
|
workspace_file = wd / 'WORKSPACE' |
|
|
|
|
|
if workspace_file.is_file(): |
|
|
|
|
|
cls._megvii3_root_cache = str(wd) |
|
|
|
|
|
return cls._megvii3_root_cache |
|
|
|
|
|
wd = wd.parent |
|
|
|
|
|
|
|
|
workspace_file = wd / "WORKSPACE" |
|
|
|
|
|
if workspace_file.is_file(): |
|
|
|
|
|
cls._megvii3_root_cache = str(wd) |
|
|
|
|
|
return cls._megvii3_root_cache |
|
|
|
|
|
wd = wd.parent |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
_megengine_root_cache = None |
|
|
_megengine_root_cache = None |
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
def get_megengine_root(cls): |
|
|
def get_megengine_root(cls): |
|
|
if cls._megengine_root_cache is not None: |
|
|
if cls._megengine_root_cache is not None: |
|
@@ -66,15 +69,15 @@ class HeaderGen: |
|
|
|
|
|
|
|
|
def extend_netinfo(self, data): |
|
|
def extend_netinfo(self, data): |
|
|
self._has_netinfo = True |
|
|
self._has_netinfo = True |
|
|
if 'hash' not in data: |
|
|
|
|
|
|
|
|
if "hash" not in data: |
|
|
self._file_without_hash = True |
|
|
self._file_without_hash = True |
|
|
else: |
|
|
else: |
|
|
self._graph_hashes.add(str(data['hash'])) |
|
|
|
|
|
for i in data['dtypes']: |
|
|
|
|
|
|
|
|
self._graph_hashes.add(str(data["hash"])) |
|
|
|
|
|
for i in data["dtypes"]: |
|
|
self._dtypes.add(i) |
|
|
self._dtypes.add(i) |
|
|
for i in data['opr_types']: |
|
|
|
|
|
|
|
|
for i in data["opr_types"]: |
|
|
self._oprs.add(i) |
|
|
self._oprs.add(i) |
|
|
for i in data['elemwise_modes']: |
|
|
|
|
|
|
|
|
for i in data["elemwise_modes"]: |
|
|
self._elemwise_modes.add(i) |
|
|
self._elemwise_modes.add(i) |
|
|
|
|
|
|
|
|
def extend_midout(self, fname): |
|
|
def extend_midout(self, fname): |
|
@@ -82,7 +85,7 @@ class HeaderGen: |
|
|
|
|
|
|
|
|
def generate(self, fout): |
|
|
def generate(self, fout): |
|
|
self._fout = fout |
|
|
self._fout = fout |
|
|
self._write_def('MGB_BINREDUCE_VERSION', '20190219') |
|
|
|
|
|
|
|
|
self._write_def("MGB_BINREDUCE_VERSION", "20190219") |
|
|
if self._has_netinfo: |
|
|
if self._has_netinfo: |
|
|
self._write_dtype() |
|
|
self._write_dtype() |
|
|
self._write_elemwise_modes() |
|
|
self._write_elemwise_modes() |
|
@@ -93,13 +96,13 @@ class HeaderGen: |
|
|
|
|
|
|
|
|
def strip_opr_name_with_version(self, name): |
|
|
def strip_opr_name_with_version(self, name): |
|
|
pos = len(name) |
|
|
pos = len(name) |
|
|
t = re.search(r'V\d+$', name) |
|
|
|
|
|
|
|
|
t = re.search(r"V\d+$", name) |
|
|
if t: |
|
|
if t: |
|
|
pos = t.start() |
|
|
pos = t.start() |
|
|
return name[:pos] |
|
|
return name[:pos] |
|
|
|
|
|
|
|
|
def _write_oprs(self): |
|
|
def _write_oprs(self): |
|
|
defs = ['}', 'namespace opr {'] |
|
|
|
|
|
|
|
|
defs = ["}", "namespace opr {"] |
|
|
already_declare = set() |
|
|
already_declare = set() |
|
|
already_instance = set() |
|
|
already_instance = set() |
|
|
for i in self._oprs: |
|
|
for i in self._oprs: |
|
@@ -109,13 +112,15 @@ class HeaderGen: |
|
|
else: |
|
|
else: |
|
|
already_declare.add(i) |
|
|
already_declare.add(i) |
|
|
|
|
|
|
|
|
defs.append('class {};'.format(i)) |
|
|
|
|
|
defs.append('}') |
|
|
|
|
|
defs.append('namespace serialization {') |
|
|
|
|
|
defs.append(""" |
|
|
|
|
|
|
|
|
defs.append("class {};".format(i)) |
|
|
|
|
|
defs.append("}") |
|
|
|
|
|
defs.append("namespace serialization {") |
|
|
|
|
|
defs.append( |
|
|
|
|
|
""" |
|
|
template<class Opr, class Callee> |
|
|
template<class Opr, class Callee> |
|
|
struct OprRegistryCaller { |
|
|
struct OprRegistryCaller { |
|
|
}; """) |
|
|
|
|
|
|
|
|
}; """ |
|
|
|
|
|
) |
|
|
for i in sorted(self._oprs): |
|
|
for i in sorted(self._oprs): |
|
|
i = self.strip_opr_name_with_version(i) |
|
|
i = self.strip_opr_name_with_version(i) |
|
|
if i in already_instance: |
|
|
if i in already_instance: |
|
@@ -123,40 +128,53 @@ class HeaderGen: |
|
|
else: |
|
|
else: |
|
|
already_instance.add(i) |
|
|
already_instance.add(i) |
|
|
|
|
|
|
|
|
defs.append(""" |
|
|
|
|
|
|
|
|
defs.append( |
|
|
|
|
|
""" |
|
|
template<class Callee> |
|
|
template<class Callee> |
|
|
struct OprRegistryCaller<opr::{}, Callee>: public |
|
|
struct OprRegistryCaller<opr::{}, Callee>: public |
|
|
OprRegistryCallerDefaultImpl<Callee> {{ |
|
|
OprRegistryCallerDefaultImpl<Callee> {{ |
|
|
}}; """.format(i)) |
|
|
|
|
|
self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs) |
|
|
|
|
|
|
|
|
}}; """.format( |
|
|
|
|
|
i |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs) |
|
|
|
|
|
|
|
|
def _write_elemwise_modes(self): |
|
|
def _write_elemwise_modes(self): |
|
|
with tempfile.NamedTemporaryFile() as ftmp: |
|
|
with tempfile.NamedTemporaryFile() as ftmp: |
|
|
fpath = os.path.realpath(ftmp.name) |
|
|
fpath = os.path.realpath(ftmp.name) |
|
|
subprocess.check_call( |
|
|
subprocess.check_call( |
|
|
['./dnn/scripts/gen_param_defs.py', |
|
|
|
|
|
'--write-enum-items', 'Elemwise:Mode', |
|
|
|
|
|
'./dnn/scripts/opr_param_defs.py', |
|
|
|
|
|
fpath], |
|
|
|
|
|
cwd=self.get_megengine_root() |
|
|
|
|
|
|
|
|
[ |
|
|
|
|
|
"./dnn/scripts/gen_param_defs.py", |
|
|
|
|
|
"--write-enum-items", |
|
|
|
|
|
"Elemwise:Mode", |
|
|
|
|
|
"./dnn/scripts/opr_param_defs.py", |
|
|
|
|
|
fpath, |
|
|
|
|
|
], |
|
|
|
|
|
cwd=self.get_megengine_root(), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
with open(fpath) as fin: |
|
|
with open(fpath) as fin: |
|
|
mode_list = [i.strip() for i in fin] |
|
|
mode_list = [i.strip() for i in fin] |
|
|
|
|
|
|
|
|
for i in mode_list: |
|
|
for i in mode_list: |
|
|
i = i.split(' ')[0].split('=')[0] |
|
|
|
|
|
|
|
|
i = i.split(" ")[0].split("=")[0] |
|
|
if i in self._elemwise_modes: |
|
|
if i in self._elemwise_modes: |
|
|
content = '_cb({})'.format(i) |
|
|
|
|
|
|
|
|
content = "_cb({})".format(i) |
|
|
else: |
|
|
else: |
|
|
content = '' |
|
|
|
|
|
|
|
|
content = "" |
|
|
self._write_def( |
|
|
self._write_def( |
|
|
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content) |
|
|
|
|
|
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', |
|
|
|
|
|
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') |
|
|
|
|
|
|
|
|
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( |
|
|
|
|
|
i.split(" ")[0].split("=")[0] |
|
|
|
|
|
), |
|
|
|
|
|
content, |
|
|
|
|
|
) |
|
|
|
|
|
self._write_def( |
|
|
|
|
|
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", |
|
|
|
|
|
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def _write_dtype(self): |
|
|
def _write_dtype(self): |
|
|
if 'Float16' not in self._dtypes: |
|
|
|
|
|
|
|
|
if "Float16" not in self._dtypes: |
|
|
# MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 |
|
|
# MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 |
|
|
# support in the past; however `FLOT16' is really a typo. We plan to |
|
|
# support in the past; however `FLOT16' is really a typo. We plan to |
|
|
# change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. |
|
|
# change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. |
|
@@ -166,74 +184,86 @@ class HeaderGen: |
|
|
# In the future when the situation is settled and no one would ever |
|
|
# In the future when the situation is settled and no one would ever |
|
|
# use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be |
|
|
# use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be |
|
|
# safely deleted. |
|
|
# safely deleted. |
|
|
self._write_def('MEGDNN_DISABLE_FLOT16', 1) |
|
|
|
|
|
self._write_def('MEGDNN_DISABLE_FLOAT16', 1) |
|
|
|
|
|
|
|
|
self._write_def("MEGDNN_DISABLE_FLOT16", 1) |
|
|
|
|
|
self._write_def("MEGDNN_DISABLE_FLOAT16", 1) |
|
|
|
|
|
|
|
|
def _write_hash(self): |
|
|
def _write_hash(self): |
|
|
if self._file_without_hash: |
|
|
if self._file_without_hash: |
|
|
print('WARNING: network info has no graph hash. Using json file ' |
|
|
|
|
|
'generated by MegBrain >= 7.28.0 is recommended') |
|
|
|
|
|
|
|
|
print( |
|
|
|
|
|
"WARNING: network info has no graph hash. Using json file " |
|
|
|
|
|
"generated by MegBrain >= 7.28.0 is recommended" |
|
|
|
|
|
) |
|
|
else: |
|
|
else: |
|
|
defs = 'ULL,'.join(self._graph_hashes) + 'ULL' |
|
|
|
|
|
self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs) |
|
|
|
|
|
|
|
|
defs = "ULL,".join(self._graph_hashes) + "ULL" |
|
|
|
|
|
self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs) |
|
|
|
|
|
|
|
|
def _write_def(self, name, val): |
|
|
def _write_def(self, name, val): |
|
|
if isinstance(val, list): |
|
|
if isinstance(val, list): |
|
|
val = '\n'.join(val) |
|
|
|
|
|
val = str(val).strip().replace('\n', ' \\\n') |
|
|
|
|
|
self._fout.write('#define {} {}\n'.format(name, val)) |
|
|
|
|
|
|
|
|
val = "\n".join(val) |
|
|
|
|
|
val = str(val).strip().replace("\n", " \\\n") |
|
|
|
|
|
self._fout.write("#define {} {}\n".format(name, val)) |
|
|
|
|
|
|
|
|
def _write_midout(self): |
|
|
def _write_midout(self): |
|
|
if not self._midout_files: |
|
|
if not self._midout_files: |
|
|
return |
|
|
return |
|
|
|
|
|
|
|
|
gen = os.path.join(self.get_megengine_root(), 'third_party', 'midout', 'gen_header.py') |
|
|
|
|
|
|
|
|
gen = os.path.join( |
|
|
|
|
|
self.get_megengine_root(), "third_party", "midout", "gen_header.py" |
|
|
|
|
|
) |
|
|
if self.get_megvii3_root(): |
|
|
if self.get_megvii3_root(): |
|
|
gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', 'gen_header.py') |
|
|
|
|
|
print('use {} to gen bin_reduce header'.format(gen)) |
|
|
|
|
|
|
|
|
gen = os.path.join( |
|
|
|
|
|
self.get_megvii3_root(), "brain", "midout", "gen_header.py" |
|
|
|
|
|
) |
|
|
|
|
|
print("use {} to gen bin_reduce header".format(gen)) |
|
|
cvt = subprocess.run( |
|
|
cvt = subprocess.run( |
|
|
[gen] + self._midout_files, |
|
|
|
|
|
stdout=subprocess.PIPE, check=True, |
|
|
|
|
|
).stdout.decode('utf-8') |
|
|
|
|
|
self._fout.write('// midout \n') |
|
|
|
|
|
|
|
|
[gen] + self._midout_files, stdout=subprocess.PIPE, check=True, |
|
|
|
|
|
).stdout.decode("utf-8") |
|
|
|
|
|
self._fout.write("// midout \n") |
|
|
self._fout.write(cvt) |
|
|
self._fout.write(cvt) |
|
|
if cvt.find(" half,") > 0: |
|
|
if cvt.find(" half,") > 0: |
|
|
change = open(self._fout.name).read().replace(" half,", " __fp16,") |
|
|
change = open(self._fout.name).read().replace(" half,", " __fp16,") |
|
|
with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: |
|
|
with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: |
|
|
fix_fp16.write(change) |
|
|
fix_fp16.write(change) |
|
|
msg = ( |
|
|
msg = ( |
|
|
"WARNING:\n" |
|
|
|
|
|
"hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" |
|
|
|
|
|
"which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" |
|
|
|
|
|
"then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
"WARNING:\n" |
|
|
|
|
|
"hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" |
|
|
|
|
|
"which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" |
|
|
|
|
|
"then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" |
|
|
|
|
|
) |
|
|
print(msg) |
|
|
print(msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
parser = argparse.ArgumentParser( |
|
|
description='generate header file for reducing binary size by ' |
|
|
|
|
|
'stripping unused oprs in a particular network; output file would ' |
|
|
|
|
|
'be written to bin_reduce.h', |
|
|
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
|
|
|
|
|
description="generate header file for reducing binary size by " |
|
|
|
|
|
"stripping unused oprs in a particular network; output file would " |
|
|
|
|
|
"be written to bin_reduce.h", |
|
|
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
|
|
|
|
) |
|
|
parser.add_argument( |
|
|
parser.add_argument( |
|
|
'inputs', nargs='+', |
|
|
|
|
|
help='input files that describe specific traits of the network; ' |
|
|
|
|
|
'can be one of the following:' |
|
|
|
|
|
' 1. json files generated by ' |
|
|
|
|
|
'megbrain.serialize_comp_graph_to_file() in python; ' |
|
|
|
|
|
' 2. trace files generated by midout library') |
|
|
|
|
|
default_file=os.path.join(HeaderGen.get_megengine_root(), 'src', 'bin_reduce_cmake.h') |
|
|
|
|
|
|
|
|
"inputs", |
|
|
|
|
|
nargs="+", |
|
|
|
|
|
help="input files that describe specific traits of the network; " |
|
|
|
|
|
"can be one of the following:" |
|
|
|
|
|
" 1. json files generated by " |
|
|
|
|
|
"megbrain.serialize_comp_graph_to_file() in python; " |
|
|
|
|
|
" 2. trace files generated by midout library", |
|
|
|
|
|
) |
|
|
|
|
|
default_file = os.path.join( |
|
|
|
|
|
HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h" |
|
|
|
|
|
) |
|
|
is_megvii3 = HeaderGen.get_megvii3_root() |
|
|
is_megvii3 = HeaderGen.get_megvii3_root() |
|
|
if is_megvii3: |
|
|
if is_megvii3: |
|
|
default_file=os.path.join(HeaderGen.get_megvii3_root(), 'utils', 'bin_reduce.h') |
|
|
|
|
|
parser.add_argument('-o', '--output', help='output file', default=default_file) |
|
|
|
|
|
|
|
|
default_file = os.path.join( |
|
|
|
|
|
HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h" |
|
|
|
|
|
) |
|
|
|
|
|
parser.add_argument("-o", "--output", help="output file", default=default_file) |
|
|
args = parser.parse_args() |
|
|
args = parser.parse_args() |
|
|
print('config output file: {}'.format(args.output)) |
|
|
|
|
|
|
|
|
print("config output file: {}".format(args.output)) |
|
|
|
|
|
|
|
|
gen = HeaderGen() |
|
|
gen = HeaderGen() |
|
|
for i in args.inputs: |
|
|
for i in args.inputs: |
|
|
print('==== processing {}'.format(i)) |
|
|
|
|
|
|
|
|
print("==== processing {}".format(i)) |
|
|
with open(i) as fin: |
|
|
with open(i) as fin: |
|
|
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: |
|
|
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: |
|
|
gen.extend_midout(i) |
|
|
gen.extend_midout(i) |
|
@@ -241,8 +271,9 @@ def main(): |
|
|
fin.seek(0) |
|
|
fin.seek(0) |
|
|
gen.extend_netinfo(json.loads(fin.read())) |
|
|
gen.extend_netinfo(json.loads(fin.read())) |
|
|
|
|
|
|
|
|
with open(args.output, 'w') as fout: |
|
|
|
|
|
|
|
|
with open(args.output, "w") as fout: |
|
|
gen.generate(fout) |
|
|
gen.generate(fout) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
main() |