|
|
@@ -33,6 +33,7 @@ Interpreter& Interpreter::inst() { |
|
|
|
} |
|
|
|
|
|
|
|
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto info = alloc(); |
|
|
|
info->desc.layout = value.layout(); |
|
|
|
info->desc.comp_node = value.comp_node(); |
|
|
@@ -47,6 +48,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { |
|
|
|
} |
|
|
|
|
|
|
|
Handle ChannelImpl::put(const DeviceTensorND& data) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto info = alloc(); |
|
|
|
info->desc.layout = data.layout(); |
|
|
|
info->desc.comp_node = data.comp_node(); |
|
|
@@ -58,6 +60,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::del(Handle handle) { |
|
|
|
if (!check_available()){ |
|
|
|
return; |
|
|
|
} |
|
|
|
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle); |
|
|
|
auto* info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
m_valid_handle.erase(handle); |
|
|
@@ -65,6 +70,7 @@ void ChannelImpl::del(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::swap_in(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_worker_state.options.enable_swap) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
@@ -74,6 +80,7 @@ void ChannelImpl::swap_in(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::swap_out(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_worker_state.options.enable_swap) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
@@ -83,6 +90,7 @@ void ChannelImpl::swap_out(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::drop(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_worker_state.options.enable_drop) { |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
@@ -201,6 +209,7 @@ void ChannelImpl::dispatch_kernel( |
|
|
|
SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
std::shared_ptr<OpDef> op, |
|
|
|
const SmallVector<Handle>& inputs) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
for (auto i : inputs) { |
|
|
|
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", i); |
|
|
@@ -237,6 +246,7 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
} |
|
|
|
|
|
|
|
HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
// TODO: maybe get_value should be done on host. i.e. delete GetValue |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
@@ -269,6 +279,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -296,6 +307,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
DType ChannelImpl::get_dtype(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -308,6 +320,7 @@ DType ChannelImpl::get_dtype(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
CompNode ChannelImpl::get_device(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -320,6 +333,7 @@ CompNode ChannelImpl::get_device(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), |
|
|
|
"invalid handle: %p", handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
@@ -342,6 +356,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::sync() { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
m_buffer.flush(); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<SyncStartEvent>(); |
|
|
@@ -356,14 +371,26 @@ void ChannelImpl::sync() { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::close() { |
|
|
|
if (!check_available()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end()); |
|
|
|
for (auto* handle: valid_handles) { |
|
|
|
del(handle); |
|
|
|
} |
|
|
|
mgb_assert(m_valid_handle.empty()); |
|
|
|
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size()); |
|
|
|
sync(); |
|
|
|
m_closed = true; |
|
|
|
} |
|
|
|
|
|
|
|
size_t ChannelImpl::get_option(std::string name) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
return m_channel_state.options.get_option(name); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::set_option(std::string name, size_t value) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
m_channel_state.options.set_option(name, value); |
|
|
|
m_buffer.enqueue(SetOption{name, value}); |
|
|
|
} |
|
|
@@ -440,9 +467,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) { |
|
|
|
m_pool.free(ptr); |
|
|
|
} |
|
|
|
|
|
|
|
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){ |
|
|
|
m_channel_state.tid = std::this_thread::get_id(); |
|
|
|
} |
|
|
|
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){} |
|
|
|
|
|
|
|
ChannelImpl::~ChannelImpl() { |
|
|
|
close(); |
|
|
@@ -562,6 +587,10 @@ void ChannelImpl::detach_users(TensorInfo* dest) { |
|
|
|
//dest->users.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
bool ChannelImpl::check_available() { |
|
|
|
return !m_closed; |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::sync_device_scope(CompNode device) { |
|
|
|
auto& prev = m_worker_state.device_scope_map[device]; |
|
|
|
auto& current = m_worker_state.scopes; |
|
|
@@ -786,9 +815,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { |
|
|
|
std::swap(profiler, m_worker_state.profiler); |
|
|
|
auto records = profiler->stop(); |
|
|
|
auto host_map = [this](std::thread::id tid) { |
|
|
|
if (tid == m_channel_state.tid) { |
|
|
|
return "channel"; |
|
|
|
} else if (tid == m_worker_state.tid) { |
|
|
|
if (tid == m_worker_state.tid) { |
|
|
|
return "worker"; |
|
|
|
} else { |
|
|
|
return "unknown"; |
|
|
@@ -959,6 +986,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto profiler_option = InterpreterProfiler::Option::from_dict(option); |
|
|
|
auto profiler = std::make_unique<InterpreterProfiler>(); |
|
|
|
profiler->set_option(profiler_option); |
|
|
@@ -968,6 +996,7 @@ void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::stop_profile(std::string basename, std::string format) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
m_buffer.flush(); |
|
|
|
auto profiler = std::make_unique<InterpreterProfiler>(); |
|
|
|
std::swap(profiler, m_channel_state.profiler); |
|
|
@@ -976,6 +1005,7 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::push_scope(std::string name) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
m_channel_state.profiler->record_host<ChannelBeginScope>(name); |
|
|
|
m_channel_state.scopes.push_back(name); |
|
|
@@ -984,6 +1014,7 @@ void ChannelImpl::push_scope(std::string name) { |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::pop_scope(std::string name) { |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
if (m_channel_state.profiler->is_profiling()) { |
|
|
|
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); |
|
|
|
m_channel_state.scopes.pop_back(); |
|
|
@@ -992,14 +1023,6 @@ void ChannelImpl::pop_scope(std::string name) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::assert_in_channel() { |
|
|
|
mgb_assert(m_channel_state.tid != std::this_thread::get_id()); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::assert_in_worker() { |
|
|
|
mgb_assert(m_worker_state.tid == std::this_thread::get_id()); |
|
|
|
} |
|
|
|
|
|
|
|
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) { |
|
|
|
for (auto i : vec) { |
|
|
|
i->pin(); |
|
|
|