Browse Source

perf(interpreter): try put device value with host to reduce d2h

GitOrigin-RevId: 63d36e7706
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
48db45d123
4 changed files with 7 additions and 8 deletions
  1. +3
    -4
      imperative/python/src/tensor.cpp
  2. +2
    -2
      imperative/src/impl/interpreter/interpreter_impl.cpp
  3. +1
    -1
      imperative/src/impl/interpreter/interpreter_impl.h
  4. +1
    -1
      imperative/src/include/megbrain/imperative/interpreter.h

+ 3
- 4
imperative/python/src/tensor.cpp View File

@@ -44,7 +44,7 @@ PyObject *cpp_apply_backward_varnode;

std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor(), value->get_value()));
}
py::tuple tup(6);
auto data = value->get_value();
@@ -248,7 +248,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// for DeviceTensorND
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
auto dv = py::handle(arg0).cast<DeviceTensorND>();
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv, {});
m_tensor = std::make_shared<Tensor>(handle);
} else {
throw py::type_error("single argument is not tensor, varnode or devicetensor");
@@ -347,7 +347,6 @@ SET_GET_NAME(user_custom_name)
SET_GET_NAME(automatic_name)
#undef SET_GET_NAME


PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr();
}
@@ -532,7 +531,7 @@ PyObject* TensorWrapper::_dev_tensor(){

// set m_handle to make it a real tensor
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>(), {});
m_tensor->m_handle = std::move(SharedHandle(sh));

// compiled info is useless after m_handle is set


+ 2
- 2
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -135,7 +135,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
return info;
}

Handle ChannelImpl::put(const DeviceTensorND& data) {
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
MGB_LOCK_GUARD(m_spin);
auto& state = get_channel_state();
mgb_assert(check_available(), "Channel already closed");
@@ -144,7 +144,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
init(info, {data.layout(), data.comp_node()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->ptr = Tensor::make(data);
info->ptr = Tensor::make(data, hvalue);
RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
info->status = TensorInfo::Produced;
RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put);


+ 1
- 1
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -42,7 +42,7 @@ struct ChannelImpl : Interpreter::Channel {
~ChannelImpl() override;

Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value) override;
Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;

void del(Handle) override;
void swap_in(Handle) override;


+ 1
- 1
imperative/src/include/megbrain/imperative/interpreter.h View File

@@ -23,7 +23,7 @@ struct Interpreter {
virtual ~Channel() = default;

virtual Handle put(const HostTensorND& value, bool no_cache) = 0;
virtual Handle put(const DeviceTensorND& value) = 0;
virtual Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) = 0;

virtual void del(Handle) = 0;
virtual void swap_in(Handle) = 0;


Loading…
Cancel
Save