Browse Source

fix(mge/dtr): filter bad ops

GitOrigin-RevId: 380262dca0
release-1.4
Megvii Engine Team 4 years ago
parent
commit
9122ff2e40
2 changed files with 15 additions and 1 deletions
  1. +8
    -1
      imperative/src/impl/interpreter/interpreter_impl.cpp
  2. +7
    -0
      imperative/src/impl/interpreter/interpreter_impl.h

+ 8
- 1
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -763,7 +763,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
break;
}
}
if (!is_inplace && !cross_cn) {
// FIXME: do not use opname as identifier
auto get_name = [](const OpDef& opdef) {
if (auto attr = opdef.try_cast_final<OprAttr>()) {
return attr->type.c_str();
}
return opdef.dyn_typeinfo()->name;
};
if (!is_inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
size_t detach_cnt = 0;
for (auto output : cmd.outputs) {


+ 7
- 0
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -308,6 +308,13 @@ private:

//! whether the warning message has been printed
bool warn_printed = false;

bool is_bad_op(std::string op_name) {
return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end();
}

std::vector<std::string> op_blacklist = {"CollectiveComm", "InplaceAdd",
"ParamPackSplit", "ParamPackConcat", "GaussianRNG"};
} m_dtr;

//! automatically evict an optimal tensor


Loading…
Cancel
Save