|
@@ -77,18 +77,40 @@ class HeaderGen: |
|
|
self._dtypes.add(i) |
|
|
self._dtypes.add(i) |
|
|
for i in data["opr_types"]: |
|
|
for i in data["opr_types"]: |
|
|
self._oprs.add(i) |
|
|
self._oprs.add(i) |
|
|
for i in data["elemwise_modes"]: |
|
|
|
|
|
self._elemwise_modes.add(i) |
|
|
|
|
|
|
|
|
|
|
|
def extend_midout(self, fname): |
|
|
def extend_midout(self, fname): |
|
|
self._midout_files.append(fname) |
|
|
self._midout_files.append(fname) |
|
|
|
|
|
|
|
|
|
|
|
def extend_elemwise_mode_info(self, fname): |
|
|
|
|
|
for line in open(fname): |
|
|
|
|
|
# tag write in dnn/src/common/elemwise/opr_impl.cpp |
|
|
|
|
|
idx = line.find("megdnn_common_elemwise_mode") |
|
|
|
|
|
if idx > 0: |
|
|
|
|
|
cmd = "c++filt -t {}".format(line) |
|
|
|
|
|
demangle = subprocess.check_output(cmd, shell=True).decode("utf-8") |
|
|
|
|
|
demangle = demangle.replace(">", "").split() |
|
|
|
|
|
is_find_number = False |
|
|
|
|
|
for i in demangle: |
|
|
|
|
|
if i.isnumeric(): |
|
|
|
|
|
self._elemwise_modes.add(i) |
|
|
|
|
|
is_find_number = True |
|
|
|
|
|
break |
|
|
|
|
|
assert ( |
|
|
|
|
|
is_find_number |
|
|
|
|
|
), "code issue happened!! can not find elemwise mode in: {}".format( |
|
|
|
|
|
line |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def generate(self, fout): |
|
|
def generate(self, fout): |
|
|
self._fout = fout |
|
|
self._fout = fout |
|
|
self._write_def("MGB_BINREDUCE_VERSION", "20190219") |
|
|
|
|
|
|
|
|
self._write_def("MGB_BINREDUCE_VERSION", "20220507") |
|
|
if self._has_netinfo: |
|
|
if self._has_netinfo: |
|
|
self._write_dtype() |
|
|
self._write_dtype() |
|
|
|
|
|
|
|
|
|
|
|
if len(self._elemwise_modes) > 0: |
|
|
self._write_elemwise_modes() |
|
|
self._write_elemwise_modes() |
|
|
|
|
|
|
|
|
|
|
|
if self._has_netinfo: |
|
|
self._write_oprs() |
|
|
self._write_oprs() |
|
|
self._write_hash() |
|
|
self._write_hash() |
|
|
self._write_midout() |
|
|
self._write_midout() |
|
@@ -156,22 +178,32 @@ class HeaderGen: |
|
|
with open(fpath) as fin: |
|
|
with open(fpath) as fin: |
|
|
mode_list = [i.strip() for i in fin] |
|
|
mode_list = [i.strip() for i in fin] |
|
|
|
|
|
|
|
|
|
|
|
all_elemwise_modes = set() |
|
|
for i in mode_list: |
|
|
for i in mode_list: |
|
|
i = i.split(" ")[0].split("=")[0] |
|
|
|
|
|
if i in self._elemwise_modes: |
|
|
|
|
|
content = "_cb({})".format(i) |
|
|
|
|
|
|
|
|
i_type = i.replace(" ", "").replace("=", " ").split()[0] |
|
|
|
|
|
i_id = i.replace(" ", "").replace("=", " ").split()[1] |
|
|
|
|
|
all_elemwise_modes.add(i_id) |
|
|
|
|
|
|
|
|
|
|
|
if i_id in self._elemwise_modes: |
|
|
|
|
|
content = "_cb({})".format(i_type) |
|
|
else: |
|
|
else: |
|
|
content = "" |
|
|
content = "" |
|
|
self._write_def( |
|
|
self._write_def( |
|
|
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( |
|
|
|
|
|
i.split(" ")[0].split("=")[0] |
|
|
|
|
|
), |
|
|
|
|
|
content, |
|
|
|
|
|
|
|
|
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# write end of elemwise macro |
|
|
self._write_def( |
|
|
self._write_def( |
|
|
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", |
|
|
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", |
|
|
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", |
|
|
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", |
|
|
) |
|
|
) |
|
|
|
|
|
# finally check all self._elemwise_modes is in all_elemwise_modes |
|
|
|
|
|
for i in self._elemwise_modes: |
|
|
|
|
|
assert ( |
|
|
|
|
|
i in all_elemwise_modes |
|
|
|
|
|
), "code issue happened, can not find elemwise mode: {} in {}".format( |
|
|
|
|
|
i, all_elemwise_modes |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def _write_dtype(self): |
|
|
def _write_dtype(self): |
|
|
if "Float16" not in self._dtypes: |
|
|
if "Float16" not in self._dtypes: |
|
@@ -267,6 +299,7 @@ def main(): |
|
|
with open(i) as fin: |
|
|
with open(i) as fin: |
|
|
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: |
|
|
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: |
|
|
gen.extend_midout(i) |
|
|
gen.extend_midout(i) |
|
|
|
|
|
gen.extend_elemwise_mode_info(i) |
|
|
else: |
|
|
else: |
|
|
fin.seek(0) |
|
|
fin.seek(0) |
|
|
gen.extend_netinfo(json.loads(fin.read())) |
|
|
gen.extend_netinfo(json.loads(fin.read())) |
|
|