Browse Source

fix(mge/tracing): replace detach as fast path copy

GitOrigin-RevId: d765725d5a
release-1.6
Megvii Engine Team 3 years ago
parent
commit
a829270489
2 changed files with 16 additions and 9 deletions
  1. +3
    -9
      imperative/python/src/tensor.cpp
  2. +13
    -0
      imperative/python/test/unit/jit/test_tracing.py

+ 3
- 9
imperative/python/src/tensor.cpp View File

@@ -541,15 +541,9 @@ PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type;

std::shared_ptr<Tensor> new_tensor;
if (m_tensor->m_handle.get()) {
new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
} else {
new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
}
new_tensor->m_trace_info = m_tensor->m_trace_info;

new_tensor->m_flags = m_tensor->m_flags;
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy());
auto new_tensor = python::apply(op, m_tensor)[0];
new_tensor->m_grad_info_dict = {};
auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr();
}


+ 13
- 0
imperative/python/test/unit/jit/test_tracing.py View File

@@ -96,6 +96,19 @@ def test_output_copy_trace():


@pytest.mark.parametrize("trace_mode", [False, True])
def test_tensor_detach(trace_mode):
@trace(symbolic=True)
def f(x):
y = x.detach() ** 2
z = y.detach() + 1
return z.detach()

x = tensor([1, 2, 3, 4])
for _ in range(3):
f(x).numpy()


@pytest.mark.parametrize("trace_mode", [False, True])
def test_exclude_from_trace(trace_mode):
@trace(symbolic=trace_mode)
def f(x):


Loading…
Cancel
Save