GitOrigin-RevId: 36787a08a5
tags/v1.0.0-rc1
@@ -63,13 +63,12 @@ class CpuCompNode::WorkerQueue final | |||
#endif | |||
} | |||
sys::set_thread_name(m_locator.to_string()); | |||
if(m_thread_pool) | |||
m_thread_pool->active(); | |||
} | |||
void on_sync_all_task_finish() override { | |||
if (m_thread_pool) | |||
if (m_thread_pool) { | |||
m_thread_pool->deactive(); | |||
} | |||
} | |||
public: | |||
@@ -436,6 +435,8 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { | |||
} | |||
} | |||
ThreadPool* get_thread_pool() const { return m_thread_pool.get(); } | |||
void* mgb_aligned_alloc(size_t size) { | |||
auto alignment = get_mem_addr_alignment(); | |||
#ifdef WIN32 | |||
@@ -546,6 +547,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { | |||
} else if (m_worker_queue) { | |||
m_worker_queue->wait_all_task_finish(); | |||
} | |||
if (m_thread_pool) { | |||
m_thread_pool->deactive(); | |||
} | |||
} | |||
void dispatch(Task &&task) override { | |||
@@ -893,6 +897,11 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() { | |||
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { | |||
for (size_t i = 0, it = SCQueueSynchronizer::max_spin() / 20; i < it; ++i) { | |||
if (finished()) { | |||
auto thread_pool = static_cast<CpuCompNodeImpl*>(m_comp_node_impl) | |||
->get_thread_pool(); | |||
if (thread_pool) { | |||
thread_pool->deactive(); | |||
} | |||
return; | |||
} | |||
} | |||
@@ -906,6 +915,11 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { | |||
m_dev_wait_cv.wait(lock); | |||
} | |||
m_dev_wait_nr_waiter.fetch_sub(1, std::memory_order_release); | |||
auto thread_pool = | |||
static_cast<CpuCompNodeImpl*>(m_comp_node_impl)->get_thread_pool(); | |||
if (thread_pool) { | |||
thread_pool->deactive(); | |||
} | |||
} | |||
CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept { | |||
@@ -74,6 +74,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) { | |||
//! Make sure the main thread have bind | |||
if (m_main_affinity_flag && | |||
m_core_binding_function != nullptr) { | |||
std::lock_guard<std::mutex> lock(m_mutex_task); | |||
m_core_binding_function(m_nr_threads - 1); | |||
m_main_affinity_flag = false; | |||
} | |||
@@ -85,10 +86,10 @@ void ThreadPool::add_task(const TaskElem& task_elem) { | |||
} | |||
return; | |||
} else { | |||
std::lock_guard<std::mutex> lock(m_mutex_task); | |||
mgb_assert(m_task_iter.load(std::memory_order_acquire) <= 0, | |||
"The init value of m_all_sub_task is not zero."); | |||
active(); | |||
std::lock_guard<std::mutex> lock(m_mutex_task); | |||
//! Set the task number, task iter and task | |||
m_nr_parallelism = parallelism; | |||
m_task_iter.exchange(parallelism, std::memory_order_relaxed); | |||
@@ -113,6 +114,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) { | |||
void ThreadPool::set_affinity(AffinityCallBack affinity_cb) { | |||
mgb_assert(affinity_cb, "The affinity callback must not be nullptr"); | |||
std::lock_guard<std::mutex> lock(m_mutex_task); | |||
m_core_binding_function = affinity_cb; | |||
for (size_t i = 0; i < m_nr_threads - 1; i++) { | |||
m_workers[i]->affinity_flag = true; | |||
@@ -147,10 +149,12 @@ void ThreadPool::active() { | |||
} | |||
} | |||
void ThreadPool::deactive() { | |||
std::lock_guard<std::mutex> lock_task(m_mutex_task); | |||
std::unique_lock<std::mutex> lock(m_mutex); | |||
m_active = false; | |||
} | |||
ThreadPool::~ThreadPool() { | |||
std::lock_guard<std::mutex> lock_task(m_mutex_task); | |||
{ | |||
std::unique_lock<std::mutex> lock(m_mutex); | |||
m_stop = true; | |||
@@ -80,7 +80,7 @@ public: | |||
~ThreadPool(); | |||
private: | |||
size_t m_nr_threads = 0; | |||
const size_t m_nr_threads = 0; | |||
//! Indicate whether the main thread have binding | |||
bool m_main_affinity_flag; | |||
//! The callback binding the threads to cores | |||
@@ -12,6 +12,8 @@ | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/system.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/utility.h" | |||
#include <atomic> | |||
#include <random> | |||
@@ -59,6 +61,73 @@ TEST(TestThreadPool, BASIC) { | |||
ASSERT_EQ(dst1[i], truth[i]); | |||
} | |||
} | |||
TEST(TestGraph, ParallelRunMultithreadMode) { | |||
// check race conditions when graphs are executed on multple threads | |||
std::atomic_size_t sync_counter{0}; | |||
constexpr size_t NR_RUN = 50; | |||
size_t nr_worker = std::max(4, sys::get_cpu_count() / 4); | |||
if (auto setting = MGB_GETENV("TestGraphParallelRun_nr_worker")) { | |||
nr_worker = std::stoul(setting); | |||
} | |||
mgb_log("use %zu workers", nr_worker); | |||
auto sync_barrier = [&sync_counter, nr_worker](size_t& cnt) { | |||
++sync_counter; | |||
++cnt; | |||
while (sync_counter < cnt * nr_worker) | |||
; | |||
}; | |||
auto do_worker = [&sync_barrier](size_t sync_cnt) { | |||
auto cn = CompNode::load("multithread2:0"); | |||
HostTensorGenerator<> gen; | |||
auto host_x = gen({23}, cn); | |||
HostTensorND host_y, y_expect; | |||
y_expect.copy_from(*host_x); | |||
{ | |||
auto py = y_expect.ptr<float>(); | |||
for (int i = 0; i < 23; ++i) { | |||
for (int j = 0; j < 5; ++j) { | |||
py[i] = py[i] * 2 + 3; | |||
} | |||
} | |||
} | |||
sync_barrier(sync_cnt); | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), y = x; | |||
for (int i = 0; i < 5; ++i) { | |||
y = y * 2 + 3; | |||
} | |||
sync_barrier(sync_cnt); | |||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
sync_barrier(sync_cnt); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_EQ(y_expect, host_y); | |||
memset(host_y.raw_ptr(), -1, 23 * sizeof(float)); | |||
sync_barrier(sync_cnt); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_EQ(y_expect, host_y); | |||
func->wait(); | |||
}; | |||
auto worker = [&]() { | |||
size_t scnt = 0; | |||
for (size_t run_id = 0; run_id < NR_RUN; ++run_id) { | |||
do_worker(scnt); | |||
} | |||
}; | |||
std::vector<std::thread> workers; | |||
for (size_t i = 0; i < nr_worker; ++i) | |||
workers.emplace_back(worker); | |||
for (auto&& i : workers) | |||
i.join(); | |||
} | |||
#else | |||
#pragma message "tests are disabled as thread is not enabled." | |||
#endif // MGB_HAVE_THREAD | |||