some dnn backends opr will use agency opr,
for example: softmax cpu naive imp will call elemwise opr,
at model dump stage, we can not get dnn runtime logic,
so we record elemwise mode info at runtime stage.
GitOrigin-RevId: 6528b4c85d
release-1.10
@@ -17,6 +17,9 @@ | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_common_elemwise) | |||
//! this tag will be used at tools/gen_header_for_bin_reduce.py | |||
//! please do not modify it | |||
MIDOUT_DECL(megdnn_common_elemwise_mode) | |||
#include <mutex> | |||
#include <vector> | |||
@@ -154,6 +157,88 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||
#if !MEGDNN_ELEMWISE_MODE_ENABLE_ALL | |||
megdnn_assert(ret.arity); | |||
#endif | |||
//! Some DNN backend OPRS will use proxy OPRS. For example, softmax@cpu Naive imp | |||
//! will call elemwise OPR. In the model dump stage, we have no information about | |||
//! this logic, which will lead to the loss of elemwise mode. As a solution, we | |||
//! record the elemwise mode information by adding the 'midout' case flag in the run | |||
//! stage. | |||
#define CB_MODE(mode) \ | |||
case mode: \ | |||
MIDOUT_BEGIN(megdnn_common_elemwise_mode, midout_iv(mode)) { return ret; } \ | |||
MIDOUT_END(); \ | |||
break; | |||
switch (mode) { | |||
CB_MODE(Mode::RELU); | |||
CB_MODE(Mode::ABS); | |||
CB_MODE(Mode::ACOS); | |||
CB_MODE(Mode::ASIN); | |||
CB_MODE(Mode::CEIL); | |||
CB_MODE(Mode::COS); | |||
CB_MODE(Mode::EXP); | |||
CB_MODE(Mode::EXPM1); | |||
CB_MODE(Mode::FLOOR); | |||
CB_MODE(Mode::LOG); | |||
CB_MODE(Mode::LOG1P); | |||
CB_MODE(Mode::NEGATE); | |||
CB_MODE(Mode::SIGMOID); | |||
CB_MODE(Mode::SIN); | |||
CB_MODE(Mode::TANH); | |||
CB_MODE(Mode::ABS_GRAD); | |||
CB_MODE(Mode::ADD); | |||
CB_MODE(Mode::FLOOR_DIV); | |||
CB_MODE(Mode::MAX); | |||
CB_MODE(Mode::MIN); | |||
CB_MODE(Mode::MOD); | |||
CB_MODE(Mode::MUL); | |||
CB_MODE(Mode::POW); | |||
CB_MODE(Mode::SIGMOID_GRAD); | |||
CB_MODE(Mode::SUB); | |||
CB_MODE(Mode::SWITCH_GT0); | |||
CB_MODE(Mode::TANH_GRAD); | |||
CB_MODE(Mode::TRUE_DIV); | |||
CB_MODE(Mode::LOG_SUM_EXP); | |||
CB_MODE(Mode::LT); | |||
CB_MODE(Mode::LEQ); | |||
CB_MODE(Mode::EQ); | |||
CB_MODE(Mode::SHL); | |||
CB_MODE(Mode::SHR); | |||
CB_MODE(Mode::COND_LEQ_MOV); | |||
CB_MODE(Mode::FUSE_MUL_ADD3); | |||
CB_MODE(Mode::FUSE_MUL_ADD4); | |||
CB_MODE(Mode::FUSE_ADD_RELU); | |||
CB_MODE(Mode::FUSE_ADD_SIGMOID); | |||
CB_MODE(Mode::FUSE_ADD_TANH); | |||
CB_MODE(Mode::FAST_TANH); | |||
CB_MODE(Mode::FAST_TANH_GRAD); | |||
CB_MODE(Mode::ROUND); | |||
CB_MODE(Mode::RMULH); | |||
CB_MODE(Mode::ATAN2); | |||
CB_MODE(Mode::ERF); | |||
CB_MODE(Mode::ERFINV); | |||
CB_MODE(Mode::ERFC); | |||
CB_MODE(Mode::ERFCINV); | |||
CB_MODE(Mode::H_SWISH); | |||
CB_MODE(Mode::H_SWISH_GRAD); | |||
CB_MODE(Mode::FUSE_ADD_H_SWISH); | |||
CB_MODE(Mode::NOT); | |||
CB_MODE(Mode::AND); | |||
CB_MODE(Mode::OR); | |||
CB_MODE(Mode::XOR); | |||
CB_MODE(Mode::SILU); | |||
CB_MODE(Mode::SILU_GRAD); | |||
CB_MODE(Mode::GELU); | |||
CB_MODE(Mode::GELU_GRAD); | |||
default: | |||
megdnn_assert( | |||
0, | |||
"code issue happened!!, please add new elemwise to switch mode."); | |||
return ret; | |||
#undef CB_MODE | |||
} | |||
return ret; | |||
} | |||
@@ -77,18 +77,40 @@ class HeaderGen: | |||
self._dtypes.add(i) | |||
for i in data["opr_types"]: | |||
self._oprs.add(i) | |||
for i in data["elemwise_modes"]: | |||
self._elemwise_modes.add(i) | |||
def extend_midout(self, fname): | |||
self._midout_files.append(fname) | |||
def extend_elemwise_mode_info(self, fname): | |||
for line in open(fname): | |||
# tag write in dnn/src/common/elemwise/opr_impl.cpp | |||
idx = line.find("megdnn_common_elemwise_mode") | |||
if idx > 0: | |||
cmd = "c++filt -t {}".format(line) | |||
demangle = subprocess.check_output(cmd, shell=True).decode("utf-8") | |||
demangle = demangle.replace(">", "").split() | |||
is_find_number = False | |||
for i in demangle: | |||
if i.isnumeric(): | |||
self._elemwise_modes.add(i) | |||
is_find_number = True | |||
break | |||
assert ( | |||
is_find_number | |||
), "code issue happened!! can not find elemwise mode in: {}".format( | |||
line | |||
) | |||
def generate(self, fout): | |||
self._fout = fout | |||
self._write_def("MGB_BINREDUCE_VERSION", "20190219") | |||
self._write_def("MGB_BINREDUCE_VERSION", "20220507") | |||
if self._has_netinfo: | |||
self._write_dtype() | |||
if len(self._elemwise_modes) > 0: | |||
self._write_elemwise_modes() | |||
if self._has_netinfo: | |||
self._write_oprs() | |||
self._write_hash() | |||
self._write_midout() | |||
@@ -156,22 +178,32 @@ class HeaderGen: | |||
with open(fpath) as fin: | |||
mode_list = [i.strip() for i in fin] | |||
all_elemwise_modes = set() | |||
for i in mode_list: | |||
i = i.split(" ")[0].split("=")[0] | |||
if i in self._elemwise_modes: | |||
content = "_cb({})".format(i) | |||
i_type = i.replace(" ", "").replace("=", " ").split()[0] | |||
i_id = i.replace(" ", "").replace("=", " ").split()[1] | |||
all_elemwise_modes.add(i_id) | |||
if i_id in self._elemwise_modes: | |||
content = "_cb({})".format(i_type) | |||
else: | |||
content = "" | |||
self._write_def( | |||
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( | |||
i.split(" ")[0].split("=")[0] | |||
), | |||
content, | |||
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content, | |||
) | |||
# write end of elemwise macro | |||
self._write_def( | |||
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", | |||
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", | |||
) | |||
# finally check all self._elemwise_modes is in all_elemwise_modes | |||
for i in self._elemwise_modes: | |||
assert ( | |||
i in all_elemwise_modes | |||
), "code issue happened, can not find elemwise mode: {} in {}".format( | |||
i, all_elemwise_modes | |||
) | |||
def _write_dtype(self): | |||
if "Float16" not in self._dtypes: | |||
@@ -267,6 +299,7 @@ def main(): | |||
with open(i) as fin: | |||
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: | |||
gen.extend_midout(i) | |||
gen.extend_elemwise_mode_info(i) | |||
else: | |||
fin.seek(0) | |||
gen.extend_netinfo(json.loads(fin.read())) | |||