diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 0b90cc04..84f6283c 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -957,14 +957,14 @@ std::tuple, bool> tuple2vector(py::object shape) { } bool enable_fastpath(py::handle inp) { - // FIXME: the way to judge whether it is in traced module is inaccurate + auto&& tm_tr = TransformationManager::get_instance() + .segments[TransformationManager::Segment::ModuleTrace]; if (!TensorWrapper::try_cast(inp.ptr()) || TransformationManager::get_instance() .segments[TransformationManager::Segment::Trace] .size() > 0 || - TransformationManager::get_instance() - .segments[TransformationManager::Segment::ModuleTrace] - .size() > 0) { + (tm_tr.size() > 0 && + reinterpret_cast(tm_tr[0].get())->enabled())) { return false; } return true;