GitOrigin-RevId: 63d36e7706
tags/v1.6.0-rc1
@@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode; | |||||
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { | std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { | ||||
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { | if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { | ||||
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor())); | |||||
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor(), value->get_value())); | |||||
} | } | ||||
py::tuple tup(6); | py::tuple tup(6); | ||||
auto data = value->get_value(); | auto data = value->get_value(); | ||||
@@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
// for DeviceTensorND | // for DeviceTensorND | ||||
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { | if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { | ||||
auto dv = py::handle(arg0).cast<DeviceTensorND>(); | auto dv = py::handle(arg0).cast<DeviceTensorND>(); | ||||
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv); | |||||
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv, {}); | |||||
m_tensor = std::make_shared<Tensor>(handle); | m_tensor = std::make_shared<Tensor>(handle); | ||||
} else { | } else { | ||||
throw py::type_error("single argument is not tensor, varnode or devicetensor"); | throw py::type_error("single argument is not tensor, varnode or devicetensor"); | ||||
@@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name) | |||||
SET_GET_NAME(automatic_name) | SET_GET_NAME(automatic_name) | ||||
#undef SET_GET_NAME | #undef SET_GET_NAME | ||||
PyObject* TensorWrapper::handle() { | PyObject* TensorWrapper::handle() { | ||||
return py::cast(m_tensor->m_handle).release().ptr(); | return py::cast(m_tensor->m_handle).release().ptr(); | ||||
} | } | ||||
@@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){ | |||||
// set m_handle to make it a real tensor | // set m_handle to make it a real tensor | ||||
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); | auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); | ||||
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); | |||||
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>(), {}); | |||||
m_tensor->m_handle = std::move(SharedHandle(sh)); | m_tensor->m_handle = std::move(SharedHandle(sh)); | ||||
// compiled info is useless after m_handle is set | // compiled info is useless after m_handle is set | ||||
@@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { | |||||
return info; | return info; | ||||
} | } | ||||
Handle ChannelImpl::put(const DeviceTensorND& data) { | |||||
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) { | |||||
MGB_LOCK_GUARD(m_spin); | MGB_LOCK_GUARD(m_spin); | ||||
auto& state = get_channel_state(); | auto& state = get_channel_state(); | ||||
mgb_assert(check_available(), "Channel already closed"); | mgb_assert(check_available(), "Channel already closed"); | ||||
@@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { | |||||
RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); | RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); | ||||
init(info, {data.layout(), data.comp_node()}); | init(info, {data.layout(), data.comp_node()}); | ||||
info->mem_desc.id = StorageIdentifier::make(++m_storage_id); | info->mem_desc.id = StorageIdentifier::make(++m_storage_id); | ||||
info->ptr = Tensor::make(data); | |||||
info->ptr = Tensor::make(data, hvalue); | |||||
RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); | RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); | ||||
info->status = TensorInfo::Produced; | info->status = TensorInfo::Produced; | ||||
RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put); | RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put); | ||||
@@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel { | |||||
~ChannelImpl() override; | ~ChannelImpl() override; | ||||
Handle put(const HostTensorND& value, bool no_cache) override; | Handle put(const HostTensorND& value, bool no_cache) override; | ||||
Handle put(const DeviceTensorND& value) override; | |||||
Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override; | |||||
void del(Handle) override; | void del(Handle) override; | ||||
void swap_in(Handle) override; | void swap_in(Handle) override; | ||||
@@ -23,7 +23,7 @@ struct Interpreter { | |||||
virtual ~Channel() = default; | virtual ~Channel() = default; | ||||
virtual Handle put(const HostTensorND& value, bool no_cache) = 0; | virtual Handle put(const HostTensorND& value, bool no_cache) = 0; | ||||
virtual Handle put(const DeviceTensorND& value) = 0; | |||||
virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0; | |||||
virtual void del(Handle) = 0; | virtual void del(Handle) = 0; | ||||
virtual void swap_in(Handle) = 0; | virtual void swap_in(Handle) = 0; | ||||