|
|
@@ -51,7 +51,7 @@ void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) { |
|
|
|
class CpuCompNode::WorkerQueue final |
|
|
|
: public AsyncQueueSC<TaskElem, WorkerQueue> { |
|
|
|
const Locator m_locator; |
|
|
|
ThreadPool* m_thread_pool = nullptr; |
|
|
|
std::shared_ptr<ThreadPool> m_thread_pool = nullptr; |
|
|
|
|
|
|
|
void on_async_queue_worker_thread_start() override { |
|
|
|
mgb_assert(m_locator.device >= 0); |
|
|
@@ -74,7 +74,7 @@ public: |
|
|
|
|
|
|
|
explicit WorkerQueue(Locator locator) : m_locator(locator) {} |
|
|
|
|
|
|
|
void attach_thread_pool(ThreadPool* thread_pool) { |
|
|
|
void attach_thread_pool(std::shared_ptr<ThreadPool> thread_pool) { |
|
|
|
m_thread_pool = thread_pool; |
|
|
|
} |
|
|
|
|
|
|
@@ -92,7 +92,7 @@ public: |
|
|
|
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z; |
|
|
|
} |
|
|
|
|
|
|
|
ThreadPool* get_thread_pool() { return m_thread_pool; } |
|
|
|
ThreadPool* get_thread_pool() { return m_thread_pool.get(); } |
|
|
|
}; |
|
|
|
|
|
|
|
class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { |
|
|
@@ -102,7 +102,7 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { |
|
|
|
SeqRecorderImpl** const m_self_pointer; |
|
|
|
|
|
|
|
std::vector<TaskElem> m_tasks; |
|
|
|
ThreadPool* m_thread_pool = nullptr; |
|
|
|
std::shared_ptr<ThreadPool> m_thread_pool = nullptr; |
|
|
|
const CompNode m_record_compnode; |
|
|
|
/*! |
|
|
|
* \brief use to check the all ther recording tasks are its self CompNode |
|
|
@@ -118,7 +118,8 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { |
|
|
|
} |
|
|
|
|
|
|
|
public: |
|
|
|
SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool, |
|
|
|
SeqRecorderImpl(SeqRecorderImpl** self_pointer, |
|
|
|
std::shared_ptr<ThreadPool> thread_pool, |
|
|
|
const CompNode& comp_node) |
|
|
|
: m_self_pointer{self_pointer}, |
|
|
|
m_thread_pool{thread_pool}, |
|
|
@@ -239,7 +240,7 @@ public: |
|
|
|
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z; |
|
|
|
} |
|
|
|
|
|
|
|
ThreadPool* get_thread_pool() { return m_thread_pool; } |
|
|
|
ThreadPool* get_thread_pool() { return m_thread_pool.get(); } |
|
|
|
}; |
|
|
|
|
|
|
|
using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl; |
|
|
@@ -404,14 +405,14 @@ public: |
|
|
|
//! implementation of InplaceCPUDispatcher |
|
|
|
class InplaceCPUDispatcher final : public CPUDispatcher { |
|
|
|
std::atomic_size_t m_nr_task{0}; |
|
|
|
ThreadPool* m_thread_pool = nullptr; |
|
|
|
std::shared_ptr<ThreadPool> m_thread_pool = nullptr; |
|
|
|
//! InplaceCPUDispatcher may used by both type of compnodes, so |
|
|
|
//! m_comp_node's type should be base class. |
|
|
|
CompNodeBaseImpl* const m_comp_node; |
|
|
|
|
|
|
|
public: |
|
|
|
InplaceCPUDispatcher(CompNodeBaseImpl* comp_node, |
|
|
|
ThreadPool* thread_pool = nullptr) |
|
|
|
std::shared_ptr<ThreadPool> thread_pool = nullptr) |
|
|
|
: m_thread_pool(thread_pool), m_comp_node(comp_node) {} |
|
|
|
|
|
|
|
void dispatch(Task&& task) override { |
|
|
@@ -558,7 +559,7 @@ CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr = |
|
|
|
//! ==================== CompNodeRecorderImpl ====================== |
|
|
|
class CpuCompNode::CompNodeRecorderImpl final : public CompNodeBaseImpl { |
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_DECL; |
|
|
|
std::unique_ptr<ThreadPool> m_thread_pool; |
|
|
|
std::shared_ptr<ThreadPool> m_thread_pool; |
|
|
|
std::shared_ptr<WorkerQueue> m_worker_queue; |
|
|
|
|
|
|
|
//! used during comp node seq rec |
|
|
@@ -629,7 +630,7 @@ public: |
|
|
|
m_worker_queue(worker_queue) { |
|
|
|
auto cn = make_comp_node_from_impl(this); |
|
|
|
if (locator.type == DeviceType::MULTITHREAD) { |
|
|
|
m_thread_pool = std::unique_ptr<ThreadPool>( |
|
|
|
m_thread_pool = std::shared_ptr<ThreadPool>( |
|
|
|
new ThreadPool(static_cast<size_t>(locator.nr_threads))); |
|
|
|
mgb_assert(m_thread_pool, "ThradPool create failed"); |
|
|
|
} |
|
|
@@ -645,10 +646,10 @@ public: |
|
|
|
} else if (locator.type == DeviceType::MULTITHREAD) { |
|
|
|
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { |
|
|
|
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>( |
|
|
|
this, m_thread_pool.get())}, |
|
|
|
this, m_thread_pool)}, |
|
|
|
cn); |
|
|
|
} else { |
|
|
|
m_worker_queue->attach_thread_pool(m_thread_pool.get()); |
|
|
|
m_worker_queue->attach_thread_pool(m_thread_pool); |
|
|
|
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>( |
|
|
|
m_worker_queue, this)}, |
|
|
|
cn); |
|
|
@@ -807,7 +808,7 @@ public: |
|
|
|
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder( |
|
|
|
cg::ComputingGraph*) override { |
|
|
|
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder, |
|
|
|
m_thread_pool.get(), this); |
|
|
|
m_thread_pool, this); |
|
|
|
} |
|
|
|
|
|
|
|
SeqRecorderImpl* cur_recorder() const override { return sm_cur_recorder; } |
|
|
|