diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 1843c6b2..848fbdee 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -763,7 +763,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { break; } } - if (!is_inplace && !cross_cn) { + // FIXME: do not use opname as identifier + auto get_name = [](const OpDef& opdef) { + if (auto attr = opdef.try_cast_final()) { + return attr->type.c_str(); + } + return opdef.dyn_typeinfo()->name; + }; + if (!is_inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); size_t detach_cnt = 0; for (auto output : cmd.outputs) { diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 0d11c1c7..d10cc5c4 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -308,6 +308,13 @@ private: //! whether the warning message has been printed bool warn_printed = false; + + bool is_bad_op(std::string op_name) { + return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end(); + } + + std::vector op_blacklist = {"CollectiveComm", "InplaceAdd", + "ParamPackSplit", "ParamPackConcat", "GaussianRNG"}; } m_dtr; //! automatically evict an optimal tensor