|
|
@@ -196,6 +196,10 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
const SmallVector<LogicalTensorDesc>& input_descs, |
|
|
|
SmallVector<Handle>* outputs) { |
|
|
|
auto& state = get_channel_state(); |
|
|
|
|
|
|
|
auto name = op->trait()->make_name(*op); |
|
|
|
state.scopes.push(name); |
|
|
|
|
|
|
|
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); |
|
|
|
RECORD_EVENT(ShapeInferEvent, validated); |
|
|
|
|
|
|
@@ -256,6 +260,8 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
return op_info; |
|
|
|
}; |
|
|
|
RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos)); |
|
|
|
|
|
|
|
state.scopes.pop(name); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::dispatch_kernel( |
|
|
@@ -353,7 +359,6 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
|
|
|
|
HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -364,7 +369,6 @@ HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
|
|
|
|
TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -378,7 +382,6 @@ TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
|
|
|
|
DType ChannelImpl::get_dtype(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -390,7 +393,6 @@ DType ChannelImpl::get_dtype(Handle handle) { |
|
|
|
|
|
|
|
CompNode ChannelImpl::get_device(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -402,7 +404,6 @@ CompNode ChannelImpl::get_device(Handle handle) { |
|
|
|
|
|
|
|
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -411,7 +412,6 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
|
|
|
|
void ChannelImpl::sync() { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto& state = get_channel_state(); |
|
|
|
m_buffer.flush(); |
|
|
|
m_worker.wait_all_task_finish(); |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
@@ -519,7 +519,6 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) { |
|
|
|
|
|
|
|
void ChannelImpl::real_free(TensorInfo* ptr) { |
|
|
|
auto& state = get_worker_state(); |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { |
|
|
|
m_dtr.erase_candidate(ptr); |
|
|
|
} |
|
|
@@ -531,6 +530,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) { |
|
|
|
} |
|
|
|
RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count); |
|
|
|
ptr->status = TensorInfo::Deleted; |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
m_pool.free(ptr); |
|
|
|
} |
|
|
|
|
|
|
@@ -540,12 +540,9 @@ ChannelImpl::~ChannelImpl() { |
|
|
|
close(); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) { |
|
|
|
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { |
|
|
|
auto& state = get_worker_state(); |
|
|
|
std::unique_lock<std::mutex> lock{m_mutex, std::defer_lock}; |
|
|
|
if (notice) { |
|
|
|
lock.lock(); |
|
|
|
} |
|
|
|
MGB_LOCK_GUARD(m_mutex); |
|
|
|
m_dtr.update_used_time(dest); |
|
|
|
RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr()); |
|
|
|
// update tensor desc for static infer |
|
|
@@ -555,12 +552,10 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr |
|
|
|
dest->ptr = std::move(ptr); |
|
|
|
dest->evict_type = EvictType::NONE; |
|
|
|
dest->status = TensorInfo::Produced; |
|
|
|
if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { |
|
|
|
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { |
|
|
|
m_dtr.insert_candidate(dest); |
|
|
|
} |
|
|
|
if (notice) { |
|
|
|
notify_tensor_unsafe(dest); |
|
|
|
} |
|
|
|
notify_tensor_unsafe(dest); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::release_tensor(TensorInfo* dest) { |
|
|
@@ -781,6 +776,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { |
|
|
|
} |
|
|
|
if (!value_fetching) { |
|
|
|
m_buffer.enqueue(GetValue{info}); |
|
|
|
m_buffer.flush(); |
|
|
|
value_fetching = true; |
|
|
|
} |
|
|
|
return false; |
|
|
@@ -789,16 +785,12 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { |
|
|
|
} |
|
|
|
}); |
|
|
|
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr); |
|
|
|
if (m_waitee != nullptr) { |
|
|
|
mgb_assert(m_waitee == info, "waitee mismatch"); |
|
|
|
m_waitee = nullptr; |
|
|
|
} |
|
|
|
m_waitee = nullptr; |
|
|
|
return info->ptr; |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) { |
|
|
|
if (info == m_waitee) { |
|
|
|
m_waitee = nullptr; |
|
|
|
RECORD_EVENT(TensorNotifyPropEvent, info->id); |
|
|
|
m_cv.notify_all(); |
|
|
|
} |
|
|
@@ -809,7 +801,6 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() { |
|
|
|
for (auto* handle: m_valid_handle) { |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
valid_tensors.insert(info); |
|
|
|
//TODO: valid_tensors.insert({info, info->status}); |
|
|
|
} |
|
|
|
return valid_tensors; |
|
|
|
} |
|
|
@@ -1005,7 +996,6 @@ void ChannelImpl::CommandBuffer::flush() { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::CommandBuffer::flush(Handle pos) { |
|
|
|
auto& state = m_owner->get_channel_state(); |
|
|
|
for (auto iter = m_commands.begin(); iter != pos; ++iter) { |
|
|
|
if (Profiler::is_profiling()) { |
|
|
|
mgb_log_debug("%s Flushed", to_string(*iter).c_str()); |
|
|
|