Browse Source

fix(imperative): remove big tensor from host side

GitOrigin-RevId: 2047982d73
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
f5b8fec4ca
2 changed files with 10 additions and 3 deletions
  1. +6
    -2
      imperative/src/impl/interpreter/interpreter_impl.cpp
  2. +4
    -1
      imperative/src/impl/physical_tensor.cpp

+ 6
- 2
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -149,9 +149,13 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
const_cast<HostTensorND&>(value).reset(value.storage(), layout);
}
auto info = alloc();
init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
constexpr int size_threshold = TensorShape::MAX_NDIM;
init(info, {value.layout(), value.comp_node()});
if (value.layout().total_nr_elems() <= size_threshold) {
info->h_value = value;
info->desc.value = value.proxy_to_default_cpu();
}
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->h_value = value;
m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) {
sync_impl();


+ 4
- 1
imperative/src/impl/physical_tensor.cpp View File

@@ -130,7 +130,10 @@ Tensor::Tensor(
: m_layout(layout), m_blob(std::move(blob)), m_offset(offset), m_value(hv) {}

Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) {
m_value = hv;
constexpr int size_threshold = TensorShape::MAX_NDIM;
if (hv.layout().total_nr_elems() <= size_threshold) {
m_value = hv;
}
MGB_RECORD_EVENT(
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(),
dev_tensor().raw_ptr());


Loading…
Cancel
Save