|
|
@@ -102,17 +102,23 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { |
|
|
|
bool m_fake_exec = false, m_synchronized = false, m_stopped = false, |
|
|
|
m_first_replay = true; |
|
|
|
SeqRecorderImpl** const m_self_pointer; |
|
|
|
std::mutex* const m_self_pointer_mtx; |
|
|
|
|
|
|
|
std::vector<TaskElem> m_tasks; |
|
|
|
ThreadPool* m_thread_pool = nullptr; |
|
|
|
const CompNode m_record_compnode; |
|
|
|
|
|
|
|
/*! |
|
|
|
* \brief use to check the all ther recording tasks are its self CompNode |
|
|
|
* related task, void hook other CompNode related task to the recorder. |
|
|
|
*/ |
|
|
|
void check_the_same_comp_node(const CompNode& comp_node) const; |
|
|
|
|
|
|
|
public: |
|
|
|
SeqRecorderImpl(SeqRecorderImpl** self_pointer, |
|
|
|
std::mutex* const self_pointer_mtx, ThreadPool* thread_pool) |
|
|
|
SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool, |
|
|
|
const CompNode& comp_node) |
|
|
|
: m_self_pointer{self_pointer}, |
|
|
|
m_self_pointer_mtx{self_pointer_mtx}, |
|
|
|
m_thread_pool{thread_pool} { |
|
|
|
m_thread_pool{thread_pool}, |
|
|
|
m_record_compnode{comp_node} { |
|
|
|
mgb_assert(!*m_self_pointer); |
|
|
|
*m_self_pointer = this; |
|
|
|
} |
|
|
@@ -123,23 +129,25 @@ public: |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void enter_fake_exec() override { |
|
|
|
void enter_fake_exec(const CompNode& comp_node) override { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
mgb_assert(!m_stopped && !m_fake_exec); |
|
|
|
m_fake_exec = true; |
|
|
|
} |
|
|
|
|
|
|
|
void exit_fake_exec() override { |
|
|
|
void exit_fake_exec(const CompNode& comp_node) override { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
mgb_assert(!m_stopped && m_fake_exec); |
|
|
|
mgb_assert(m_tasks.empty()); |
|
|
|
m_fake_exec = false; |
|
|
|
m_synchronized = false; |
|
|
|
} |
|
|
|
|
|
|
|
void stop() override { |
|
|
|
void stop(const CompNode& comp_node = {}) override { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
mgb_assert(*m_self_pointer == this); |
|
|
|
mgb_assert(!m_fake_exec); |
|
|
|
*m_self_pointer = nullptr; |
|
|
|
m_self_pointer_mtx->unlock(); |
|
|
|
m_stopped = true; |
|
|
|
} |
|
|
|
|
|
|
@@ -175,25 +183,32 @@ public: |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
void on_alloc() { |
|
|
|
void on_alloc(const CompNode& comp_node) { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
mgb_assert(m_fake_exec, |
|
|
|
"alloc is disallowed during comp node seq recording"); |
|
|
|
} |
|
|
|
|
|
|
|
void on_free() { |
|
|
|
void on_free(const CompNode& comp_node) { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
mgb_assert(m_fake_exec, |
|
|
|
"free is disallowed during comp node seq recording"); |
|
|
|
} |
|
|
|
|
|
|
|
void on_sync() { m_synchronized = true; } |
|
|
|
void on_sync(const CompNode& comp_node) { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
m_synchronized = true; |
|
|
|
} |
|
|
|
|
|
|
|
void dispatch(Task&& task) { |
|
|
|
void dispatch(Task&& task, const CompNode& comp_node) { |
|
|
|
mgb_assert(!m_synchronized, |
|
|
|
"no more tasks should be dispatched after synchronization"); |
|
|
|
auto kern = [task](size_t, size_t) { task(); }; |
|
|
|
dispatch_allow_after_sync({std::move(kern), static_cast<size_t>(1_z)}); |
|
|
|
dispatch_allow_after_sync({std::move(kern), static_cast<size_t>(1_z)}, |
|
|
|
comp_node); |
|
|
|
} |
|
|
|
void dispatch_allow_after_sync(Task&& task) { |
|
|
|
void dispatch_allow_after_sync(Task&& task, const CompNode& comp_node) { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
mgb_assert(!m_stopped, |
|
|
|
"dispatch should not be called after recording is stopped"); |
|
|
|
if (!m_fake_exec) { |
|
|
@@ -201,151 +216,28 @@ public: |
|
|
|
m_tasks.push_back({std::move(kern), static_cast<size_t>(1_z)}); |
|
|
|
} |
|
|
|
} |
|
|
|
void dispatch(TaskElem&& task_elem) { |
|
|
|
void dispatch(TaskElem&& task_elem, const CompNode& comp_node) { |
|
|
|
mgb_assert(!m_synchronized, |
|
|
|
"no more tasks should be dispatched after synchronization"); |
|
|
|
dispatch_allow_after_sync(std::move(task_elem)); |
|
|
|
dispatch_allow_after_sync(std::move(task_elem), comp_node); |
|
|
|
} |
|
|
|
void dispatch_allow_after_sync(TaskElem&& task_elem) { |
|
|
|
void dispatch_allow_after_sync(TaskElem&& task_elem, |
|
|
|
const CompNode& comp_node) { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
mgb_assert(!m_stopped, |
|
|
|
"dispatch should not be called after recording is stopped"); |
|
|
|
if (!m_fake_exec) { |
|
|
|
m_tasks.push_back(task_elem); |
|
|
|
} |
|
|
|
} |
|
|
|
size_t nr_threads() { |
|
|
|
size_t nr_threads(const CompNode& comp_node) { |
|
|
|
check_the_same_comp_node(comp_node); |
|
|
|
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z; |
|
|
|
} |
|
|
|
|
|
|
|
ThreadPool* get_thread_pool() { return m_thread_pool; } |
|
|
|
}; |
|
|
|
|
|
|
|
//! implementation of CPUDispatcher that is passed to megdnn via megcore |
|
|
|
class CpuCompNode::WorkerQueue::DispatcherImpl final: public CPUDispatcher { |
|
|
|
std::atomic_size_t m_nr_task{0}; |
|
|
|
std::shared_ptr<WorkerQueue> m_queue; |
|
|
|
SeqRecorderImpl** const m_cur_recorder; |
|
|
|
|
|
|
|
public: |
|
|
|
DispatcherImpl(const std::shared_ptr<WorkerQueue>& queue, |
|
|
|
SeqRecorderImpl** recorder) |
|
|
|
: m_queue{queue}, m_cur_recorder{recorder} {} |
|
|
|
|
|
|
|
void dispatch(Task&& task) override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
(*m_cur_recorder)->dispatch(std::move(task)); |
|
|
|
} else { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
auto kern = [task](size_t, size_t) { task(); }; |
|
|
|
m_queue->add_task({kern, static_cast<size_t>(1_z)}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void dispatch(MultiThreadingTask&& task, size_t parallelism) override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
(*m_cur_recorder)->dispatch({std::move(task), parallelism}); |
|
|
|
} else { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
m_queue->add_task({std::move(task), parallelism}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void sync() override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
(*m_cur_recorder)->on_sync(); |
|
|
|
} else { |
|
|
|
m_queue->wait_all_task_finish(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t nr_threads() override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
return (*m_cur_recorder)->nr_threads(); |
|
|
|
} else { |
|
|
|
return m_queue->nr_threads(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t get_nr_dispatched_tasks() const override { |
|
|
|
return m_nr_task; |
|
|
|
} |
|
|
|
|
|
|
|
void set_affinity(AffinityCallBack&& affinity_cb) override { |
|
|
|
auto thread_pool = m_queue->get_thread_pool(); |
|
|
|
if(thread_pool){ |
|
|
|
thread_pool->set_affinity(affinity_cb); |
|
|
|
} else { |
|
|
|
auto affinity_run = [affinity_cb](size_t, size_t) { |
|
|
|
affinity_cb(0); |
|
|
|
}; |
|
|
|
m_queue->add_task({affinity_run, 1_z}); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
//! implementation of InplaceCPUDispatcher |
|
|
|
class InplaceCPUDispatcher final : public CPUDispatcher { |
|
|
|
std::atomic_size_t m_nr_task{0}; |
|
|
|
ThreadPool* m_thread_pool = nullptr; |
|
|
|
CpuCompNode::SeqRecorderImpl** const m_cur_recorder; |
|
|
|
|
|
|
|
public: |
|
|
|
InplaceCPUDispatcher(CpuCompNode::SeqRecorderImpl** recorder, |
|
|
|
ThreadPool* thread_pool = nullptr) |
|
|
|
: m_thread_pool(thread_pool), m_cur_recorder(recorder) {} |
|
|
|
|
|
|
|
void dispatch(Task&& task) override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
(*m_cur_recorder)->dispatch(std::move(task)); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
auto kern = [task](size_t, size_t) { task(); }; |
|
|
|
m_thread_pool->add_task({kern, static_cast<size_t>(1_z)}); |
|
|
|
}else { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
task(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void dispatch(MultiThreadingTask&& task, size_t parallelism) override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
(*m_cur_recorder)->dispatch({std::move(task), parallelism}); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
m_thread_pool->add_task({task, parallelism}); |
|
|
|
}else{ |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
for(size_t i=0; i<parallelism;i++){ |
|
|
|
task(i, 0); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t nr_threads() override { |
|
|
|
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z; |
|
|
|
} |
|
|
|
|
|
|
|
void sync() override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
(*m_cur_recorder)->on_sync(); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_thread_pool->deactive(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t get_nr_dispatched_tasks() const override { return m_nr_task; } |
|
|
|
|
|
|
|
void set_affinity(AffinityCallBack&& affinity_cb) override { |
|
|
|
if (*m_cur_recorder) { |
|
|
|
(*m_cur_recorder)->get_thread_pool()->set_affinity(affinity_cb); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_thread_pool->set_affinity(affinity_cb); |
|
|
|
}else{ |
|
|
|
affinity_cb(0); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { |
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_DECL; |
|
|
|
|
|
|
@@ -353,8 +245,7 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { |
|
|
|
class CompSeqRecEventImpl; |
|
|
|
class CpuEventImpl; |
|
|
|
|
|
|
|
SeqRecorderImpl* m_cur_recorder = nullptr; |
|
|
|
std::mutex m_cur_recorder_mtx; |
|
|
|
static thread_local SeqRecorderImpl* sm_cur_recorder; |
|
|
|
std::shared_ptr<WorkerQueue> m_worker_queue; |
|
|
|
Locator m_locator, m_locator_logical; |
|
|
|
std::unique_ptr<ThreadPool> m_thread_pool; |
|
|
@@ -375,49 +266,10 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { |
|
|
|
|
|
|
|
public: |
|
|
|
CompNodeImpl(const Locator& locator, const Locator& locator_logical, |
|
|
|
const std::shared_ptr<WorkerQueue>& worker_queue) |
|
|
|
: CpuDispatchableBase(static_free_device, static_free_host), |
|
|
|
m_worker_queue{worker_queue}, |
|
|
|
m_locator(locator), |
|
|
|
m_locator_logical(locator_logical) { |
|
|
|
auto cn = make_comp_node_from_impl(this); |
|
|
|
if (locator.type == DeviceType::MULTITHREAD) { |
|
|
|
m_thread_pool = std::unique_ptr<ThreadPool>(new ThreadPool( |
|
|
|
static_cast<size_t>(locator.nr_threads))); |
|
|
|
mgb_assert(m_thread_pool, "ThradPool create failed"); |
|
|
|
} |
|
|
|
|
|
|
|
if (locator.type == DeviceType::CPU) { |
|
|
|
if(locator.device == Locator::DEVICE_CPU_DEFAULT){ |
|
|
|
sm_default_cpu_comp_node_ptr = this; |
|
|
|
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>( |
|
|
|
&m_cur_recorder)}, |
|
|
|
cn); |
|
|
|
} else { |
|
|
|
m_env.init_cpu( |
|
|
|
{std::make_shared<WorkerQueue::DispatcherImpl>( |
|
|
|
m_worker_queue, &m_cur_recorder)}, |
|
|
|
cn); |
|
|
|
} |
|
|
|
} else if (locator.type == DeviceType::MULTITHREAD) { |
|
|
|
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { |
|
|
|
m_env.init_cpu( |
|
|
|
{std::make_shared<InplaceCPUDispatcher>( |
|
|
|
&m_cur_recorder, m_thread_pool.get())}, |
|
|
|
cn); |
|
|
|
} else { |
|
|
|
m_worker_queue->attach_thread_pool(m_thread_pool.get()); |
|
|
|
m_env.init_cpu( |
|
|
|
{std::make_shared<WorkerQueue::DispatcherImpl>( |
|
|
|
m_worker_queue, &m_cur_recorder)}, |
|
|
|
cn); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
const std::shared_ptr<WorkerQueue>& worker_queue); |
|
|
|
~CompNodeImpl() { |
|
|
|
if (m_cur_recorder) { |
|
|
|
m_cur_recorder->stop(); |
|
|
|
if (sm_cur_recorder) { |
|
|
|
sm_cur_recorder->stop(); |
|
|
|
} |
|
|
|
if (m_worker_queue) { |
|
|
|
// synchronize before fini |
|
|
@@ -462,17 +314,17 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { |
|
|
|
} |
|
|
|
|
|
|
|
void* alloc_device(size_t size) override { |
|
|
|
if (m_cur_recorder) { |
|
|
|
m_cur_recorder->on_alloc(); |
|
|
|
if (sm_cur_recorder) { |
|
|
|
sm_cur_recorder->on_alloc(this); |
|
|
|
} |
|
|
|
return mgb_aligned_alloc(size); |
|
|
|
} |
|
|
|
|
|
|
|
void free_device(void *ptr) { |
|
|
|
if (m_cur_recorder || check_global_finalized("free_device()")) { |
|
|
|
if (sm_cur_recorder || check_global_finalized("free_device()")) { |
|
|
|
mgb_aligned_free(ptr); |
|
|
|
if (m_cur_recorder) { |
|
|
|
m_cur_recorder->on_free(); |
|
|
|
if (sm_cur_recorder) { |
|
|
|
sm_cur_recorder->on_free(this); |
|
|
|
} |
|
|
|
return; |
|
|
|
} else { |
|
|
@@ -557,8 +409,8 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { |
|
|
|
std::unique_ptr<Event> create_event(size_t flags) override; |
|
|
|
|
|
|
|
void sync() override { |
|
|
|
if (m_cur_recorder) { |
|
|
|
m_cur_recorder->on_sync(); |
|
|
|
if (sm_cur_recorder) { |
|
|
|
sm_cur_recorder->on_sync(this); |
|
|
|
} else if (m_worker_queue) { |
|
|
|
m_worker_queue->wait_all_task_finish(); |
|
|
|
} |
|
|
@@ -590,13 +442,12 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { |
|
|
|
|
|
|
|
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder( |
|
|
|
cg::ComputingGraph*) override { |
|
|
|
m_cur_recorder_mtx.lock(); |
|
|
|
return std::make_unique<SeqRecorderImpl>( |
|
|
|
&m_cur_recorder, &m_cur_recorder_mtx, m_thread_pool.get()); |
|
|
|
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder, |
|
|
|
m_thread_pool.get(), this); |
|
|
|
} |
|
|
|
|
|
|
|
//! current sequence recorder |
|
|
|
SeqRecorderImpl* cur_recorder() const { return m_cur_recorder; } |
|
|
|
SeqRecorderImpl* cur_recorder() const { return sm_cur_recorder; } |
|
|
|
|
|
|
|
void add_callback(Task &&task) override { |
|
|
|
if (!check_global_finalized("add_callback()")) { |
|
|
@@ -608,6 +459,179 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { |
|
|
|
}; |
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CpuCompNodeImpl); |
|
|
|
CpuCompNodeImpl* CpuCompNodeImpl::sm_default_cpu_comp_node_ptr; |
|
|
|
thread_local CpuCompNode::SeqRecorderImpl* CpuCompNodeImpl::sm_cur_recorder = |
|
|
|
nullptr; |
|
|
|
|
|
|
|
void CpuCompNode::SeqRecorderImpl::check_the_same_comp_node( |
|
|
|
const CompNode& comp_node) const { |
|
|
|
if (mgb_unlikely(comp_node.valid())) { |
|
|
|
mgb_assert(m_record_compnode == comp_node, |
|
|
|
"CompNode %s can't hook in CompNode %s when recording\n", |
|
|
|
comp_node.locator().to_string().c_str(), |
|
|
|
m_record_compnode.locator().to_string().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
//! implementation of CPUDispatcher that is passed to megdnn via megcore |
|
|
|
class CpuCompNode::WorkerQueue::DispatcherImpl final: public CPUDispatcher { |
|
|
|
std::atomic_size_t m_nr_task{0}; |
|
|
|
std::shared_ptr<WorkerQueue> m_queue; |
|
|
|
CpuCompNode::CompNodeImpl* const m_comp_node; |
|
|
|
|
|
|
|
public: |
|
|
|
DispatcherImpl(const std::shared_ptr<WorkerQueue>& queue, |
|
|
|
CpuCompNode::CompNodeImpl* comp_node) |
|
|
|
: m_queue{queue}, m_comp_node{comp_node} {} |
|
|
|
|
|
|
|
void dispatch(Task&& task) override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
recorder->dispatch(std::move(task), m_comp_node); |
|
|
|
} else { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
auto kern = [task](size_t, size_t) { task(); }; |
|
|
|
m_queue->add_task({kern, static_cast<size_t>(1_z)}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void dispatch(MultiThreadingTask&& task, size_t parallelism) override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
recorder->dispatch({std::move(task), parallelism}, m_comp_node); |
|
|
|
} else { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
m_queue->add_task({std::move(task), parallelism}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void sync() override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
recorder->on_sync(m_comp_node); |
|
|
|
} else { |
|
|
|
m_queue->wait_all_task_finish(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t nr_threads() override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
return recorder->nr_threads(m_comp_node); |
|
|
|
} else { |
|
|
|
return m_queue->nr_threads(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t get_nr_dispatched_tasks() const override { return m_nr_task; } |
|
|
|
|
|
|
|
void set_affinity(AffinityCallBack&& affinity_cb) override { |
|
|
|
auto thread_pool = m_queue->get_thread_pool(); |
|
|
|
if (thread_pool) { |
|
|
|
thread_pool->set_affinity(affinity_cb); |
|
|
|
} else { |
|
|
|
auto affinity_run = [affinity_cb](size_t, size_t) { |
|
|
|
affinity_cb(0); |
|
|
|
}; |
|
|
|
m_queue->add_task({affinity_run, 1_z}); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
//! implementation of InplaceCPUDispatcher |
|
|
|
class InplaceCPUDispatcher final : public CPUDispatcher { |
|
|
|
std::atomic_size_t m_nr_task{0}; |
|
|
|
ThreadPool* m_thread_pool = nullptr; |
|
|
|
CpuCompNode::CompNodeImpl* const m_comp_node; |
|
|
|
|
|
|
|
public: |
|
|
|
InplaceCPUDispatcher(CpuCompNode::CompNodeImpl* comp_node, |
|
|
|
ThreadPool* thread_pool = nullptr) |
|
|
|
: m_thread_pool(thread_pool), m_comp_node(comp_node) {} |
|
|
|
|
|
|
|
void dispatch(Task&& task) override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
recorder->dispatch(std::move(task), m_comp_node); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
auto kern = [task](size_t, size_t) { task(); }; |
|
|
|
m_thread_pool->add_task({kern, static_cast<size_t>(1_z)}); |
|
|
|
} else { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
task(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void dispatch(MultiThreadingTask&& task, size_t parallelism) override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
recorder->dispatch({std::move(task), parallelism}, m_comp_node); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
m_thread_pool->add_task({task, parallelism}); |
|
|
|
}else{ |
|
|
|
m_nr_task.fetch_add(1, std::memory_order_relaxed); |
|
|
|
for(size_t i=0; i<parallelism;i++){ |
|
|
|
task(i, 0); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t nr_threads() override { |
|
|
|
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z; |
|
|
|
} |
|
|
|
|
|
|
|
void sync() override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
recorder->on_sync(m_comp_node); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_thread_pool->deactive(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t get_nr_dispatched_tasks() const override { return m_nr_task; } |
|
|
|
|
|
|
|
void set_affinity(AffinityCallBack&& affinity_cb) override { |
|
|
|
if (auto recorder = m_comp_node->cur_recorder()) { |
|
|
|
recorder->get_thread_pool()->set_affinity(affinity_cb); |
|
|
|
} else if (m_thread_pool) { |
|
|
|
m_thread_pool->set_affinity(affinity_cb); |
|
|
|
}else{ |
|
|
|
affinity_cb(0); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
CpuCompNode::CompNodeImpl::CompNodeImpl( |
|
|
|
const Locator& locator, const Locator& locator_logical, |
|
|
|
const std::shared_ptr<WorkerQueue>& worker_queue) |
|
|
|
: CpuDispatchableBase(static_free_device, static_free_host), |
|
|
|
m_worker_queue{worker_queue}, |
|
|
|
m_locator(locator), |
|
|
|
m_locator_logical(locator_logical) { |
|
|
|
auto cn = make_comp_node_from_impl(this); |
|
|
|
if (locator.type == DeviceType::MULTITHREAD) { |
|
|
|
m_thread_pool = std::unique_ptr<ThreadPool>( |
|
|
|
new ThreadPool(static_cast<size_t>(locator.nr_threads))); |
|
|
|
mgb_assert(m_thread_pool, "ThradPool create failed"); |
|
|
|
} |
|
|
|
|
|
|
|
if (locator.type == DeviceType::CPU) { |
|
|
|
if (locator.device == Locator::DEVICE_CPU_DEFAULT) { |
|
|
|
sm_default_cpu_comp_node_ptr = this; |
|
|
|
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)}, cn); |
|
|
|
} else { |
|
|
|
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>( |
|
|
|
m_worker_queue, this)}, |
|
|
|
cn); |
|
|
|
} |
|
|
|
} else if (locator.type == DeviceType::MULTITHREAD) { |
|
|
|
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { |
|
|
|
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>( |
|
|
|
this, m_thread_pool.get())}, |
|
|
|
cn); |
|
|
|
} else { |
|
|
|
m_worker_queue->attach_thread_pool(m_thread_pool.get()); |
|
|
|
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>( |
|
|
|
m_worker_queue, this)}, |
|
|
|
cn); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
class CpuCompNodeImpl::CompSeqRecEventImpl final |
|
|
|
: public CpuDispatchableBase::EventImpl { |
|
|
@@ -618,7 +642,7 @@ class CpuCompNodeImpl::CompSeqRecEventImpl final |
|
|
|
incr_nr_req(); |
|
|
|
on_finish(); |
|
|
|
}; |
|
|
|
rec->dispatch_allow_after_sync(callback); |
|
|
|
rec->dispatch_allow_after_sync(callback, m_comp_node_impl); |
|
|
|
} else { |
|
|
|
EventImpl::do_record(); |
|
|
|
} |
|
|
@@ -674,7 +698,7 @@ std::unique_ptr<CompNode::Event> CpuCompNodeImpl::create_event(size_t flags) { |
|
|
|
if (m_worker_queue) { |
|
|
|
m_worker_queue->check_exception(); |
|
|
|
} |
|
|
|
if (m_cur_recorder) { |
|
|
|
if (sm_cur_recorder) { |
|
|
|
return std::make_unique<CompSeqRecEventImpl>(this, flags); |
|
|
|
} else { |
|
|
|
return std::make_unique<CpuEventImpl>(this, flags); |
|
|
|