diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 8b57e45c..edc39f3e 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -223,8 +223,10 @@ void test_multi_thread(bool multi_thread_compnode) { std::string model_path = "./shufflenet.mge"; size_t nr_threads = 2; - std::vector thread_ids(nr_threads); + std::vector thread_ids_user(nr_threads); + std::vector thread_ids_worker(nr_threads); auto runner = [&](size_t i) { + thread_ids_user[i] = std::hash{}(std::this_thread::get_id()); std::shared_ptr network = std::make_shared(config); Runtime::set_cpu_inplace_mode(network); if (multi_thread_compnode) { @@ -232,11 +234,18 @@ void test_multi_thread(bool multi_thread_compnode) { } network->load_model(model_path); - Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) { - if (id == 0) { - thread_ids[i] = std::this_thread::get_id(); - } - }); + Runtime::set_runtime_thread_affinity( + network, [&multi_thread_compnode, &thread_ids_worker, i](int id) { + if (multi_thread_compnode) { + if (id == 1) { + thread_ids_worker[i] = std::hash{}( + std::this_thread::get_id()); + } + } else { + thread_ids_worker[i] = std::hash{}( + std::this_thread::get_id()); + } + }); std::shared_ptr input_tensor = network->get_input_tensor(0); auto src_ptr = lite_tensor->get_memory_ptr(); @@ -250,11 +259,11 @@ void test_multi_thread(bool multi_thread_compnode) { std::vector threads; for (size_t i = 0; i < nr_threads; i++) { threads.emplace_back(runner, i); + threads[i].join(); } for (size_t i = 0; i < nr_threads; i++) { - threads[i].join(); + ASSERT_EQ(thread_ids_user[i], thread_ids_worker[i]); } - ASSERT_NE(thread_ids[0], thread_ids[1]); } } // namespace