Browse Source

perf(interpreter): try put device value with host to reduce d2h

GitOrigin-RevId: 63d36e7706
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
48db45d123
4 changed files with 7 additions and 8 deletions
  1. +3
    -4
      imperative/python/src/tensor.cpp
  2. +2
    -2
      imperative/src/impl/interpreter/interpreter_impl.cpp
  3. +1
    -1
      imperative/src/impl/interpreter/interpreter_impl.h
  4. +1
    -1
      imperative/src/include/megbrain/imperative/interpreter.h

+ 3
- 4
imperative/python/src/tensor.cpp View File

@@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode;


std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor(), value->get_value()));
} }
py::tuple tup(6); py::tuple tup(6);
auto data = value->get_value(); auto data = value->get_value();
@@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// for DeviceTensorND // for DeviceTensorND
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
auto dv = py::handle(arg0).cast<DeviceTensorND>(); auto dv = py::handle(arg0).cast<DeviceTensorND>();
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv, {});
m_tensor = std::make_shared<Tensor>(handle); m_tensor = std::make_shared<Tensor>(handle);
} else { } else {
throw py::type_error("single argument is not tensor, varnode or devicetensor"); throw py::type_error("single argument is not tensor, varnode or devicetensor");
@@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name)
SET_GET_NAME(automatic_name) SET_GET_NAME(automatic_name)
#undef SET_GET_NAME #undef SET_GET_NAME



PyObject* TensorWrapper::handle() { PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr(); return py::cast(m_tensor->m_handle).release().ptr();
} }
@@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){


// set m_handle to make it a real tensor // set m_handle to make it a real tensor
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>(), {});
m_tensor->m_handle = std::move(SharedHandle(sh)); m_tensor->m_handle = std::move(SharedHandle(sh));


// compiled info is useless after m_handle is set // compiled info is useless after m_handle is set


+ 2
- 2
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
return info; return info;
} }


Handle ChannelImpl::put(const DeviceTensorND& data) {
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
auto& state = get_channel_state(); auto& state = get_channel_state();
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
@@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put); RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
init(info, {data.layout(), data.comp_node()}); init(info, {data.layout(), data.comp_node()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id); info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->ptr = Tensor::make(data);
info->ptr = Tensor::make(data, hvalue);
RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr()); RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
info->status = TensorInfo::Produced; info->status = TensorInfo::Produced;
RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put); RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put);


+ 1
- 1
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel {
~ChannelImpl() override; ~ChannelImpl() override;


Handle put(const HostTensorND& value, bool no_cache) override; Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value) override;
Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;


void del(Handle) override; void del(Handle) override;
void swap_in(Handle) override; void swap_in(Handle) override;


+ 1
- 1
imperative/src/include/megbrain/imperative/interpreter.h View File

@@ -23,7 +23,7 @@ struct Interpreter {
virtual ~Channel() = default; virtual ~Channel() = default;


virtual Handle put(const HostTensorND& value, bool no_cache) = 0; virtual Handle put(const HostTensorND& value, bool no_cache) = 0;
virtual Handle put(const DeviceTensorND& value) = 0;
virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0;


virtual void del(Handle) = 0; virtual void del(Handle) = 0;
virtual void swap_in(Handle) = 0; virtual void swap_in(Handle) = 0;


Loading…
Cancel
Save