|
|
@@ -23,6 +23,23 @@ using namespace imperative; |
|
|
|
using namespace interpreter; |
|
|
|
using namespace interpreter::intl; |
|
|
|
|
|
|
|
std::thread::id ChannelImpl::get_worker_tid() { |
|
|
|
return m_worker_state.tid; |
|
|
|
} |
|
|
|
|
|
|
|
ChannelImpl::ChannelState& ChannelImpl::get_channel_state() { |
|
|
|
assert_in_channel(); |
|
|
|
return m_channel_state; |
|
|
|
} |
|
|
|
|
|
|
|
ChannelImpl::WorkerState& ChannelImpl::get_worker_state() { |
|
|
|
assert_in_worker(); |
|
|
|
return m_worker_state; |
|
|
|
} |
|
|
|
|
|
|
|
#define m_channel_state |
|
|
|
#define m_worker_state |
|
|
|
|
|
|
|
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() { |
|
|
|
return std::make_unique<ChannelImpl>(); |
|
|
|
} |
|
|
@@ -48,13 +65,14 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { |
|
|
|
} |
|
|
|
|
|
|
|
Handle ChannelImpl::put(const DeviceTensorND& data) { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto info = alloc(); |
|
|
|
info->desc.layout = data.layout(); |
|
|
|
info->desc.comp_node = data.comp_node(); |
|
|
|
info->ptr = Tensor::make(data); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node); |
|
|
|
} |
|
|
|
return info; |
|
|
|
} |
|
|
@@ -71,7 +89,8 @@ void ChannelImpl::del(Handle handle) { |
|
|
|
|
|
|
|
void ChannelImpl::swap_in(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_worker_state.options.enable_swap) { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
if (state.options.enable_swap) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -81,7 +100,8 @@ void ChannelImpl::swap_in(Handle handle) { |
|
|
|
|
|
|
|
void ChannelImpl::swap_out(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_worker_state.options.enable_swap) { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
if (state.options.enable_swap) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -91,7 +111,8 @@ void ChannelImpl::swap_out(Handle handle) { |
|
|
|
|
|
|
|
void ChannelImpl::drop(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_worker_state.options.enable_drop) { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
if (state.options.enable_drop) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -104,6 +125,7 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
const SmallVector<TensorInfo*>& input_infos, |
|
|
|
const SmallVector<LogicalTensorDesc>& input_descs, |
|
|
|
SmallVector<Handle>* outputs) { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); |
|
|
|
MGB_MARK_USED_VAR(validated); |
|
|
|
|
|
|
@@ -147,8 +169,8 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
return tid; |
|
|
|
}; |
|
|
|
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<HostOpExecuteEvent>(event_data); |
|
|
|
} |
|
|
|
|
|
|
|
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); |
|
|
@@ -166,8 +188,8 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
} |
|
|
|
|
|
|
|
event_data.outputs = tinfo_to_tid(output_infos); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<HostOpFinishEvent>(event_data); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -176,6 +198,7 @@ void ChannelImpl::dispatch_kernel( |
|
|
|
const SmallVector<TensorInfo*>& input_infos, |
|
|
|
const SmallVector<LogicalTensorDesc>& input_descs, |
|
|
|
SmallVector<Handle>* outputs) { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); |
|
|
|
|
|
|
|
ApplyOp cmd{std::move(op)}; |
|
|
@@ -194,9 +217,9 @@ void ChannelImpl::dispatch_kernel( |
|
|
|
outputs->push_back(info); |
|
|
|
} |
|
|
|
m_buffer.enqueue(std::move(cmd)); |
|
|
|
if (!validated && m_channel_state.options.async_level == 1) { |
|
|
|
if (!validated && state.options.async_level == 1) { |
|
|
|
sync(); |
|
|
|
} else if (m_channel_state.options.async_level == 0) { |
|
|
|
} else if (state.options.async_level == 0) { |
|
|
|
sync(); |
|
|
|
// check device error |
|
|
|
for (auto&& oup : *outputs) { |
|
|
@@ -210,6 +233,7 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
std::shared_ptr<OpDef> op, |
|
|
|
const SmallVector<Handle>& inputs) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
for (auto i : inputs) { |
|
|
|
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", i); |
|
|
@@ -229,7 +253,7 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<Handle> outputs; |
|
|
|
DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute |
|
|
|
DispatchMode dispatch_mode = state.options.enable_host_compute |
|
|
|
? OpDef::decide_dispatch_mode(*op, input_descs) |
|
|
|
: DispatchMode::KERNEL; |
|
|
|
switch (dispatch_mode) { |
|
|
@@ -247,6 +271,7 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
|
|
|
|
HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
// TODO: maybe get_value should be done on host. i.e. delete GetValue |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
@@ -262,16 +287,16 @@ HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
if (!value_fetched()) { |
|
|
|
m_waitee = info; |
|
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue); |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
tensor_ptr = info->ptr; |
|
|
|
return value_fetched(); |
|
|
|
}); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue); |
|
|
|
} |
|
|
|
m_waitee = nullptr; |
|
|
|
} |
|
|
@@ -280,6 +305,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
|
|
|
|
TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -290,15 +316,15 @@ TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
mgb_assert(!m_waitee); |
|
|
|
m_waitee = info; |
|
|
|
m_buffer.flush(); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape); |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return static_cast<bool>(info->ptr); |
|
|
|
}); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape); |
|
|
|
} |
|
|
|
m_waitee = nullptr; |
|
|
|
TensorShape ret = info->ptr->layout(); |
|
|
@@ -308,11 +334,12 @@ TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
|
|
|
|
DType ChannelImpl::get_dtype(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType); |
|
|
|
} |
|
|
|
auto ret = info->desc.layout.dtype; |
|
|
|
mgb_assert(ret.valid()); |
|
|
@@ -321,11 +348,12 @@ DType ChannelImpl::get_dtype(Handle handle) { |
|
|
|
|
|
|
|
CompNode ChannelImpl::get_device(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device); |
|
|
|
} |
|
|
|
auto ret = info->desc.comp_node; |
|
|
|
mgb_assert(ret.valid()); |
|
|
@@ -334,6 +362,7 @@ CompNode ChannelImpl::get_device(Handle handle) { |
|
|
|
|
|
|
|
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -341,15 +370,15 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
mgb_assert(!m_waitee); |
|
|
|
m_waitee = info; |
|
|
|
m_buffer.flush(); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue); |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return static_cast<bool>(info->ptr); |
|
|
|
}); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue); |
|
|
|
} |
|
|
|
m_waitee = nullptr; |
|
|
|
return info->ptr->dev_tensor(); |
|
|
@@ -357,14 +386,15 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
|
|
|
|
void ChannelImpl::sync() { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
m_buffer.flush(); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<SyncStartEvent>(); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<SyncStartEvent>(); |
|
|
|
} |
|
|
|
m_worker.wait_all_task_finish(); |
|
|
|
CompNode::sync_all(); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<SyncFinishEvent>(); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<SyncFinishEvent>(); |
|
|
|
} |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
check_worker_exc_unsafe(); |
|
|
@@ -386,22 +416,25 @@ void ChannelImpl::close() { |
|
|
|
|
|
|
|
size_t ChannelImpl::get_option(std::string name) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
return m_channel_state.options.get_option(name); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
return state.options.get_option(name); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::set_option(std::string name, size_t value) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
m_channel_state.options.set_option(name, value); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
state.options.set_option(name, value); |
|
|
|
m_buffer.enqueue(SetOption{name, value}); |
|
|
|
} |
|
|
|
|
|
|
|
TensorInfo* ChannelImpl::alloc() { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
auto info = m_pool.alloc(); |
|
|
|
m_valid_handle.insert(info); |
|
|
|
info->id = m_last_id++; |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorDeclareEvent>(info->id); |
|
|
|
} |
|
|
|
return info; |
|
|
|
} |
|
|
@@ -422,7 +455,8 @@ void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::free(TensorInfo* ptr) { |
|
|
|
if (m_worker_state.options.enable_dtr_auto_drop) { |
|
|
|
auto& state = get_worker_state(); |
|
|
|
if (state.options.enable_dtr_auto_drop) { |
|
|
|
// Evicting a tensor, rather than freeing it, can avoid pinning |
|
|
|
// potentially exploding amounts of memory and allow us to save |
|
|
|
// more memory. |
|
|
@@ -455,11 +489,12 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::real_free(TensorInfo* ptr) { |
|
|
|
auto& state = get_worker_state(); |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorEraseEvent>(ptr->id); |
|
|
|
} |
|
|
|
if (ptr->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) { |
|
|
|
if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { |
|
|
|
m_dtr.erase_candidate(ptr); |
|
|
|
} |
|
|
|
detach_users(ptr); |
|
|
@@ -474,11 +509,14 @@ ChannelImpl::~ChannelImpl() { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) { |
|
|
|
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex) |
|
|
|
: std::unique_lock<std::mutex>(); |
|
|
|
auto& state = get_worker_state(); |
|
|
|
auto lock = std::unique_lock<std::mutex>(m_mutex, std::defer_lock); |
|
|
|
if (notice) { |
|
|
|
lock.lock(); |
|
|
|
} |
|
|
|
m_dtr.update_used_time(dest); |
|
|
|
if (notice && m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node()); |
|
|
|
if (notice && state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node()); |
|
|
|
} |
|
|
|
dest->value_fetched = ptr->value_fetched(); |
|
|
|
// update tensor desc for static infer |
|
|
@@ -487,7 +525,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr |
|
|
|
dest->memory = ptr->blob()->size(); |
|
|
|
dest->ptr = std::move(ptr); |
|
|
|
dest->evict_type = EvictType::NONE; |
|
|
|
if (notice && dest->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) { |
|
|
|
if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { |
|
|
|
m_dtr.insert_candidate(dest); |
|
|
|
} |
|
|
|
if (notice && m_waitee == dest) { |
|
|
@@ -509,6 +547,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::recompute(TensorInfo::ComputePath* path) { |
|
|
|
auto& state = get_worker_state(); |
|
|
|
SmallVector<TensorPtr> inputs; |
|
|
|
inputs.reserve(path->inputs.size()); |
|
|
|
m_dtr.pin(path->inputs); |
|
|
@@ -519,7 +558,7 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) { |
|
|
|
inputs.push_back(i->ptr); |
|
|
|
m_dtr.update_used_time(i); |
|
|
|
} |
|
|
|
if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) { |
|
|
|
if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) { |
|
|
|
auto_evict(); |
|
|
|
} |
|
|
|
auto outputs = OpDef::apply_on_physical_tensor(*path->op, inputs); |
|
|
@@ -531,7 +570,7 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) { |
|
|
|
o->recompute_times ++; |
|
|
|
if (!o->ptr) { |
|
|
|
produce_tensor(o, std::move(outputs[i]), false); |
|
|
|
if (m_worker_state.options.enable_dtr_auto_drop) { |
|
|
|
if (state.options.enable_dtr_auto_drop) { |
|
|
|
m_dtr.update_dsu_after_recompute(o); |
|
|
|
} |
|
|
|
} |
|
|
@@ -540,11 +579,12 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::auto_evict() { |
|
|
|
auto& state = get_worker_state(); |
|
|
|
if (!m_dtr.comp_node.valid()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
size_t current_memory = m_dtr.comp_node.get_used_memory(); |
|
|
|
while (current_memory > m_worker_state.options.dtr_eviction_threshold) { |
|
|
|
while (current_memory > state.options.dtr_eviction_threshold) { |
|
|
|
auto best = m_dtr.find_best_tensor(); |
|
|
|
if (!best) { |
|
|
|
if (!m_dtr.warn_printed) { |
|
|
@@ -592,13 +632,14 @@ bool ChannelImpl::check_available() { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::sync_device_scope(CompNode device) { |
|
|
|
auto& prev = m_worker_state.device_scope_map[device]; |
|
|
|
auto& current = m_worker_state.scopes; |
|
|
|
auto& state = get_worker_state(); |
|
|
|
auto& prev = state.device_scope_map[device]; |
|
|
|
auto& current = state.scopes; |
|
|
|
auto push_scope = [&](std::string name) { |
|
|
|
m_worker_state.profiler->record_device<DeviceBeginScope>(device, name); |
|
|
|
state.profiler->record_device<DeviceBeginScope>(device, name); |
|
|
|
}; |
|
|
|
auto pop_scope = [&](std::string name) { |
|
|
|
m_worker_state.profiler->record_device<DeviceEndScope>(device, name); |
|
|
|
state.profiler->record_device<DeviceEndScope>(device, name); |
|
|
|
}; |
|
|
|
size_t similarity = 0; |
|
|
|
for (size_t i = 0; i < prev.size() && i < current.size(); i++) { |
|
|
@@ -619,16 +660,17 @@ void ChannelImpl::sync_device_scope(CompNode device) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd); |
|
|
|
auto& state = get_worker_state(); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<CommandExecuteEvent>(icmd); |
|
|
|
} |
|
|
|
bool finished = false; |
|
|
|
auto do_finish_command = [&]{ |
|
|
|
if (finished) { |
|
|
|
return; |
|
|
|
} |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<CommandFinishEvent>(icmd); |
|
|
|
} |
|
|
|
finished = true; |
|
|
|
}; |
|
|
@@ -642,7 +684,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
uint64_t apply_id = ++m_last_id; |
|
|
|
SmallVector<TensorPtr> tensor_inputs; |
|
|
|
SmallVector<CompNode> devices; |
|
|
|
if (m_worker_state.options.enable_dtr_auto_drop) { |
|
|
|
if (state.options.enable_dtr_auto_drop) { |
|
|
|
m_dtr.pin(cmd.inputs); |
|
|
|
} |
|
|
|
for (auto i : cmd.inputs) { |
|
|
@@ -660,7 +702,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
} |
|
|
|
// Begin profiling operator |
|
|
|
OpEvent event_data; |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { |
|
|
|
SmallVector<uint64_t> tid; |
|
|
|
for (auto* ptinfo: tinfo) { |
|
|
@@ -689,14 +731,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
// Before wait |
|
|
|
//TODO: split operator wait and execute so that OpWait could be corrected recorded. |
|
|
|
// Before execute |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<HostOpExecuteEvent>(event_data); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<HostOpExecuteEvent>(event_data); |
|
|
|
for (auto&& device: devices) { |
|
|
|
sync_device_scope(device); |
|
|
|
m_worker_state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data); |
|
|
|
state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data); |
|
|
|
} |
|
|
|
} |
|
|
|
if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) { |
|
|
|
if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) { |
|
|
|
auto_evict(); |
|
|
|
} |
|
|
|
// Apply op |
|
|
@@ -704,15 +746,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
auto tensor_outputs = OpDef::apply_on_physical_tensor( |
|
|
|
*cmd.op, std::move(tensor_inputs)); |
|
|
|
// After execute |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<HostOpFinishEvent>(event_data); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<HostOpFinishEvent>(event_data); |
|
|
|
for (auto&& device: devices) { |
|
|
|
m_worker_state.profiler->record_device<DeviceOpFinishEvent>(device, event_data); |
|
|
|
state.profiler->record_device<DeviceOpFinishEvent>(device, event_data); |
|
|
|
} |
|
|
|
} |
|
|
|
// End profiling operator |
|
|
|
double estimate_compute_time = 0; |
|
|
|
if (m_worker_state.options.enable_dtr_auto_drop) { |
|
|
|
if (state.options.enable_dtr_auto_drop) { |
|
|
|
for (auto i : cmd.inputs) { |
|
|
|
estimate_compute_time += i->memory; |
|
|
|
} |
|
|
@@ -735,12 +777,12 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); |
|
|
|
if (m_worker_state.options.enable_dtr_auto_drop) { |
|
|
|
if (state.options.enable_dtr_auto_drop) { |
|
|
|
cmd.outputs[i]->dsu_ptr = std::make_shared<DsuNode>(estimate_compute_time); |
|
|
|
} |
|
|
|
} |
|
|
|
if (m_worker_state.options.enable_drop == 1 |
|
|
|
&& m_worker_state.options.record_computing_path == 1){ |
|
|
|
if (state.options.enable_drop == 1 |
|
|
|
&& state.options.record_computing_path == 1){ |
|
|
|
bool is_inplace = false; |
|
|
|
bool cross_cn = false; |
|
|
|
for (auto input : cmd.inputs) { |
|
|
@@ -774,7 +816,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); |
|
|
|
size_t detach_cnt = 0; |
|
|
|
for (auto output : cmd.outputs) { |
|
|
|
if (!output->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) { |
|
|
|
if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { |
|
|
|
output->detach_producer(); |
|
|
|
detach_cnt ++; |
|
|
|
} |
|
|
@@ -808,21 +850,22 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
} else if constexpr (std::is_same_v<T, Drop>) { |
|
|
|
do_drop(cmd.dest, true); |
|
|
|
} else if constexpr (std::is_same_v<T, SetOption>) { |
|
|
|
m_worker_state.options.set_option(cmd.key, cmd.value); |
|
|
|
state.options.set_option(cmd.key, cmd.value); |
|
|
|
} else if constexpr (std::is_same_v<T, StartProfile>) { |
|
|
|
CompNode::sync_all(); |
|
|
|
m_worker_state.profiler.reset(cmd.profiler); |
|
|
|
state.profiler.reset(cmd.profiler); |
|
|
|
} else if constexpr (std::is_same_v<T, StopProfile>) { |
|
|
|
for (auto&& [device, scopes]: m_worker_state.device_scope_map) { |
|
|
|
for (auto&& [device, scopes]: state.device_scope_map) { |
|
|
|
MGB_MARK_USED_VAR(scopes); |
|
|
|
sync_device_scope(device); |
|
|
|
} |
|
|
|
do_finish_command(); |
|
|
|
auto profiler = std::make_unique<InterpreterProfiler>(); |
|
|
|
std::swap(profiler, m_worker_state.profiler); |
|
|
|
std::swap(profiler, state.profiler); |
|
|
|
auto records = profiler->stop(); |
|
|
|
auto host_map = [this](std::thread::id tid) { |
|
|
|
if (tid == m_worker_state.tid) { |
|
|
|
auto worker_tid = get_worker_tid(); |
|
|
|
auto host_map = [worker_tid](std::thread::id tid) { |
|
|
|
if (tid == worker_tid) { |
|
|
|
return "worker"; |
|
|
|
} else { |
|
|
|
return "unknown"; |
|
|
@@ -830,21 +873,21 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
}; |
|
|
|
InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map); |
|
|
|
} else if constexpr (std::is_same_v<T, PushScope>) { |
|
|
|
m_worker_state.scopes.push_back(cmd.scope_name); |
|
|
|
state.scopes.push_back(cmd.scope_name); |
|
|
|
do_finish_command(); |
|
|
|
m_worker_state.profiler->record_host<WorkerBeginScope>(cmd.scope_name); |
|
|
|
state.profiler->record_host<WorkerBeginScope>(cmd.scope_name); |
|
|
|
} else if constexpr (std::is_same_v<T, PopScope>) { |
|
|
|
mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch"); |
|
|
|
m_worker_state.scopes.pop_back(); |
|
|
|
mgb_assert(state.scopes.back() == cmd.scope_name, "scope name mismatch"); |
|
|
|
state.scopes.pop_back(); |
|
|
|
do_finish_command(); |
|
|
|
m_worker_state.profiler->record_host<WorkerEndScope>(cmd.scope_name); |
|
|
|
state.profiler->record_host<WorkerEndScope>(cmd.scope_name); |
|
|
|
} else { |
|
|
|
static_assert(!std::is_same_v<T, T>); |
|
|
|
} |
|
|
|
}; |
|
|
|
std::visit([&](const auto& cmd){ |
|
|
|
using T = std::decay_t<decltype(cmd)>; |
|
|
|
if (!m_worker_state.options.catch_worker_execption) { |
|
|
|
if (!state.options.catch_worker_execption) { |
|
|
|
cmd_visitor(cmd); |
|
|
|
return; |
|
|
|
} |
|
|
@@ -891,11 +934,12 @@ void ChannelImpl::CommandBuffer::flush() { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::CommandBuffer::flush(Handle pos) { |
|
|
|
auto& state = m_owner->get_channel_state(); |
|
|
|
for (auto iter = m_commands.begin(); iter != pos; ++iter) { |
|
|
|
// mgb_log_debug("%s Flushed", to_string(*iter).c_str()); |
|
|
|
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; |
|
|
|
if (m_owner->m_channel_state.profiler->is_profiling()) { |
|
|
|
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<CommandEnqueueEvent>(icmd); |
|
|
|
} |
|
|
|
m_owner->m_worker.add_task(std::move(icmd)); |
|
|
|
} |
|
|
@@ -903,7 +947,8 @@ void ChannelImpl::CommandBuffer::flush(Handle pos) { |
|
|
|
} |
|
|
|
|
|
|
|
auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { |
|
|
|
return std::visit([this](const auto& cmd) { |
|
|
|
auto& state = m_owner->get_channel_state(); |
|
|
|
return std::visit([&, this](const auto& cmd) { |
|
|
|
using T = std::decay_t<decltype(cmd)>; |
|
|
|
if constexpr (std::is_same_v<T, ApplyOp>) { |
|
|
|
auto* op_type = cmd.op->dyn_typeinfo(); |
|
|
@@ -917,7 +962,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { |
|
|
|
} else if constexpr (std::is_same_v<T, GetValue>) { |
|
|
|
return m_commands.end(); |
|
|
|
} |
|
|
|
size_t buffer_length = m_owner->m_channel_state.options.buffer_length; |
|
|
|
size_t buffer_length = state.options.buffer_length; |
|
|
|
if (m_commands.size() > buffer_length) { |
|
|
|
return m_commands.begin() + (m_commands.size() - buffer_length); |
|
|
|
} |
|
|
@@ -993,42 +1038,54 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) |
|
|
|
|
|
|
|
void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
auto profiler_option = InterpreterProfiler::Option::from_dict(option); |
|
|
|
auto profiler = std::make_unique<InterpreterProfiler>(); |
|
|
|
profiler->set_option(profiler_option); |
|
|
|
profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic)); |
|
|
|
std::swap(profiler, m_channel_state.profiler); |
|
|
|
m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()}); |
|
|
|
std::swap(profiler, state.profiler); |
|
|
|
m_buffer.enqueue(StartProfile{state.profiler.get()}); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::stop_profile(std::string basename, std::string format) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
m_buffer.flush(); |
|
|
|
auto profiler = std::make_unique<InterpreterProfiler>(); |
|
|
|
std::swap(profiler, m_channel_state.profiler); |
|
|
|
std::swap(profiler, state.profiler); |
|
|
|
profiler.release(); |
|
|
|
m_buffer.enqueue(StopProfile{basename, format}); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::push_scope(std::string name) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<ChannelBeginScope>(name); |
|
|
|
m_channel_state.scopes.push_back(name); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
state.profiler->record_host<ChannelBeginScope>(name); |
|
|
|
state.scopes.push_back(name); |
|
|
|
m_buffer.enqueue(PushScope{name}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::pop_scope(std::string name) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); |
|
|
|
m_channel_state.scopes.pop_back(); |
|
|
|
m_channel_state.profiler->record_host<ChannelEndScope>(name); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
if (state.profiler->is_profiling()) { |
|
|
|
mgb_assert((!state.scopes.empty()) && state.scopes.back() == name, "scope name mismatch"); |
|
|
|
state.scopes.pop_back(); |
|
|
|
state.profiler->record_host<ChannelEndScope>(name); |
|
|
|
m_buffer.enqueue(PopScope{name}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::assert_in_channel() { |
|
|
|
mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread"); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::assert_in_worker() { |
|
|
|
mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread"); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) { |
|
|
|
for (auto i : vec) { |
|
|
|
i->pin(); |
|
|
|