Browse Source

fix(mge/dtr): fix dtr problem

GitOrigin-RevId: 2a703f9ee4
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
0a6f4a880e
9 changed files with 67 additions and 24 deletions
  1. +18
    -0
      imperative/src/impl/blob_manager_impl.cpp
  2. +4
    -0
      imperative/src/impl/blob_manager_impl.h
  3. +1
    -0
      imperative/src/impl/interpreter/commands.h
  4. +33
    -20
      imperative/src/impl/interpreter/interpreter_impl.cpp
  5. +1
    -1
      imperative/src/impl/interpreter/interpreter_impl.h
  6. +4
    -1
      imperative/src/impl/interpreter/tensor_info.h
  7. +1
    -1
      imperative/src/impl/physical_tensor.cpp
  8. +4
    -0
      imperative/src/include/megbrain/imperative/blob_manager.h
  9. +1
    -1
      imperative/src/include/megbrain/imperative/physical_tensor.h

+ 18
- 0
imperative/src/impl/blob_manager_impl.cpp View File

@@ -41,6 +41,10 @@ void BlobManagerImpl::unregister_blob(Blob* blob) {
} }


void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) { void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) {
if (custom_allocator) {
blob->m_storage = custom_allocator(blob->m_comp_node, size);
return;
}
// try alloc // try alloc
MGB_TRY { alloc_direct(blob, size); } MGB_TRY { alloc_direct(blob, size); }
// if fail, try defrag, alloc again // if fail, try defrag, alloc again
@@ -61,6 +65,13 @@ void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) {
DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag( DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag(
CompNode cn, TensorLayout& layout) { CompNode cn, TensorLayout& layout) {
DeviceTensorND dev_tensor; DeviceTensorND dev_tensor;
if (custom_allocator) {
DeviceTensorStorage storage(cn);
size_t sz = layout.dtype.size(layout.total_nr_elems());
storage.reset(cn, sz, custom_allocator(cn, sz));
dev_tensor.reset(storage, layout);
return dev_tensor;
}
MGB_TRY { return alloc_workspace(cn, layout); } MGB_TRY { return alloc_workspace(cn, layout); }
MGB_CATCH(MemAllocError&, { MGB_CATCH(MemAllocError&, {
mgb_log_warn("memory allocation failed for workspace; try defragmenting"); mgb_log_warn("memory allocation failed for workspace; try defragmenting");
@@ -78,6 +89,10 @@ DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout
return dev_tensor; return dev_tensor;
} }


void BlobManagerImpl::set_allocator(allocator_t allocator) {
custom_allocator = allocator;
}

void BlobManagerImpl::defrag(const CompNode& cn) { void BlobManagerImpl::defrag(const CompNode& cn) {
BlobSetWithMux* blobs_set_ptr; BlobSetWithMux* blobs_set_ptr;
{ {
@@ -159,6 +174,9 @@ struct BlobManagerStub : BlobManager {
void defrag(const CompNode& cn) { void defrag(const CompNode& cn) {
mgb_assert(0, "prohibited after global variable destruction"); mgb_assert(0, "prohibited after global variable destruction");
}; };
virtual void set_allocator(allocator_t allocator) {
mgb_assert(0, "prohibited after global variable destruction");
};
}; };


BlobManager* BlobManager::inst() { BlobManager* BlobManager::inst() {


+ 4
- 0
imperative/src/impl/blob_manager_impl.h View File

@@ -45,6 +45,8 @@ class BlobManagerImpl final : public BlobManager {


DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout);


BlobManager::allocator_t custom_allocator;

public: public:
static BlobManager* inst(); static BlobManager* inst();


@@ -56,6 +58,8 @@ public:
void register_blob(Blob* blob) override; void register_blob(Blob* blob) override;


void unregister_blob(Blob* blob) override; void unregister_blob(Blob* blob) override;

void set_allocator(allocator_t allocator) override;
}; };


} // namespace imperative } // namespace imperative


+ 1
- 0
imperative/src/impl/interpreter/commands.h View File

@@ -49,6 +49,7 @@ struct ApplyOp {
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs; SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs; SmallVector<TensorInfo*> outputs;
SmallVector<LogicalTensorDesc> outputs_descs;
bool validated = false; bool validated = false;


template <typename TFunctor> template <typename TFunctor>


+ 33
- 20
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -114,11 +114,13 @@ ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() { void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
sys::set_thread_name("worker"); sys::set_thread_name("worker");
m_owner->m_worker_state.tid = std::this_thread::get_id(); m_owner->m_worker_state.tid = std::this_thread::get_id();
OpDef::set_allocator([&](CompNode device, size_t size) {
auto custom_allocator = [&](CompNode device, size_t size) {
auto blob = Blob::make(device, size); auto blob = Blob::make(device, size);
m_owner->alloc_tensor_with_evict(blob.get()); m_owner->alloc_tensor_with_evict(blob.get());
return blob->storage(); return blob->storage();
});
};
OpDef::set_allocator(custom_allocator);
BlobManager::inst()->set_allocator(custom_allocator);
} }


// Do not use m_xxx_state directly // Do not use m_xxx_state directly
@@ -353,7 +355,7 @@ void ChannelImpl::dispatch_kernel(
for (int i = 0; i < output_descs.size(); ++i) { for (int i = 0; i < output_descs.size(); ++i) {
auto&& desc = output_descs[i]; auto&& desc = output_descs[i];
auto info = alloc(); auto info = alloc();
init(info, std::move(desc));
init(info, desc);
// make sure desc's value is consistent with h_value // make sure desc's value is consistent with h_value
if (!info->desc.value.empty()) { if (!info->desc.value.empty()) {
info->h_value = HostTensorND::make_proxy(desc.value) info->h_value = HostTensorND::make_proxy(desc.value)
@@ -362,9 +364,9 @@ void ChannelImpl::dispatch_kernel(
output_infos.push_back(info); output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info)); outputs->push_back(reinterpret_cast<Handle>(info));
} }
ApplyOp cmd{
Profiler::next_id(), std::move(op), std::move(input_infos),
std::move(output_infos), validated};
ApplyOp cmd{Profiler::next_id(), std::move(op),
std::move(input_infos), std::move(output_infos),
std::move(output_descs), validated};
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
auto op_info_getter = [op = cmd.op] { auto op_info_getter = [op = cmd.op] {
std::unordered_map<std::string, std::string> op_info; std::unordered_map<std::string, std::string> op_info;
@@ -594,7 +596,7 @@ TensorInfo* ChannelImpl::alloc() {
return info; return info;
} }


void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
m_valid_handle.insert(reinterpret_cast<Handle>(info)); m_valid_handle.insert(reinterpret_cast<Handle>(info));
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated; info->status = TensorInfo::Allocated;
@@ -692,6 +694,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
"shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(), "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
ptr->layout().to_string().c_str()); ptr->layout().to_string().c_str());
} }
// in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled
if (state.options.enable_dtr_auto_drop) {
ptr->to_contiguous_inplace();
}
dest->desc.layout = ptr->layout(); dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node(); dest->desc.comp_node = ptr->comp_node();
dest->memory = ptr->blob()->size(); dest->memory = ptr->blob()->size();
@@ -719,8 +726,9 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == EvictType::DROP) { if (dest->evict_type == EvictType::DROP) {
auto&& path = dest->producer; auto&& path = dest->producer;
m_apply_stack.push( m_apply_stack.push(
{ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
"dtr"});
{ApplyOp{path->id, path->op, path->inputs, path->outputs,
path->outputs_descs},
0, dest, "dtr"});
if (!m_applying) if (!m_applying)
flush_apply_stack(); flush_apply_stack();
} }
@@ -812,8 +820,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
} }
// Apply op // Apply op
SmallVector<LogicalTensorDesc> output_descs; SmallVector<LogicalTensorDesc> output_descs;
for (auto i : cmd.outputs) {
output_descs.push_back(i->desc);
for (auto i : cmd.outputs_descs) {
output_descs.push_back(i);
} }
// Here std::move is REQUIRED for removing duplicated references. // Here std::move is REQUIRED for removing duplicated references.
auto outputs = apply_on_physical_tensor( auto outputs = apply_on_physical_tensor(
@@ -1031,6 +1039,7 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
} }


void ChannelImpl::alloc_tensor_with_evict(Blob* x) { void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
bool in_worker = (get_worker_tid() == std::this_thread::get_id());
auto reserve_size = [&](size_t size) { auto reserve_size = [&](size_t size) {
if (!m_dtr.comp_node.valid()) { if (!m_dtr.comp_node.valid()) {
return false; return false;
@@ -1043,17 +1052,21 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
return true; return true;
}; };
auto pre_level = set_log_level(LogLevel::NO_LOG); auto pre_level = set_log_level(LogLevel::NO_LOG);
reserve_size(x->size());
if (in_worker) {
reserve_size(x->size());
}
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); } MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, { MGB_CATCH(MemAllocError&, {
bool suc = false; bool suc = false;
while (!suc) {
if (!auto_evict(1)) {
break;
if (in_worker) {
while (!suc) {
if (!auto_evict(1)) {
break;
}
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, { continue; });
suc = true;
} }
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, { continue; });
suc = true;
} }
if (!suc) { if (!suc) {
set_log_level(pre_level); set_log_level(pre_level);
@@ -1143,10 +1156,10 @@ void ChannelImpl::process_one_task(Command& icmd) {


if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
TensorInfo::ComputePath::make( TensorInfo::ComputePath::make(
cmd.id, cmd.op, cmd.inputs, cmd.outputs);
cmd.id, cmd.op, cmd.inputs, cmd.outputs, cmd.outputs_descs);
size_t detach_cnt = 0; size_t detach_cnt = 0;
if (!strcmp(get_name(*cmd.op), "BatchNorm") && if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
cmd.outputs.size() == 5) {
cmd.outputs.size() == 6) {
cmd.outputs[0]->detach_producer(); // detach running_mean cmd.outputs[0]->detach_producer(); // detach running_mean
cmd.outputs[1]->detach_producer(); // detach running_var cmd.outputs[1]->detach_producer(); // detach running_var
for (auto input : cmd.inputs) { for (auto input : cmd.inputs) {


+ 1
- 1
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -77,7 +77,7 @@ private:
struct State; struct State;


TensorInfo* alloc(); TensorInfo* alloc();
void init(TensorInfo*, LogicalTensorDesc&& desc);
void init(TensorInfo*, LogicalTensorDesc desc);
void free(TensorInfo*); void free(TensorInfo*);
void real_free(TensorInfo*); void real_free(TensorInfo*);
void recursive_free(TensorInfo*); void recursive_free(TensorInfo*);


+ 4
- 1
imperative/src/impl/interpreter/tensor_info.h View File

@@ -91,6 +91,7 @@ struct TensorInfo {
SmallVector<TensorInfo*> inputs; SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> unique_inputs; SmallVector<TensorInfo*> unique_inputs;
SmallVector<TensorInfo*> outputs; SmallVector<TensorInfo*> outputs;
SmallVector<LogicalTensorDesc> outputs_descs;


size_t ref_cnt() { size_t ref_cnt() {
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr); return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
@@ -98,12 +99,14 @@ struct TensorInfo {


static ComputePath* make( static ComputePath* make(
uint64_t id, std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, uint64_t id, std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs,
SmallVector<TensorInfo*> outputs) {
SmallVector<TensorInfo*> outputs,
SmallVector<LogicalTensorDesc> outputs_descs) {
auto* path = new TensorInfo::ComputePath(); auto* path = new TensorInfo::ComputePath();
path->id = id; path->id = id;
path->op = op; path->op = op;
path->inputs = inputs; path->inputs = inputs;
path->outputs = outputs; path->outputs = outputs;
path->outputs_descs = outputs_descs;
// dedup // dedup
SmallVector<TensorInfo*> unique_inputs = inputs; SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end()); std::sort(unique_inputs.begin(), unique_inputs.end());


+ 1
- 1
imperative/src/impl/physical_tensor.cpp View File

@@ -87,7 +87,7 @@ Blob::~Blob() {
} }


const Blob::RawStorage& Blob::storage() { const Blob::RawStorage& Blob::storage() {
if (!m_storage) {
if (!m_storage && m_size) {
BlobManager::inst()->alloc_with_defrag(this, m_size); BlobManager::inst()->alloc_with_defrag(this, m_size);
} }
return m_storage; return m_storage;


+ 4
- 0
imperative/src/include/megbrain/imperative/blob_manager.h View File

@@ -18,6 +18,8 @@ namespace imperative {


class BlobManager : public NonCopyableObj { class BlobManager : public NonCopyableObj {
public: public:
using allocator_t =
std::function<DeviceTensorStorage::RawStorage(CompNode, size_t)>;
virtual ~BlobManager() = default; virtual ~BlobManager() = default;


static BlobManager* inst(); static BlobManager* inst();
@@ -26,6 +28,8 @@ public:


virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; virtual void alloc_with_defrag(Blob* blob, size_t size) = 0;


virtual void set_allocator(allocator_t allocator) = 0;

virtual DeviceTensorND alloc_workspace_with_defrag( virtual DeviceTensorND alloc_workspace_with_defrag(
CompNode cn, TensorLayout& layout) = 0; CompNode cn, TensorLayout& layout) = 0;




+ 1
- 1
imperative/src/include/megbrain/imperative/physical_tensor.h View File

@@ -119,7 +119,7 @@ public:
return make_scalar(value, m_blob->comp_node()); return make_scalar(value, m_blob->comp_node());
} }


BlobPtr blob() { return m_blob; }
BlobPtr& blob() { return m_blob; }


void fetch_value(); void fetch_value();
bool value_fetched(); bool value_fetched();


Loading…
Cancel
Save