|
|
@@ -47,10 +47,6 @@ namespace views = ranges::views; |
|
|
|
|
|
|
|
namespace mgb::imperative::python { |
|
|
|
|
|
|
|
namespace { |
|
|
|
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
interpreter::Interpreter::Channel* interpreter_for_py = nullptr; |
|
|
|
PyTypeObject* py_tensor_type = nullptr; |
|
|
|
PyTypeObject* py_varnode_type = nullptr; |
|
|
@@ -594,7 +590,9 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* TensorWrapper::module_trace_info() { |
|
|
|
if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { |
|
|
|
if (auto module_trace_info = |
|
|
|
ModuleTraceTransformation::module_trace_info_map.try_get( |
|
|
|
m_tensor->data())) { |
|
|
|
if (module_trace_info->ptr()) { |
|
|
|
return module_trace_info->inc_ref().ptr(); |
|
|
|
} |
|
|
@@ -608,7 +606,8 @@ PyObject* TensorWrapper::module_trace_info() { |
|
|
|
|
|
|
|
void TensorWrapper::set_module_trace_info(PyObject* obj) { |
|
|
|
// TODO: erase when obj == nullptr |
|
|
|
module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); |
|
|
|
ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] = |
|
|
|
py::reinterpret_borrow<py::object>(obj); |
|
|
|
} |
|
|
|
|
|
|
|
void TensorWrapper::_set_format(PyObject* dest) { |
|
|
@@ -620,6 +619,7 @@ void TensorWrapper::_set_format(PyObject* dest) { |
|
|
|
void TensorWrapper::_set_name(PyObject* dest) { |
|
|
|
auto py_dest = py::reinterpret_borrow<py::object>(dest); |
|
|
|
auto name = py_dest.cast<std::string>(); |
|
|
|
|
|
|
|
m_tensor->set_name(name); |
|
|
|
} |
|
|
|
|
|
|
|