|
|
@@ -240,10 +240,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { |
|
|
|
pyf = cpp_apply_const_with_tracing; |
|
|
|
} |
|
|
|
|
|
|
|
auto ret = py::reinterpret_steal<py::object>( |
|
|
|
PyObject_Call(pyf, tup.ptr(), nullptr)); |
|
|
|
auto py_ret = py::reinterpret_borrow<py::list>(ret); |
|
|
|
if (auto* t = try_cast(py_ret[0].ptr())) { |
|
|
|
auto py_ret = PyObject_Call(pyf, tup.ptr(), nullptr); |
|
|
|
if (!py_ret) throw py::error_already_set(); |
|
|
|
auto py_list = py::reinterpret_steal<py::list>(py_ret); |
|
|
|
if (auto* t = try_cast(py_list[0].ptr())) { |
|
|
|
m_tensor = t->m_tensor; |
|
|
|
} |
|
|
|
return; |
|
|
@@ -389,6 +389,7 @@ PyObject* TensorWrapper::device() { |
|
|
|
PyObject* TensorWrapper::numpy() { |
|
|
|
if (m_tensor->m_trace_info.compiled_info != nullptr) { |
|
|
|
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr); |
|
|
|
if (!np_val) throw py::error_already_set(); |
|
|
|
if (np_val == Py_None) { |
|
|
|
throw TraceReadError("value of this tensor is not read in trace"); |
|
|
|
} |
|
|
@@ -478,6 +479,7 @@ PyObject* TensorWrapper::detach() { |
|
|
|
PyObject* TensorWrapper::_dev_tensor(){ |
|
|
|
if (m_tensor->m_trace_info.compiled_info != nullptr) { |
|
|
|
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); |
|
|
|
if (!dev_tensor) throw py::error_already_set(); |
|
|
|
if (dev_tensor == Py_None) { |
|
|
|
throw TraceReadError("raw data of this tensor is not read in trace"); |
|
|
|
} |
|
|
|