Browse Source

refactor(interpreter): wrap accesses to channel/worker state

GitOrigin-RevId: 1d58f2c876
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
1ce40b5bf7
2 changed files with 162 additions and 97 deletions
  1. +153
    -96
      imperative/src/impl/interpreter/interpreter_impl.cpp
  2. +9
    -1
      imperative/src/impl/interpreter/interpreter_impl.h

+ 153
- 96
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -23,6 +23,23 @@ using namespace imperative;
using namespace interpreter; using namespace interpreter;
using namespace interpreter::intl; using namespace interpreter::intl;


std::thread::id ChannelImpl::get_worker_tid() {
return m_worker_state.tid;
}

ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
assert_in_channel();
return m_channel_state;
}

ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
assert_in_worker();
return m_worker_state;
}

#define m_channel_state
#define m_worker_state

std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() { std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
return std::make_unique<ChannelImpl>(); return std::make_unique<ChannelImpl>();
} }
@@ -48,13 +65,14 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
} }


Handle ChannelImpl::put(const DeviceTensorND& data) { Handle ChannelImpl::put(const DeviceTensorND& data) {
auto& state = get_channel_state();
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto info = alloc(); auto info = alloc();
info->desc.layout = data.layout(); info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node(); info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data); info->ptr = Tensor::make(data);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
} }
return info; return info;
} }
@@ -71,7 +89,8 @@ void ChannelImpl::del(Handle handle) {


void ChannelImpl::swap_in(Handle handle) { void ChannelImpl::swap_in(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) {
auto& state = get_channel_state();
if (state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
@@ -81,7 +100,8 @@ void ChannelImpl::swap_in(Handle handle) {


void ChannelImpl::swap_out(Handle handle) { void ChannelImpl::swap_out(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) {
auto& state = get_channel_state();
if (state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
@@ -91,7 +111,8 @@ void ChannelImpl::swap_out(Handle handle) {


void ChannelImpl::drop(Handle handle) { void ChannelImpl::drop(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_drop) {
auto& state = get_channel_state();
if (state.options.enable_drop) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
@@ -104,6 +125,7 @@ void ChannelImpl::dispatch_default_cpu(
const SmallVector<TensorInfo*>& input_infos, const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs, const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) { SmallVector<Handle>* outputs) {
auto& state = get_channel_state();
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
MGB_MARK_USED_VAR(validated); MGB_MARK_USED_VAR(validated);


@@ -147,8 +169,8 @@ void ChannelImpl::dispatch_default_cpu(
return tid; return tid;
}; };
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}};
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data);
if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpExecuteEvent>(event_data);
} }


OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
@@ -166,8 +188,8 @@ void ChannelImpl::dispatch_default_cpu(
} }


event_data.outputs = tinfo_to_tid(output_infos); event_data.outputs = tinfo_to_tid(output_infos);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data);
if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpFinishEvent>(event_data);
} }
} }


@@ -176,6 +198,7 @@ void ChannelImpl::dispatch_kernel(
const SmallVector<TensorInfo*>& input_infos, const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs, const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) { SmallVector<Handle>* outputs) {
auto& state = get_channel_state();
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);


ApplyOp cmd{std::move(op)}; ApplyOp cmd{std::move(op)};
@@ -194,9 +217,9 @@ void ChannelImpl::dispatch_kernel(
outputs->push_back(info); outputs->push_back(info);
} }
m_buffer.enqueue(std::move(cmd)); m_buffer.enqueue(std::move(cmd));
if (!validated && m_channel_state.options.async_level == 1) {
if (!validated && state.options.async_level == 1) {
sync(); sync();
} else if (m_channel_state.options.async_level == 0) {
} else if (state.options.async_level == 0) {
sync(); sync();
// check device error // check device error
for (auto&& oup : *outputs) { for (auto&& oup : *outputs) {
@@ -210,6 +233,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op, std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) { const SmallVector<Handle>& inputs) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
for (auto i : inputs) { for (auto i : inputs) {
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
"invalid handle: %p", i); "invalid handle: %p", i);
@@ -229,7 +253,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
} }


SmallVector<Handle> outputs; SmallVector<Handle> outputs;
DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute
DispatchMode dispatch_mode = state.options.enable_host_compute
? OpDef::decide_dispatch_mode(*op, input_descs) ? OpDef::decide_dispatch_mode(*op, input_descs)
: DispatchMode::KERNEL; : DispatchMode::KERNEL;
switch (dispatch_mode) { switch (dispatch_mode) {
@@ -247,6 +271,7 @@ SmallVector<Handle> ChannelImpl::apply_op(


HostTensorND ChannelImpl::get_value(Handle handle) { HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
// TODO: maybe get_value should be done on host. i.e. delete GetValue // TODO: maybe get_value should be done on host. i.e. delete GetValue
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
@@ -262,16 +287,16 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
if (!value_fetched()) { if (!value_fetched()) {
m_waitee = info; m_waitee = info;
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
} }
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
tensor_ptr = info->ptr; tensor_ptr = info->ptr;
return value_fetched(); return value_fetched();
}); });
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
} }
m_waitee = nullptr; m_waitee = nullptr;
} }
@@ -280,6 +305,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {


TensorShape ChannelImpl::get_shape(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -290,15 +316,15 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
m_waitee = info; m_waitee = info;
m_buffer.flush(); m_buffer.flush();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
} }
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return static_cast<bool>(info->ptr); return static_cast<bool>(info->ptr);
}); });
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
} }
m_waitee = nullptr; m_waitee = nullptr;
TensorShape ret = info->ptr->layout(); TensorShape ret = info->ptr->layout();
@@ -308,11 +334,12 @@ TensorShape ChannelImpl::get_shape(Handle handle) {


DType ChannelImpl::get_dtype(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
} }
auto ret = info->desc.layout.dtype; auto ret = info->desc.layout.dtype;
mgb_assert(ret.valid()); mgb_assert(ret.valid());
@@ -321,11 +348,12 @@ DType ChannelImpl::get_dtype(Handle handle) {


CompNode ChannelImpl::get_device(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
} }
auto ret = info->desc.comp_node; auto ret = info->desc.comp_node;
mgb_assert(ret.valid()); mgb_assert(ret.valid());
@@ -334,6 +362,7 @@ CompNode ChannelImpl::get_device(Handle handle) {


DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -341,15 +370,15 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
m_waitee = info; m_waitee = info;
m_buffer.flush(); m_buffer.flush();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
} }
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return static_cast<bool>(info->ptr); return static_cast<bool>(info->ptr);
}); });
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
} }
m_waitee = nullptr; m_waitee = nullptr;
return info->ptr->dev_tensor(); return info->ptr->dev_tensor();
@@ -357,14 +386,15 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {


void ChannelImpl::sync() { void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
m_buffer.flush(); m_buffer.flush();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncStartEvent>();
if (state.profiler->is_profiling()) {
state.profiler->record_host<SyncStartEvent>();
} }
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
CompNode::sync_all(); CompNode::sync_all();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncFinishEvent>();
if (state.profiler->is_profiling()) {
state.profiler->record_host<SyncFinishEvent>();
} }
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe(); check_worker_exc_unsafe();
@@ -386,22 +416,25 @@ void ChannelImpl::close() {


size_t ChannelImpl::get_option(std::string name) { size_t ChannelImpl::get_option(std::string name) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
return m_channel_state.options.get_option(name);
auto& state = get_channel_state();
return state.options.get_option(name);
} }


void ChannelImpl::set_option(std::string name, size_t value) { void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
m_channel_state.options.set_option(name, value);
auto& state = get_channel_state();
state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value}); m_buffer.enqueue(SetOption{name, value});
} }


TensorInfo* ChannelImpl::alloc() { TensorInfo* ChannelImpl::alloc() {
auto& state = get_channel_state();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
auto info = m_pool.alloc(); auto info = m_pool.alloc();
m_valid_handle.insert(info); m_valid_handle.insert(info);
info->id = m_last_id++; info->id = m_last_id++;
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorDeclareEvent>(info->id);
} }
return info; return info;
} }
@@ -422,7 +455,8 @@ void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
} }


void ChannelImpl::free(TensorInfo* ptr) { void ChannelImpl::free(TensorInfo* ptr) {
if (m_worker_state.options.enable_dtr_auto_drop) {
auto& state = get_worker_state();
if (state.options.enable_dtr_auto_drop) {
// Evicting a tensor, rather than freeing it, can avoid pinning // Evicting a tensor, rather than freeing it, can avoid pinning
// potentially exploding amounts of memory and allow us to save // potentially exploding amounts of memory and allow us to save
// more memory. // more memory.
@@ -455,11 +489,12 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) {
} }


void ChannelImpl::real_free(TensorInfo* ptr) { void ChannelImpl::real_free(TensorInfo* ptr) {
auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorEraseEvent>(ptr->id);
} }
if (ptr->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) {
if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(ptr); m_dtr.erase_candidate(ptr);
} }
detach_users(ptr); detach_users(ptr);
@@ -474,11 +509,14 @@ ChannelImpl::~ChannelImpl() {
} }


void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) {
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex)
: std::unique_lock<std::mutex>();
auto& state = get_worker_state();
auto lock = std::unique_lock<std::mutex>(m_mutex, std::defer_lock);
if (notice) {
lock.lock();
}
m_dtr.update_used_time(dest); m_dtr.update_used_time(dest);
if (notice && m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
if (notice && state.profiler->is_profiling()) {
state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
} }
dest->value_fetched = ptr->value_fetched(); dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer // update tensor desc for static infer
@@ -487,7 +525,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr
dest->memory = ptr->blob()->size(); dest->memory = ptr->blob()->size();
dest->ptr = std::move(ptr); dest->ptr = std::move(ptr);
dest->evict_type = EvictType::NONE; dest->evict_type = EvictType::NONE;
if (notice && dest->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) {
if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.insert_candidate(dest); m_dtr.insert_candidate(dest);
} }
if (notice && m_waitee == dest) { if (notice && m_waitee == dest) {
@@ -509,6 +547,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
} }


void ChannelImpl::recompute(TensorInfo::ComputePath* path) { void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
auto& state = get_worker_state();
SmallVector<TensorPtr> inputs; SmallVector<TensorPtr> inputs;
inputs.reserve(path->inputs.size()); inputs.reserve(path->inputs.size());
m_dtr.pin(path->inputs); m_dtr.pin(path->inputs);
@@ -519,7 +558,7 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
inputs.push_back(i->ptr); inputs.push_back(i->ptr);
m_dtr.update_used_time(i); m_dtr.update_used_time(i);
} }
if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) {
if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
auto_evict(); auto_evict();
} }
auto outputs = OpDef::apply_on_physical_tensor(*path->op, inputs); auto outputs = OpDef::apply_on_physical_tensor(*path->op, inputs);
@@ -531,7 +570,7 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
o->recompute_times ++; o->recompute_times ++;
if (!o->ptr) { if (!o->ptr) {
produce_tensor(o, std::move(outputs[i]), false); produce_tensor(o, std::move(outputs[i]), false);
if (m_worker_state.options.enable_dtr_auto_drop) {
if (state.options.enable_dtr_auto_drop) {
m_dtr.update_dsu_after_recompute(o); m_dtr.update_dsu_after_recompute(o);
} }
} }
@@ -540,11 +579,12 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
} }


void ChannelImpl::auto_evict() { void ChannelImpl::auto_evict() {
auto& state = get_worker_state();
if (!m_dtr.comp_node.valid()) { if (!m_dtr.comp_node.valid()) {
return; return;
} }
size_t current_memory = m_dtr.comp_node.get_used_memory(); size_t current_memory = m_dtr.comp_node.get_used_memory();
while (current_memory > m_worker_state.options.dtr_eviction_threshold) {
while (current_memory > state.options.dtr_eviction_threshold) {
auto best = m_dtr.find_best_tensor(); auto best = m_dtr.find_best_tensor();
if (!best) { if (!best) {
if (!m_dtr.warn_printed) { if (!m_dtr.warn_printed) {
@@ -592,13 +632,14 @@ bool ChannelImpl::check_available() {
} }


void ChannelImpl::sync_device_scope(CompNode device) { void ChannelImpl::sync_device_scope(CompNode device) {
auto& prev = m_worker_state.device_scope_map[device];
auto& current = m_worker_state.scopes;
auto& state = get_worker_state();
auto& prev = state.device_scope_map[device];
auto& current = state.scopes;
auto push_scope = [&](std::string name) { auto push_scope = [&](std::string name) {
m_worker_state.profiler->record_device<DeviceBeginScope>(device, name);
state.profiler->record_device<DeviceBeginScope>(device, name);
}; };
auto pop_scope = [&](std::string name) { auto pop_scope = [&](std::string name) {
m_worker_state.profiler->record_device<DeviceEndScope>(device, name);
state.profiler->record_device<DeviceEndScope>(device, name);
}; };
size_t similarity = 0; size_t similarity = 0;
for (size_t i = 0; i < prev.size() && i < current.size(); i++) { for (size_t i = 0; i < prev.size() && i < current.size(); i++) {
@@ -619,16 +660,17 @@ void ChannelImpl::sync_device_scope(CompNode device) {
} }


void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd);
auto& state = get_worker_state();
if (state.profiler->is_profiling()) {
state.profiler->record_host<CommandExecuteEvent>(icmd);
} }
bool finished = false; bool finished = false;
auto do_finish_command = [&]{ auto do_finish_command = [&]{
if (finished) { if (finished) {
return; return;
} }
if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd);
if (state.profiler->is_profiling()) {
state.profiler->record_host<CommandFinishEvent>(icmd);
} }
finished = true; finished = true;
}; };
@@ -642,7 +684,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
uint64_t apply_id = ++m_last_id; uint64_t apply_id = ++m_last_id;
SmallVector<TensorPtr> tensor_inputs; SmallVector<TensorPtr> tensor_inputs;
SmallVector<CompNode> devices; SmallVector<CompNode> devices;
if (m_worker_state.options.enable_dtr_auto_drop) {
if (state.options.enable_dtr_auto_drop) {
m_dtr.pin(cmd.inputs); m_dtr.pin(cmd.inputs);
} }
for (auto i : cmd.inputs) { for (auto i : cmd.inputs) {
@@ -660,7 +702,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
} }
// Begin profiling operator // Begin profiling operator
OpEvent event_data; OpEvent event_data;
if (m_worker_state.profiler->is_profiling()) {
if (state.profiler->is_profiling()) {
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
SmallVector<uint64_t> tid; SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) { for (auto* ptinfo: tinfo) {
@@ -689,14 +731,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
// Before wait // Before wait
//TODO: split operator wait and execute so that OpWait could be corrected recorded. //TODO: split operator wait and execute so that OpWait could be corrected recorded.
// Before execute // Before execute
if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<HostOpExecuteEvent>(event_data);
if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpExecuteEvent>(event_data);
for (auto&& device: devices) { for (auto&& device: devices) {
sync_device_scope(device); sync_device_scope(device);
m_worker_state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data);
state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data);
} }
} }
if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) {
if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
auto_evict(); auto_evict();
} }
// Apply op // Apply op
@@ -704,15 +746,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
auto tensor_outputs = OpDef::apply_on_physical_tensor( auto tensor_outputs = OpDef::apply_on_physical_tensor(
*cmd.op, std::move(tensor_inputs)); *cmd.op, std::move(tensor_inputs));
// After execute // After execute
if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<HostOpFinishEvent>(event_data);
if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpFinishEvent>(event_data);
for (auto&& device: devices) { for (auto&& device: devices) {
m_worker_state.profiler->record_device<DeviceOpFinishEvent>(device, event_data);
state.profiler->record_device<DeviceOpFinishEvent>(device, event_data);
} }
} }
// End profiling operator // End profiling operator
double estimate_compute_time = 0; double estimate_compute_time = 0;
if (m_worker_state.options.enable_dtr_auto_drop) {
if (state.options.enable_dtr_auto_drop) {
for (auto i : cmd.inputs) { for (auto i : cmd.inputs) {
estimate_compute_time += i->memory; estimate_compute_time += i->memory;
} }
@@ -735,12 +777,12 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
continue; continue;
} }
produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
if (m_worker_state.options.enable_dtr_auto_drop) {
if (state.options.enable_dtr_auto_drop) {
cmd.outputs[i]->dsu_ptr = std::make_shared<DsuNode>(estimate_compute_time); cmd.outputs[i]->dsu_ptr = std::make_shared<DsuNode>(estimate_compute_time);
} }
} }
if (m_worker_state.options.enable_drop == 1
&& m_worker_state.options.record_computing_path == 1){
if (state.options.enable_drop == 1
&& state.options.record_computing_path == 1){
bool is_inplace = false; bool is_inplace = false;
bool cross_cn = false; bool cross_cn = false;
for (auto input : cmd.inputs) { for (auto input : cmd.inputs) {
@@ -774,7 +816,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
size_t detach_cnt = 0; size_t detach_cnt = 0;
for (auto output : cmd.outputs) { for (auto output : cmd.outputs) {
if (!output->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) {
if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
output->detach_producer(); output->detach_producer();
detach_cnt ++; detach_cnt ++;
} }
@@ -808,21 +850,22 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
} else if constexpr (std::is_same_v<T, Drop>) { } else if constexpr (std::is_same_v<T, Drop>) {
do_drop(cmd.dest, true); do_drop(cmd.dest, true);
} else if constexpr (std::is_same_v<T, SetOption>) { } else if constexpr (std::is_same_v<T, SetOption>) {
m_worker_state.options.set_option(cmd.key, cmd.value);
state.options.set_option(cmd.key, cmd.value);
} else if constexpr (std::is_same_v<T, StartProfile>) { } else if constexpr (std::is_same_v<T, StartProfile>) {
CompNode::sync_all(); CompNode::sync_all();
m_worker_state.profiler.reset(cmd.profiler);
state.profiler.reset(cmd.profiler);
} else if constexpr (std::is_same_v<T, StopProfile>) { } else if constexpr (std::is_same_v<T, StopProfile>) {
for (auto&& [device, scopes]: m_worker_state.device_scope_map) {
for (auto&& [device, scopes]: state.device_scope_map) {
MGB_MARK_USED_VAR(scopes); MGB_MARK_USED_VAR(scopes);
sync_device_scope(device); sync_device_scope(device);
} }
do_finish_command(); do_finish_command();
auto profiler = std::make_unique<InterpreterProfiler>(); auto profiler = std::make_unique<InterpreterProfiler>();
std::swap(profiler, m_worker_state.profiler);
std::swap(profiler, state.profiler);
auto records = profiler->stop(); auto records = profiler->stop();
auto host_map = [this](std::thread::id tid) {
if (tid == m_worker_state.tid) {
auto worker_tid = get_worker_tid();
auto host_map = [worker_tid](std::thread::id tid) {
if (tid == worker_tid) {
return "worker"; return "worker";
} else { } else {
return "unknown"; return "unknown";
@@ -830,21 +873,21 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
}; };
InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map); InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map);
} else if constexpr (std::is_same_v<T, PushScope>) { } else if constexpr (std::is_same_v<T, PushScope>) {
m_worker_state.scopes.push_back(cmd.scope_name);
state.scopes.push_back(cmd.scope_name);
do_finish_command(); do_finish_command();
m_worker_state.profiler->record_host<WorkerBeginScope>(cmd.scope_name);
state.profiler->record_host<WorkerBeginScope>(cmd.scope_name);
} else if constexpr (std::is_same_v<T, PopScope>) { } else if constexpr (std::is_same_v<T, PopScope>) {
mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch");
m_worker_state.scopes.pop_back();
mgb_assert(state.scopes.back() == cmd.scope_name, "scope name mismatch");
state.scopes.pop_back();
do_finish_command(); do_finish_command();
m_worker_state.profiler->record_host<WorkerEndScope>(cmd.scope_name);
state.profiler->record_host<WorkerEndScope>(cmd.scope_name);
} else { } else {
static_assert(!std::is_same_v<T, T>); static_assert(!std::is_same_v<T, T>);
} }
}; };
std::visit([&](const auto& cmd){ std::visit([&](const auto& cmd){
using T = std::decay_t<decltype(cmd)>; using T = std::decay_t<decltype(cmd)>;
if (!m_worker_state.options.catch_worker_execption) {
if (!state.options.catch_worker_execption) {
cmd_visitor(cmd); cmd_visitor(cmd);
return; return;
} }
@@ -891,11 +934,12 @@ void ChannelImpl::CommandBuffer::flush() {
} }


void ChannelImpl::CommandBuffer::flush(Handle pos) { void ChannelImpl::CommandBuffer::flush(Handle pos) {
auto& state = m_owner->get_channel_state();
for (auto iter = m_commands.begin(); iter != pos; ++iter) { for (auto iter = m_commands.begin(); iter != pos; ++iter) {
// mgb_log_debug("%s Flushed", to_string(*iter).c_str()); // mgb_log_debug("%s Flushed", to_string(*iter).c_str());
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
if (m_owner->m_channel_state.profiler->is_profiling()) {
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd);
if (state.profiler->is_profiling()) {
state.profiler->record_host<CommandEnqueueEvent>(icmd);
} }
m_owner->m_worker.add_task(std::move(icmd)); m_owner->m_worker.add_task(std::move(icmd));
} }
@@ -903,7 +947,8 @@ void ChannelImpl::CommandBuffer::flush(Handle pos) {
} }


auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
return std::visit([this](const auto& cmd) {
auto& state = m_owner->get_channel_state();
return std::visit([&, this](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>; using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) { if constexpr (std::is_same_v<T, ApplyOp>) {
auto* op_type = cmd.op->dyn_typeinfo(); auto* op_type = cmd.op->dyn_typeinfo();
@@ -917,7 +962,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
} else if constexpr (std::is_same_v<T, GetValue>) { } else if constexpr (std::is_same_v<T, GetValue>) {
return m_commands.end(); return m_commands.end();
} }
size_t buffer_length = m_owner->m_channel_state.options.buffer_length;
size_t buffer_length = state.options.buffer_length;
if (m_commands.size() > buffer_length) { if (m_commands.size() > buffer_length) {
return m_commands.begin() + (m_commands.size() - buffer_length); return m_commands.begin() + (m_commands.size() - buffer_length);
} }
@@ -993,42 +1038,54 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)


void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
auto profiler_option = InterpreterProfiler::Option::from_dict(option); auto profiler_option = InterpreterProfiler::Option::from_dict(option);
auto profiler = std::make_unique<InterpreterProfiler>(); auto profiler = std::make_unique<InterpreterProfiler>();
profiler->set_option(profiler_option); profiler->set_option(profiler_option);
profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic)); profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic));
std::swap(profiler, m_channel_state.profiler);
m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()});
std::swap(profiler, state.profiler);
m_buffer.enqueue(StartProfile{state.profiler.get()});
} }


void ChannelImpl::stop_profile(std::string basename, std::string format) { void ChannelImpl::stop_profile(std::string basename, std::string format) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
m_buffer.flush(); m_buffer.flush();
auto profiler = std::make_unique<InterpreterProfiler>(); auto profiler = std::make_unique<InterpreterProfiler>();
std::swap(profiler, m_channel_state.profiler);
std::swap(profiler, state.profiler);
profiler.release(); profiler.release();
m_buffer.enqueue(StopProfile{basename, format}); m_buffer.enqueue(StopProfile{basename, format});
} }


void ChannelImpl::push_scope(std::string name) { void ChannelImpl::push_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_channel_state.scopes.push_back(name);
auto& state = get_channel_state();
if (state.profiler->is_profiling()) {
state.profiler->record_host<ChannelBeginScope>(name);
state.scopes.push_back(name);
m_buffer.enqueue(PushScope{name}); m_buffer.enqueue(PushScope{name});
} }
} }


void ChannelImpl::pop_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.scopes.pop_back();
m_channel_state.profiler->record_host<ChannelEndScope>(name);
auto& state = get_channel_state();
if (state.profiler->is_profiling()) {
mgb_assert((!state.scopes.empty()) && state.scopes.back() == name, "scope name mismatch");
state.scopes.pop_back();
state.profiler->record_host<ChannelEndScope>(name);
m_buffer.enqueue(PopScope{name}); m_buffer.enqueue(PopScope{name});
} }
} }


void ChannelImpl::assert_in_channel() {
mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread");
}

void ChannelImpl::assert_in_worker() {
mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread");
}

void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) { void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
for (auto i : vec) { for (auto i : vec) {
i->pin(); i->pin();


+ 9
- 1
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -90,7 +90,6 @@ private:


void regenerate(TensorInfo* dest); void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path); void recompute(TensorInfo::ComputePath* path);


void dispatch_default_cpu( void dispatch_default_cpu(
std::shared_ptr<OpDef> op, std::shared_ptr<OpDef> op,
@@ -105,6 +104,10 @@ private:


bool check_available(); bool check_available();


void assert_in_channel();
void assert_in_worker();
std::thread::id get_worker_tid();

void sync_device_scope(CompNode device); void sync_device_scope(CompNode device);


template <typename TCommand> template <typename TCommand>
@@ -319,6 +322,11 @@ private:


//! automatically evict an optimal tensor //! automatically evict an optimal tensor
void auto_evict(); void auto_evict();

// assert thread id when call get_xxx_state to avoid misuse
ChannelState& get_channel_state();
WorkerState& get_worker_state();

}; };


} // namespace mgb::imperative::interpreter::intl } // namespace mgb::imperative::interpreter::intl

Loading…
Cancel
Save