Browse Source

fix(midout): fix elemwise crash after midout

some dnn backends opr will use agency opr,
for example: softmax cpu naive imp will call elemwise opr,
at model dump stage, we can not get dnn runtime logic,
so we record elemwise mode info at runtime stage.

GitOrigin-RevId: 6528b4c85d
release-1.10
Megvii Engine Team 3 years ago
parent
commit
05186e7bd9
2 changed files with 128 additions and 10 deletions
  1. +85
    -0
      dnn/src/common/elemwise/opr_impl.cpp
  2. +43
    -10
      tools/gen_header_for_bin_reduce.py

+ 85
- 0
dnn/src/common/elemwise/opr_impl.cpp View File

@@ -17,6 +17,9 @@


#include "midout.h" #include "midout.h"
MIDOUT_DECL(megdnn_common_elemwise) MIDOUT_DECL(megdnn_common_elemwise)
//! this tag will be used at tools/gen_header_for_bin_reduce.py
//! please do not modify it
MIDOUT_DECL(megdnn_common_elemwise_mode)


#include <mutex> #include <mutex>
#include <vector> #include <vector>
@@ -154,6 +157,88 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
#if !MEGDNN_ELEMWISE_MODE_ENABLE_ALL #if !MEGDNN_ELEMWISE_MODE_ENABLE_ALL
megdnn_assert(ret.arity); megdnn_assert(ret.arity);
#endif #endif

//! Some DNN backend OPRS will use proxy OPRS. For example, softmax@cpu Naive imp
//! will call elemwise OPR. In the model dump stage, we have no information about
//! this logic, which will lead to the loss of elemwise mode. As a solution, we
//! record the elemwise mode information by adding the 'midout' case flag in the run
//! stage.
#define CB_MODE(mode) \
case mode: \
MIDOUT_BEGIN(megdnn_common_elemwise_mode, midout_iv(mode)) { return ret; } \
MIDOUT_END(); \
break;

switch (mode) {
CB_MODE(Mode::RELU);
CB_MODE(Mode::ABS);
CB_MODE(Mode::ACOS);
CB_MODE(Mode::ASIN);
CB_MODE(Mode::CEIL);
CB_MODE(Mode::COS);
CB_MODE(Mode::EXP);
CB_MODE(Mode::EXPM1);
CB_MODE(Mode::FLOOR);
CB_MODE(Mode::LOG);
CB_MODE(Mode::LOG1P);
CB_MODE(Mode::NEGATE);
CB_MODE(Mode::SIGMOID);
CB_MODE(Mode::SIN);
CB_MODE(Mode::TANH);
CB_MODE(Mode::ABS_GRAD);
CB_MODE(Mode::ADD);
CB_MODE(Mode::FLOOR_DIV);
CB_MODE(Mode::MAX);
CB_MODE(Mode::MIN);
CB_MODE(Mode::MOD);
CB_MODE(Mode::MUL);
CB_MODE(Mode::POW);
CB_MODE(Mode::SIGMOID_GRAD);
CB_MODE(Mode::SUB);
CB_MODE(Mode::SWITCH_GT0);
CB_MODE(Mode::TANH_GRAD);
CB_MODE(Mode::TRUE_DIV);
CB_MODE(Mode::LOG_SUM_EXP);
CB_MODE(Mode::LT);
CB_MODE(Mode::LEQ);
CB_MODE(Mode::EQ);
CB_MODE(Mode::SHL);
CB_MODE(Mode::SHR);
CB_MODE(Mode::COND_LEQ_MOV);
CB_MODE(Mode::FUSE_MUL_ADD3);
CB_MODE(Mode::FUSE_MUL_ADD4);
CB_MODE(Mode::FUSE_ADD_RELU);
CB_MODE(Mode::FUSE_ADD_SIGMOID);
CB_MODE(Mode::FUSE_ADD_TANH);
CB_MODE(Mode::FAST_TANH);
CB_MODE(Mode::FAST_TANH_GRAD);
CB_MODE(Mode::ROUND);
CB_MODE(Mode::RMULH);
CB_MODE(Mode::ATAN2);
CB_MODE(Mode::ERF);
CB_MODE(Mode::ERFINV);
CB_MODE(Mode::ERFC);
CB_MODE(Mode::ERFCINV);
CB_MODE(Mode::H_SWISH);
CB_MODE(Mode::H_SWISH_GRAD);
CB_MODE(Mode::FUSE_ADD_H_SWISH);
CB_MODE(Mode::NOT);
CB_MODE(Mode::AND);
CB_MODE(Mode::OR);
CB_MODE(Mode::XOR);
CB_MODE(Mode::SILU);
CB_MODE(Mode::SILU_GRAD);
CB_MODE(Mode::GELU);
CB_MODE(Mode::GELU_GRAD);
default:
megdnn_assert(
0,
"code issue happened!!, please add new elemwise to switch mode.");
return ret;

#undef CB_MODE
}

return ret; return ret;
} }




+ 43
- 10
tools/gen_header_for_bin_reduce.py View File

@@ -77,18 +77,40 @@ class HeaderGen:
self._dtypes.add(i) self._dtypes.add(i)
for i in data["opr_types"]: for i in data["opr_types"]:
self._oprs.add(i) self._oprs.add(i)
for i in data["elemwise_modes"]:
self._elemwise_modes.add(i)


def extend_midout(self, fname): def extend_midout(self, fname):
self._midout_files.append(fname) self._midout_files.append(fname)


def extend_elemwise_mode_info(self, fname):
for line in open(fname):
# tag write in dnn/src/common/elemwise/opr_impl.cpp
idx = line.find("megdnn_common_elemwise_mode")
if idx > 0:
cmd = "c++filt -t {}".format(line)
demangle = subprocess.check_output(cmd, shell=True).decode("utf-8")
demangle = demangle.replace(">", "").split()
is_find_number = False
for i in demangle:
if i.isnumeric():
self._elemwise_modes.add(i)
is_find_number = True
break
assert (
is_find_number
), "code issue happened!! can not find elemwise mode in: {}".format(
line
)

def generate(self, fout): def generate(self, fout):
self._fout = fout self._fout = fout
self._write_def("MGB_BINREDUCE_VERSION", "20190219")
self._write_def("MGB_BINREDUCE_VERSION", "20220507")
if self._has_netinfo: if self._has_netinfo:
self._write_dtype() self._write_dtype()

if len(self._elemwise_modes) > 0:
self._write_elemwise_modes() self._write_elemwise_modes()

if self._has_netinfo:
self._write_oprs() self._write_oprs()
self._write_hash() self._write_hash()
self._write_midout() self._write_midout()
@@ -156,22 +178,32 @@ class HeaderGen:
with open(fpath) as fin: with open(fpath) as fin:
mode_list = [i.strip() for i in fin] mode_list = [i.strip() for i in fin]


all_elemwise_modes = set()
for i in mode_list: for i in mode_list:
i = i.split(" ")[0].split("=")[0]
if i in self._elemwise_modes:
content = "_cb({})".format(i)
i_type = i.replace(" ", "").replace("=", " ").split()[0]
i_id = i.replace(" ", "").replace("=", " ").split()[1]
all_elemwise_modes.add(i_id)

if i_id in self._elemwise_modes:
content = "_cb({})".format(i_type)
else: else:
content = "" content = ""
self._write_def( self._write_def(
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(
i.split(" ")[0].split("=")[0]
),
content,
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content,
) )

# write end of elemwise macro
self._write_def( self._write_def(
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)",
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)",
) )
# finally check all self._elemwise_modes is in all_elemwise_modes
for i in self._elemwise_modes:
assert (
i in all_elemwise_modes
), "code issue happened, can not find elemwise mode: {} in {}".format(
i, all_elemwise_modes
)


def _write_dtype(self): def _write_dtype(self):
if "Float16" not in self._dtypes: if "Float16" not in self._dtypes:
@@ -267,6 +299,7 @@ def main():
with open(i) as fin: with open(i) as fin:
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
gen.extend_midout(i) gen.extend_midout(i)
gen.extend_elemwise_mode_info(i)
else: else:
fin.seek(0) fin.seek(0)
gen.extend_netinfo(json.loads(fin.read())) gen.extend_netinfo(json.loads(fin.read()))


Loading…
Cancel
Save