diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 5b51b509..166bc1ee 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -541,15 +541,9 @@ PyObject* TensorWrapper::detach() { PyObject* self = wrap_t::pycast(this); PyTypeObject* pytype = self->ob_type; - std::shared_ptr new_tensor; - if (m_tensor->m_handle.get()) { - new_tensor = std::make_shared(m_tensor->m_handle); - } else { - new_tensor = std::make_shared(m_tensor->m_var); - } - new_tensor->m_trace_info = m_tensor->m_trace_info; - - new_tensor->m_flags = m_tensor->m_flags; + static std::shared_ptr op = std::shared_ptr(new FastpathCopy()); + auto new_tensor = python::apply(op, m_tensor)[0]; + new_tensor->m_grad_info_dict = {}; auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); return ret.release().ptr(); } diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 2f833139..6e443e6c 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -96,6 +96,19 @@ def test_output_copy_trace(): @pytest.mark.parametrize("trace_mode", [False, True]) +def test_tensor_detach(trace_mode): + @trace(symbolic=True) + def f(x): + y = x.detach() ** 2 + z = y.detach() + 1 + return z.detach() + + x = tensor([1, 2, 3, 4]) + for _ in range(3): + f(x).numpy() + + +@pytest.mark.parametrize("trace_mode", [False, True]) def test_exclude_from_trace(trace_mode): @trace(symbolic=trace_mode) def f(x):