From 9122ff2e40ef746948f9bc175ee80a67b2c900e6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 18 May 2021 11:33:18 +0800 Subject: [PATCH] fix(mge/dtr): filter bad ops GitOrigin-RevId: 380262dca0ea11bfa28fca50f4643963bd9ad7d9 --- imperative/src/impl/interpreter/interpreter_impl.cpp | 9 ++++++++- imperative/src/impl/interpreter/interpreter_impl.h | 7 +++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 1843c6b2..848fbdee 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -763,7 +763,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { break; } } - if (!is_inplace && !cross_cn) { + // FIXME: do not use opname as identifier + auto get_name = [](const OpDef& opdef) { + if (auto attr = opdef.try_cast_final()) { + return attr->type.c_str(); + } + return opdef.dyn_typeinfo()->name; + }; + if (!is_inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); size_t detach_cnt = 0; for (auto output : cmd.outputs) { diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 0d11c1c7..d10cc5c4 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -308,6 +308,13 @@ private: //! whether the warning message has been printed bool warn_printed = false; + + bool is_bad_op(std::string op_name) { + return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end(); + } + + std::vector op_blacklist = {"CollectiveComm", "InplaceAdd", + "ParamPackSplit", "ParamPackConcat", "GaussianRNG"}; } m_dtr; //! automatically evict an optimal tensor