Browse Source

fix(interpreter): remove notice flag of produce_tensor

GitOrigin-RevId: ed65d0107f
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
bd62a0a647
4 changed files with 20 additions and 24 deletions
  1. +13
    -23
      imperative/src/impl/interpreter/interpreter_impl.cpp
  2. +1
    -1
      imperative/src/impl/interpreter/interpreter_impl.h
  3. +4
    -0
      imperative/src/impl/profiler/memory_chunk.cpp
  4. +2
    -0
      imperative/src/include/megbrain/imperative/profiler.h

+ 13
- 23
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -196,6 +196,10 @@ void ChannelImpl::dispatch_default_cpu(
const SmallVector<LogicalTensorDesc>& input_descs, const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) { SmallVector<Handle>* outputs) {
auto& state = get_channel_state(); auto& state = get_channel_state();

auto name = op->trait()->make_name(*op);
state.scopes.push(name);

auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
RECORD_EVENT(ShapeInferEvent, validated); RECORD_EVENT(ShapeInferEvent, validated);


@@ -256,6 +260,8 @@ void ChannelImpl::dispatch_default_cpu(
return op_info; return op_info;
}; };
RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos)); RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));

state.scopes.pop(name);
} }


void ChannelImpl::dispatch_kernel( void ChannelImpl::dispatch_kernel(
@@ -353,7 +359,6 @@ SmallVector<Handle> ChannelImpl::apply_op(


HostTensorND ChannelImpl::get_value(Handle handle) { HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -364,7 +369,6 @@ HostTensorND ChannelImpl::get_value(Handle handle) {


TensorShape ChannelImpl::get_shape(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -378,7 +382,6 @@ TensorShape ChannelImpl::get_shape(Handle handle) {


DType ChannelImpl::get_dtype(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -390,7 +393,6 @@ DType ChannelImpl::get_dtype(Handle handle) {


CompNode ChannelImpl::get_device(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -402,7 +404,6 @@ CompNode ChannelImpl::get_device(Handle handle) {


DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -411,7 +412,6 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {


void ChannelImpl::sync() { void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
m_buffer.flush(); m_buffer.flush();
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
@@ -519,7 +519,6 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) {


void ChannelImpl::real_free(TensorInfo* ptr) { void ChannelImpl::real_free(TensorInfo* ptr) {
auto& state = get_worker_state(); auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex);
if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(ptr); m_dtr.erase_candidate(ptr);
} }
@@ -531,6 +530,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
} }
RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count); RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
ptr->status = TensorInfo::Deleted; ptr->status = TensorInfo::Deleted;
MGB_LOCK_GUARD(m_mutex);
m_pool.free(ptr); m_pool.free(ptr);
} }


@@ -540,12 +540,9 @@ ChannelImpl::~ChannelImpl() {
close(); close();
} }


void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
auto& state = get_worker_state(); auto& state = get_worker_state();
std::unique_lock<std::mutex> lock{m_mutex, std::defer_lock};
if (notice) {
lock.lock();
}
MGB_LOCK_GUARD(m_mutex);
m_dtr.update_used_time(dest); m_dtr.update_used_time(dest);
RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr()); RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
// update tensor desc for static infer // update tensor desc for static infer
@@ -555,12 +552,10 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr
dest->ptr = std::move(ptr); dest->ptr = std::move(ptr);
dest->evict_type = EvictType::NONE; dest->evict_type = EvictType::NONE;
dest->status = TensorInfo::Produced; dest->status = TensorInfo::Produced;
if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.insert_candidate(dest); m_dtr.insert_candidate(dest);
} }
if (notice) {
notify_tensor_unsafe(dest);
}
notify_tensor_unsafe(dest);
} }


void ChannelImpl::release_tensor(TensorInfo* dest) { void ChannelImpl::release_tensor(TensorInfo* dest) {
@@ -781,6 +776,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
} }
if (!value_fetching) { if (!value_fetching) {
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
m_buffer.flush();
value_fetching = true; value_fetching = true;
} }
return false; return false;
@@ -789,16 +785,12 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
} }
}); });
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr); RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr);
if (m_waitee != nullptr) {
mgb_assert(m_waitee == info, "waitee mismatch");
m_waitee = nullptr;
}
m_waitee = nullptr;
return info->ptr; return info->ptr;
} }


void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) { void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
if (info == m_waitee) { if (info == m_waitee) {
m_waitee = nullptr;
RECORD_EVENT(TensorNotifyPropEvent, info->id); RECORD_EVENT(TensorNotifyPropEvent, info->id);
m_cv.notify_all(); m_cv.notify_all();
} }
@@ -809,7 +801,6 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
for (auto* handle: m_valid_handle) { for (auto* handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
valid_tensors.insert(info); valid_tensors.insert(info);
//TODO: valid_tensors.insert({info, info->status});
} }
return valid_tensors; return valid_tensors;
} }
@@ -1005,7 +996,6 @@ void ChannelImpl::CommandBuffer::flush() {
} }


void ChannelImpl::CommandBuffer::flush(Handle pos) { void ChannelImpl::CommandBuffer::flush(Handle pos) {
auto& state = m_owner->get_channel_state();
for (auto iter = m_commands.begin(); iter != pos; ++iter) { for (auto iter = m_commands.begin(); iter != pos; ++iter) {
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str()); mgb_log_debug("%s Flushed", to_string(*iter).c_str());


+ 1
- 1
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -91,7 +91,7 @@ private:


void check_worker_exc_unsafe(); void check_worker_exc_unsafe();


void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice);
void produce_tensor(TensorInfo* dest, TensorPtr ptr);


void release_tensor(TensorInfo* dest); void release_tensor(TensorInfo* dest);




+ 4
- 0
imperative/src/impl/profiler/memory_chunk.cpp View File

@@ -103,6 +103,7 @@ struct MemoryFlow {
auto addr_begin = std::numeric_limits<uintptr_t>::max(); auto addr_begin = std::numeric_limits<uintptr_t>::max();
auto addr_end = std::numeric_limits<uintptr_t>::min(); auto addr_end = std::numeric_limits<uintptr_t>::min();
for(auto&& [id, chunk]: chunks) { for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
addr_begin = std::min(addr_begin, chunk.address[0]); addr_begin = std::min(addr_begin, chunk.address[0]);
addr_end = std::max(addr_end, chunk.address[1]); addr_end = std::max(addr_end, chunk.address[1]);
@@ -114,6 +115,7 @@ struct MemoryFlow {
auto time_begin = std::numeric_limits<uint64_t>::max(); auto time_begin = std::numeric_limits<uint64_t>::max();
auto time_end = std::numeric_limits<uint64_t>::min(); auto time_end = std::numeric_limits<uint64_t>::min();
for(auto&& [id, chunk]: chunks) { for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
time_begin = std::min(time_begin, chunk.time[0]); time_begin = std::min(time_begin, chunk.time[0]);
time_end = std::max(time_end, chunk.time[1]); time_end = std::max(time_end, chunk.time[1]);
@@ -124,6 +126,7 @@ struct MemoryFlow {
std::shared_ptr<json::Array> to_json() const { std::shared_ptr<json::Array> to_json() const {
auto results = json::Array::make(); auto results = json::Array::make();
for(auto&& [id, chunk]: chunks) { for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
auto address = json::Array::make(); auto address = json::Array::make();
auto time = json::Array::make(); auto time = json::Array::make();
@@ -213,6 +216,7 @@ struct MemoryFlow {
return builder; return builder;
}; };
for (auto&& [id, chunk]: chunks) { for (auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
double left = (chunk.time[0]-time_begin)/time_scale; double left = (chunk.time[0]-time_begin)/time_scale;
double right = (chunk.time[1]-time_begin)/time_scale; double right = (chunk.time[1]-time_begin)/time_scale;


+ 2
- 0
imperative/src/include/megbrain/imperative/profiler.h View File

@@ -131,6 +131,7 @@ public:
MGB_LOCK_GUARD(sm_mutex); MGB_LOCK_GUARD(sm_mutex);
if constexpr (sm_debug) { if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) { for (auto&& [tid, profiler]: sm_profilers) {
MGB_MARK_USED_VAR(tid);
Status expected = Running; Status expected = Running;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Collecting)); mgb_assert(profiler->m_status.compare_exchange_strong(expected, Collecting));
} }
@@ -149,6 +150,7 @@ public:
}); });
if constexpr (sm_debug) { if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) { for (auto&& [tid, profiler]: sm_profilers) {
MGB_MARK_USED_VAR(tid);
Status expected = Collecting; Status expected = Collecting;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Running)); mgb_assert(profiler->m_status.compare_exchange_strong(expected, Running));
} }


Loading…
Cancel
Save