|
|
@@ -52,7 +52,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { |
|
|
|
info->desc.layout = data.layout(); |
|
|
|
info->desc.comp_node = data.comp_node(); |
|
|
|
info->ptr = Tensor::make(data); |
|
|
|
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node); |
|
|
|
} |
|
|
|
return info; |
|
|
|
} |
|
|
|
|
|
|
@@ -147,8 +149,9 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
return tid; |
|
|
|
}; |
|
|
|
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; |
|
|
|
|
|
|
|
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data); |
|
|
|
} |
|
|
|
|
|
|
|
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); |
|
|
|
|
|
|
@@ -169,8 +172,9 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
} |
|
|
|
|
|
|
|
event_data.outputs = tinfo_to_tid(output_infos); |
|
|
|
|
|
|
|
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::dispatch_kernel( |
|
|
@@ -267,13 +271,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
m_waitee = info; |
|
|
|
regenerate(info); |
|
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue); |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
tensor_ptr = info->ptr; |
|
|
|
return value_fetched(); |
|
|
|
}); |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue); |
|
|
|
} |
|
|
|
m_waitee = nullptr; |
|
|
|
} |
|
|
|
return tensor_ptr->get_value(); |
|
|
@@ -290,12 +298,16 @@ TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
mgb_assert(!m_waitee); |
|
|
|
m_waitee = info; |
|
|
|
m_buffer.flush(); |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape); |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return static_cast<bool>(info->ptr); |
|
|
|
}); |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape); |
|
|
|
} |
|
|
|
m_waitee = nullptr; |
|
|
|
TensorShape ret = info->ptr->layout(); |
|
|
|
mgb_assert(ret.ndim != 0); |
|
|
@@ -306,7 +318,9 @@ DType ChannelImpl::get_dtype(Handle handle) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType); |
|
|
|
} |
|
|
|
auto ret = info->desc.layout.dtype; |
|
|
|
mgb_assert(ret.valid()); |
|
|
|
return ret; |
|
|
@@ -316,7 +330,9 @@ CompNode ChannelImpl::get_device(Handle handle) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device); |
|
|
|
} |
|
|
|
auto ret = info->desc.comp_node; |
|
|
|
mgb_assert(ret.valid()); |
|
|
|
return ret; |
|
|
@@ -331,22 +347,30 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
m_waitee = info; |
|
|
|
regenerate(info); |
|
|
|
m_buffer.flush(); |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue); |
|
|
|
} |
|
|
|
m_cv.wait(lock, [&]() { |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
return static_cast<bool>(info->ptr); |
|
|
|
}); |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue); |
|
|
|
} |
|
|
|
m_waitee = nullptr; |
|
|
|
return info->ptr->dev_tensor(); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::sync() { |
|
|
|
m_buffer.flush(); |
|
|
|
m_channel_state.profiler->record_host<SyncStartEvent>(); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<SyncStartEvent>(); |
|
|
|
} |
|
|
|
m_worker.wait_all_task_finish(); |
|
|
|
CompNode::sync_all(); |
|
|
|
m_channel_state.profiler->record_host<SyncFinishEvent>(); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<SyncFinishEvent>(); |
|
|
|
} |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
check_worker_exc_unsafe(); |
|
|
|
} |
|
|
@@ -369,13 +393,17 @@ TensorInfo* ChannelImpl::alloc() { |
|
|
|
auto info = m_pool.alloc(); |
|
|
|
m_valid_handle.insert(info); |
|
|
|
info->id = m_last_id++; |
|
|
|
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id); |
|
|
|
} |
|
|
|
return info; |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::free(TensorInfo* ptr) { |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id); |
|
|
|
} |
|
|
|
m_pool.free(ptr); |
|
|
|
} |
|
|
|
|
|
|
@@ -389,7 +417,9 @@ ChannelImpl::~ChannelImpl() { |
|
|
|
|
|
|
|
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node()); |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node()); |
|
|
|
} |
|
|
|
dest->value_fetched = ptr->value_fetched(); |
|
|
|
// update tensor desc for static infer |
|
|
|
dest->desc.layout = ptr->layout(); |
|
|
@@ -471,13 +501,17 @@ void ChannelImpl::sync_device_scope(CompNode device) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd); |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd); |
|
|
|
} |
|
|
|
bool finished = false; |
|
|
|
auto do_finish_command = [&]{ |
|
|
|
if (finished) { |
|
|
|
return; |
|
|
|
} |
|
|
|
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd); |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd); |
|
|
|
} |
|
|
|
finished = true; |
|
|
|
}; |
|
|
|
//TODO: remove std::visit for support osx 10.12 |
|
|
@@ -498,22 +532,25 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
tensor_inputs.push_back(i->ptr); |
|
|
|
} |
|
|
|
// Begin profiling operator |
|
|
|
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { |
|
|
|
SmallVector<uint64_t> tid; |
|
|
|
for (auto* ptinfo: tinfo) { |
|
|
|
tid.push_back(ptinfo->id); |
|
|
|
OpEvent event_data; |
|
|
|
if (m_worker_state.profiler->is_profiling()) { |
|
|
|
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { |
|
|
|
SmallVector<uint64_t> tid; |
|
|
|
for (auto* ptinfo: tinfo) { |
|
|
|
tid.push_back(ptinfo->id); |
|
|
|
} |
|
|
|
return tid; |
|
|
|
}; |
|
|
|
event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)}; |
|
|
|
// Collecting devices |
|
|
|
for (auto i : cmd.inputs) { |
|
|
|
devices.push_back(i->desc.comp_node); |
|
|
|
} |
|
|
|
return tid; |
|
|
|
}; |
|
|
|
OpEvent event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)}; |
|
|
|
// Collecting devices |
|
|
|
for (auto i : cmd.inputs) { |
|
|
|
devices.push_back(i->desc.comp_node); |
|
|
|
} |
|
|
|
for (auto i : cmd.outputs) { |
|
|
|
devices.push_back(i->desc.comp_node); |
|
|
|
for (auto i : cmd.outputs) { |
|
|
|
devices.push_back(i->desc.comp_node); |
|
|
|
} |
|
|
|
devices.erase(std::unique(devices.begin(), devices.end()), devices.end()); |
|
|
|
} |
|
|
|
devices.erase(std::unique(devices.begin(), devices.end()), devices.end()); |
|
|
|
// Fused by command buffer. @see: CommandBuffer::fuse_del |
|
|
|
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. |
|
|
|
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. |
|
|
@@ -643,7 +680,7 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) { |
|
|
|
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) { |
|
|
|
return; |
|
|
|
} |
|
|
|
mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); |
|
|
|
// mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); |
|
|
|
m_commands.push_back(std::move(cmd)); |
|
|
|
auto flush_pos = flush_pos_for(m_commands.back()); |
|
|
|
flush(flush_pos); |
|
|
@@ -655,9 +692,11 @@ void ChannelImpl::CommandBuffer::flush() { |
|
|
|
|
|
|
|
void ChannelImpl::CommandBuffer::flush(Handle pos) { |
|
|
|
for (auto iter = m_commands.begin(); iter != pos; ++iter) { |
|
|
|
mgb_log_debug("%s Flushed", to_string(*iter).c_str()); |
|
|
|
// mgb_log_debug("%s Flushed", to_string(*iter).c_str()); |
|
|
|
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; |
|
|
|
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd); |
|
|
|
if (m_owner->m_channel_state.profiler->is_profiling()) { |
|
|
|
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd); |
|
|
|
} |
|
|
|
m_owner->m_worker.add_task(std::move(icmd)); |
|
|
|
} |
|
|
|
m_commands.erase(m_commands.begin(), pos); |
|
|
@@ -705,7 +744,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { |
|
|
|
if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); |
|
|
|
// mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); |
|
|
|
std::get<ApplyOp>(*apply_iter).dels.push_back(dest); |
|
|
|
return true; |
|
|
|
} |
|
|
@@ -771,16 +810,20 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::push_scope(std::string name) { |
|
|
|
m_channel_state.profiler->record_host<ChannelBeginScope>(name); |
|
|
|
m_channel_state.scopes.push_back(name); |
|
|
|
m_buffer.enqueue(PushScope{name}); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<ChannelBeginScope>(name); |
|
|
|
m_channel_state.scopes.push_back(name); |
|
|
|
m_buffer.enqueue(PushScope{name}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::pop_scope(std::string name) { |
|
|
|
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); |
|
|
|
m_channel_state.scopes.pop_back(); |
|
|
|
m_channel_state.profiler->record_host<ChannelEndScope>(name); |
|
|
|
m_buffer.enqueue(PopScope{name}); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); |
|
|
|
m_channel_state.scopes.pop_back(); |
|
|
|
m_channel_state.profiler->record_host<ChannelEndScope>(name); |
|
|
|
m_buffer.enqueue(PopScope{name}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::assert_in_channel() { |
|
|
|