Browse Source

fix(lite): fix lite test error

GitOrigin-RevId: ab608672ec
release-1.10
Megvii Engine Team 3 years ago
parent
commit
6814cf1cd7
1 changed files with 17 additions and 8 deletions
  1. +17
    -8
      lite/test/test_network.cpp

+ 17
- 8
lite/test/test_network.cpp View File

@@ -223,8 +223,10 @@ void test_multi_thread(bool multi_thread_compnode) {
std::string model_path = "./shufflenet.mge";

size_t nr_threads = 2;
std::vector<std::thread::id> thread_ids(nr_threads);
std::vector<size_t> thread_ids_user(nr_threads);
std::vector<size_t> thread_ids_worker(nr_threads);
auto runner = [&](size_t i) {
thread_ids_user[i] = std::hash<std::thread::id>{}(std::this_thread::get_id());
std::shared_ptr<Network> network = std::make_shared<Network>(config);
Runtime::set_cpu_inplace_mode(network);
if (multi_thread_compnode) {
@@ -232,11 +234,18 @@ void test_multi_thread(bool multi_thread_compnode) {
}

network->load_model(model_path);
Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) {
if (id == 0) {
thread_ids[i] = std::this_thread::get_id();
}
});
Runtime::set_runtime_thread_affinity(
network, [&multi_thread_compnode, &thread_ids_worker, i](int id) {
if (multi_thread_compnode) {
if (id == 1) {
thread_ids_worker[i] = std::hash<std::thread::id>{}(
std::this_thread::get_id());
}
} else {
thread_ids_worker[i] = std::hash<std::thread::id>{}(
std::this_thread::get_id());
}
});

std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
auto src_ptr = lite_tensor->get_memory_ptr();
@@ -250,11 +259,11 @@ void test_multi_thread(bool multi_thread_compnode) {
std::vector<std::thread> threads;
for (size_t i = 0; i < nr_threads; i++) {
threads.emplace_back(runner, i);
threads[i].join();
}
for (size_t i = 0; i < nr_threads; i++) {
threads[i].join();
ASSERT_EQ(thread_ids_user[i], thread_ids_worker[i]);
}
ASSERT_NE(thread_ids[0], thread_ids[1]);
}

} // namespace


Loading…
Cancel
Save