Browse Source

fix(imperative): fix persistent_cache

GitOrigin-RevId: 8f7bb5899f
release-1.2
Megvii Engine Team 4 years ago
parent
commit
08cc10324e
3 changed files with 9 additions and 5 deletions
  1. +0
    -1
      imperative/python/megengine/__init__.py
  2. +1
    -1
      imperative/python/src/helper.h
  3. +8
    -3
      imperative/python/src/tensor.cpp

+ 0
- 1
imperative/python/megengine/__init__.py View File

@@ -93,7 +93,6 @@ _persistent_cache_impl_ins.reg()
atexit.register(_full_sync) atexit.register(_full_sync)


del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins


# subpackages # subpackages
import megengine.autodiff import megengine.autodiff


+ 1
- 1
imperative/python/src/helper.h View File

@@ -366,7 +366,7 @@ namespace detail {
return true; return true;
} }
static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) { static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) {
return bytes((const char*)blob.ptr, blob.size);
return bytes((const char*)blob.ptr, blob.size).release();
} }
}; };




+ 8
- 3
imperative/python/src/tensor.cpp View File

@@ -421,8 +421,10 @@ PyObject* TensorWrapper::numpy() {
} }
return np_val.release().ptr(); return np_val.release().ptr();
} }

auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto&& hv = [&]() {
py::gil_scoped_release _;
return interpreter_for_py->get_value(m_tensor->m_handle.get());
}();
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) { if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid"); PyErr_SetString(PyExc_ValueError, "tensor invalid");
@@ -492,7 +494,10 @@ PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
} }
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
auto dev_tensor = [&](){
py::gil_scoped_release _;
return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
}();
return py::cast(dev_tensor).release().ptr(); return py::cast(dev_tensor).release().ptr();
} }




Loading…
Cancel
Save