diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index 3850190c..475d0814 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -17,6 +17,9 @@ #include "midout.h" MIDOUT_DECL(megdnn_common_elemwise) +//! this tag will be used at tools/gen_header_for_bin_reduce.py +//! please do not modify it +MIDOUT_DECL(megdnn_common_elemwise_mode) #include #include @@ -154,6 +157,88 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { #if !MEGDNN_ELEMWISE_MODE_ENABLE_ALL megdnn_assert(ret.arity); #endif + + //! Some DNN backend OPRS will use proxy OPRS. For example, softmax@cpu Naive imp + //! will call elemwise OPR. In the model dump stage, we have no information about + //! this logic, which will lead to the loss of elemwise mode. As a solution, we + //! record the elemwise mode information by adding the 'midout' case flag in the run + //! stage. +#define CB_MODE(mode) \ + case mode: \ + MIDOUT_BEGIN(megdnn_common_elemwise_mode, midout_iv(mode)) { return ret; } \ + MIDOUT_END(); \ + break; + + switch (mode) { + CB_MODE(Mode::RELU); + CB_MODE(Mode::ABS); + CB_MODE(Mode::ACOS); + CB_MODE(Mode::ASIN); + CB_MODE(Mode::CEIL); + CB_MODE(Mode::COS); + CB_MODE(Mode::EXP); + CB_MODE(Mode::EXPM1); + CB_MODE(Mode::FLOOR); + CB_MODE(Mode::LOG); + CB_MODE(Mode::LOG1P); + CB_MODE(Mode::NEGATE); + CB_MODE(Mode::SIGMOID); + CB_MODE(Mode::SIN); + CB_MODE(Mode::TANH); + CB_MODE(Mode::ABS_GRAD); + CB_MODE(Mode::ADD); + CB_MODE(Mode::FLOOR_DIV); + CB_MODE(Mode::MAX); + CB_MODE(Mode::MIN); + CB_MODE(Mode::MOD); + CB_MODE(Mode::MUL); + CB_MODE(Mode::POW); + CB_MODE(Mode::SIGMOID_GRAD); + CB_MODE(Mode::SUB); + CB_MODE(Mode::SWITCH_GT0); + CB_MODE(Mode::TANH_GRAD); + CB_MODE(Mode::TRUE_DIV); + CB_MODE(Mode::LOG_SUM_EXP); + CB_MODE(Mode::LT); + CB_MODE(Mode::LEQ); + CB_MODE(Mode::EQ); + CB_MODE(Mode::SHL); + CB_MODE(Mode::SHR); + CB_MODE(Mode::COND_LEQ_MOV); + CB_MODE(Mode::FUSE_MUL_ADD3); + CB_MODE(Mode::FUSE_MUL_ADD4); + CB_MODE(Mode::FUSE_ADD_RELU); + CB_MODE(Mode::FUSE_ADD_SIGMOID); + CB_MODE(Mode::FUSE_ADD_TANH); + CB_MODE(Mode::FAST_TANH); + CB_MODE(Mode::FAST_TANH_GRAD); + CB_MODE(Mode::ROUND); + CB_MODE(Mode::RMULH); + CB_MODE(Mode::ATAN2); + CB_MODE(Mode::ERF); + CB_MODE(Mode::ERFINV); + CB_MODE(Mode::ERFC); + CB_MODE(Mode::ERFCINV); + CB_MODE(Mode::H_SWISH); + CB_MODE(Mode::H_SWISH_GRAD); + CB_MODE(Mode::FUSE_ADD_H_SWISH); + CB_MODE(Mode::NOT); + CB_MODE(Mode::AND); + CB_MODE(Mode::OR); + CB_MODE(Mode::XOR); + CB_MODE(Mode::SILU); + CB_MODE(Mode::SILU_GRAD); + CB_MODE(Mode::GELU); + CB_MODE(Mode::GELU_GRAD); + default: + megdnn_assert( + 0, + "code issue happened!!, please add new elemwise to switch mode."); + return ret; + +#undef CB_MODE + } + return ret; } diff --git a/tools/gen_header_for_bin_reduce.py b/tools/gen_header_for_bin_reduce.py index 0414db40..952dc8a9 100755 --- a/tools/gen_header_for_bin_reduce.py +++ b/tools/gen_header_for_bin_reduce.py @@ -77,18 +77,40 @@ class HeaderGen: self._dtypes.add(i) for i in data["opr_types"]: self._oprs.add(i) - for i in data["elemwise_modes"]: - self._elemwise_modes.add(i) def extend_midout(self, fname): self._midout_files.append(fname) + def extend_elemwise_mode_info(self, fname): + for line in open(fname): + # tag write in dnn/src/common/elemwise/opr_impl.cpp + idx = line.find("megdnn_common_elemwise_mode") + if idx > 0: + cmd = "c++filt -t {}".format(line) + demangle = subprocess.check_output(cmd, shell=True).decode("utf-8") + demangle = demangle.replace(">", "").split() + is_find_number = False + for i in demangle: + if i.isnumeric(): + self._elemwise_modes.add(i) + is_find_number = True + break + assert ( + is_find_number + ), "code issue happened!! can not find elemwise mode in: {}".format( + line + ) + def generate(self, fout): self._fout = fout - self._write_def("MGB_BINREDUCE_VERSION", "20190219") + self._write_def("MGB_BINREDUCE_VERSION", "20220507") if self._has_netinfo: self._write_dtype() + + if len(self._elemwise_modes) > 0: self._write_elemwise_modes() + + if self._has_netinfo: self._write_oprs() self._write_hash() self._write_midout() @@ -156,22 +178,32 @@ class HeaderGen: with open(fpath) as fin: mode_list = [i.strip() for i in fin] + all_elemwise_modes = set() for i in mode_list: - i = i.split(" ")[0].split("=")[0] - if i in self._elemwise_modes: - content = "_cb({})".format(i) + i_type = i.replace(" ", "").replace("=", " ").split()[0] + i_id = i.replace(" ", "").replace("=", " ").split()[1] + all_elemwise_modes.add(i_id) + + if i_id in self._elemwise_modes: + content = "_cb({})".format(i_type) else: content = "" self._write_def( - "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( - i.split(" ")[0].split("=")[0] - ), - content, + "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content, ) + + # write end of elemwise macro self._write_def( "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", ) + # finally check all self._elemwise_modes is in all_elemwise_modes + for i in self._elemwise_modes: + assert ( + i in all_elemwise_modes + ), "code issue happened, can not find elemwise mode: {} in {}".format( + i, all_elemwise_modes + ) def _write_dtype(self): if "Float16" not in self._dtypes: @@ -267,6 +299,7 @@ def main(): with open(i) as fin: if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: gen.extend_midout(i) + gen.extend_elemwise_mode_info(i) else: fin.seek(0) gen.extend_netinfo(json.loads(fin.read()))