Browse Source

fix(mgb/opr): fix take CpuDispatchableBase::EventImpl as CpuEventImpl

GitOrigin-RevId: 07aa850837
release-1.1
Megvii Engine Team 4 years ago
parent
commit
10106341fe
2 changed files with 49 additions and 24 deletions
  1. +38
    -11
      src/core/impl/comp_node/cpu/comp_node.cpp
  2. +11
    -13
      src/core/impl/comp_node/cpu/comp_node.h

+ 38
- 11
src/core/impl/comp_node/cpu/comp_node.cpp View File

@@ -351,6 +351,7 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {

//! used during comp node seq rec
class CompSeqRecEventImpl;
class CpuEventImpl;

SeqRecorderImpl* m_cur_recorder = nullptr;
std::mutex m_cur_recorder_mtx;
@@ -633,6 +634,42 @@ public:
using EventImpl::EventImpl;
};

class CpuCompNodeImpl::CpuEventImpl final
: public CpuDispatchableBase::EventImpl {
#if MGB_HAVE_THREAD
void host_wait_cv() override {
for (size_t i = 0, it = SCQueueSynchronizer::max_spin() / 20; i < it;
++i) {
if (finished()) {
auto thread_pool =
static_cast<CpuCompNodeImpl*>(m_comp_node_impl)
->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
return;
}
}
m_dev_wait_nr_waiter.fetch_add(1, std::memory_order_release);
for (;;) {
std::unique_lock<std::mutex> lock{m_dev_wait_mtx};
if (finished()) {
break;
}
m_dev_wait_cv.wait(lock);
}
m_dev_wait_nr_waiter.fetch_sub(1, std::memory_order_release);
auto thread_pool = static_cast<CpuCompNodeImpl*>(m_comp_node_impl)
->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
}
#endif
public:
using EventImpl::EventImpl;
};

std::unique_ptr<CompNode::Event> CpuCompNodeImpl::create_event(size_t flags) {
if (m_worker_queue) {
m_worker_queue->check_exception();
@@ -640,7 +677,7 @@ std::unique_ptr<CompNode::Event> CpuCompNodeImpl::create_event(size_t flags) {
if (m_cur_recorder) {
return std::make_unique<CompSeqRecEventImpl>(this, flags);
} else {
return std::make_unique<EventImpl>(this, flags);
return std::make_unique<CpuEventImpl>(this, flags);
}
}

@@ -921,11 +958,6 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() {
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
for (size_t i = 0, it = SCQueueSynchronizer::max_spin() / 20; i < it; ++i) {
if (finished()) {
auto thread_pool = static_cast<CpuCompNodeImpl*>(m_comp_node_impl)
->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
return;
}
}
@@ -939,11 +971,6 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
m_dev_wait_cv.wait(lock);
}
m_dev_wait_nr_waiter.fetch_sub(1, std::memory_order_release);
auto thread_pool =
static_cast<CpuCompNodeImpl*>(m_comp_node_impl)->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
}

CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept {


+ 11
- 13
src/core/impl/comp_node/cpu/comp_node.h View File

@@ -64,9 +64,8 @@ namespace mgb {

//! implement Event on CpuDispatchableBase comp nodes
class CpuCompNode::CpuDispatchableBase::EventImpl: public EventImplHelper {
protected:
TimeSpec m_prev_finish_time;

#if MGB_HAVE_THREAD
std::atomic_size_t
m_record_nr_req{0}, m_record_nr_finish{0},
@@ -83,22 +82,21 @@ namespace mgb {

void host_wait_cv() override;

protected:
void do_record() override;
void do_record() override;

//! incr m_record_nr_req; this is used in do_record()
void incr_nr_req() {
//! incr m_record_nr_req; this is used in do_record()
void incr_nr_req() {
#if MGB_HAVE_THREAD
m_record_nr_req.fetch_add(1, std::memory_order_relaxed);
m_record_nr_req.fetch_add(1, std::memory_order_relaxed);
#endif
}
}

//! callback to be dispatched to comp node
void on_finish();
//! callback to be dispatched to comp node
void on_finish();

public:
using EventImplHelper::EventImplHelper;
~EventImpl() noexcept;
public:
using EventImplHelper::EventImplHelper;
~EventImpl() noexcept;
};
}



Loading…
Cancel
Save