GitOrigin-RevId: 5dddc68a84
tags/v1.3.0
@@ -420,7 +420,7 @@ CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator, | |||
for (int i = 0; i < sd.nr_node; ++i) { | |||
auto&& cur = sd.node[i]; | |||
if (cur.m_initialized) { | |||
if (cur.m_locator_logical == locator_logical) { | |||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||
return &cur; | |||
} | |||
} else { | |||
@@ -604,7 +604,7 @@ CompNode::Impl* CambriconCompNode::load_cambricon( | |||
for (int i = 0; i < sd.nr_node; ++i) { | |||
auto&& cur = sd.node[i]; | |||
if (cur.m_initialized) { | |||
if (cur.m_locator_logical == locator_logical) { | |||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||
return &cur; | |||
} | |||
} else { | |||
@@ -250,6 +250,10 @@ void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { | |||
void CompNode::Locator::set_unspec_device_type(DeviceType type) { | |||
mgb_assert(type != DeviceType::UNSPEC); | |||
if (type != DeviceType::CPU && type != DeviceType::CUDA) { | |||
mgb_log_warn("to resolve unspec device type as one except " | |||
"CUDA and CPU may lead to unknown problems."); | |||
} | |||
g_unspec_locator_type = type; | |||
} | |||
@@ -723,12 +723,13 @@ struct CpuCompNode::Pool { | |||
impl_storage[MAX_NR_COMP_NODE]; | |||
size_t nr_used_impl_storage = 0; | |||
ThinHashMap<std::pair<int, int>, | |||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>> logical2impl; | |||
std::unordered_map<CompNode::LocatorPairHashKey, | |||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, | |||
CompNode::LocatorPairHashKey::Hash> locator2impl; | |||
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue; | |||
ThinHashMap<std::pair<int, int>, | |||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>> | |||
logical2impl_multi_thread; | |||
std::unordered_map<CompNode::LocatorPairHashKey, | |||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, | |||
CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread; | |||
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> | |||
physical2queue_multithead; | |||
}; | |||
@@ -792,14 +793,9 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||
MGB_LOCK_GUARD(sm_pool->mtx); | |||
// encode both device ID and type into a int | |||
int compact_logical_device = locator_logical.device; | |||
mgb_assert(compact_logical_device >= -1 || | |||
compact_logical_device <= Locator::DEVICE_CPU_DEFAULT); | |||
if (locator_logical.type == CompNode::DeviceType::UNSPEC) { | |||
compact_logical_device += std::numeric_limits<int>::min() + 1; | |||
mgb_assert(compact_logical_device < | |||
Locator::DEVICE_MULTITHREAD_DEFAULT); | |||
} else { | |||
mgb_assert(locator_logical.device >= -1 || | |||
locator_logical.device <= Locator::DEVICE_CPU_DEFAULT); | |||
if (locator_logical.type != CompNode::DeviceType::UNSPEC) { | |||
mgb_assert(locator_logical.type == CompNode::DeviceType::CPU || | |||
locator_logical.type == CompNode::DeviceType::MULTITHREAD); | |||
} | |||
@@ -811,8 +807,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||
pqueue = std::make_shared<WorkerQueue>(locator); | |||
pqueue_weak = pqueue; | |||
} | |||
auto&& pimpl = sm_pool->logical2impl[{compact_logical_device, | |||
locator_logical.stream}]; | |||
auto&& pimpl = sm_pool->locator2impl[{locator, | |||
locator_logical}]; | |||
if (!pimpl) { | |||
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | |||
"too many cpu comp nodes; max %d allowed", | |||
@@ -833,8 +829,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||
pqueue = std::make_shared<WorkerQueue>(locator); | |||
pqueue_weak = pqueue; | |||
} | |||
auto&& pimpl = sm_pool->logical2impl_multi_thread[{ | |||
compact_logical_device, locator_logical.nr_threads}]; | |||
auto&& pimpl = sm_pool->locator2impl_multi_thread[{ | |||
locator, locator_logical}]; | |||
if (!pimpl) { | |||
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | |||
"too many cpu multithread comp nodes; max %d allowed", | |||
@@ -854,9 +850,9 @@ void CpuCompNode::sync_all() { | |||
return; | |||
MGB_LOCK_GUARD(sm_pool->mtx); | |||
for (auto &&i: sm_pool->logical2impl) | |||
for (auto &&i: sm_pool->locator2impl) | |||
i.second->sync(); | |||
for (auto&& i : sm_pool->logical2impl_multi_thread) | |||
for (auto&& i : sm_pool->locator2impl_multi_thread) | |||
i.second->sync(); | |||
} | |||
@@ -718,7 +718,7 @@ CompNode::Impl* CudaCompNode::load_cuda( | |||
for (int i = 0; i < sd.nr_node; ++ i) { | |||
auto &&cur = sd.node[i]; | |||
if (cur.m_initialized) { | |||
if (cur.m_locator_logical == locator_logical) { | |||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||
return &cur; | |||
} | |||
} else { | |||
@@ -606,7 +606,7 @@ CompNode::Impl* ROCmCompNode::load_rocm(const Locator& locator, | |||
for (int i = 0; i < sd.nr_node; ++i) { | |||
auto&& cur = sd.node[i]; | |||
if (cur.m_initialized) { | |||
if (cur.m_locator_logical == locator_logical) { | |||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||
return &cur; | |||
} | |||
} else { | |||
@@ -168,6 +168,22 @@ class CompNode { | |||
return type == rhs.type && device == rhs.device && | |||
stream == rhs.stream; | |||
} | |||
}; | |||
struct LocatorPairHashKey { | |||
Locator locator, locator_logical; | |||
bool operator==(const LocatorPairHashKey& rhs) const { | |||
return locator == rhs.locator && locator_logical == rhs.locator_logical; | |||
} | |||
struct Hash { | |||
size_t operator()(const LocatorPairHashKey& k) const { | |||
return hash_pair_combine(mgb::hash(k.locator), | |||
mgb::hash(k.locator_logical)); | |||
} | |||
}; | |||
}; | |||
//! predefined special streams | |||
@@ -537,6 +553,7 @@ class CompNode { | |||
friend class CompNodeEnv; | |||
friend struct HashTrait<CompNode>; | |||
friend struct HashTrait<CompNode::Locator>; | |||
friend class CompNodeImplHelper; | |||
public: | |||
CompNode(ImplBase* impl) : m_impl{impl} {} | |||
@@ -686,6 +703,15 @@ struct HashTrait<CompNode> { | |||
} | |||
}; | |||
template<> | |||
struct HashTrait<CompNode::Locator> { | |||
static size_t eval(const CompNode::Locator &val) { | |||
return static_cast<size_t>(val.device) | |||
+ (static_cast<size_t>(val.type) << 4) | |||
+ (static_cast<size_t>(val.stream) << 8); | |||
} | |||
}; | |||
namespace comp_node_detail { | |||
/*! | |||
@@ -86,19 +86,34 @@ TEST(TestCompNode, SetDefaultDev) { | |||
CompNode::finalize(); | |||
using L = CompNode::Locator; | |||
auto orig_dt = L::parse("xpu").to_physical(), | |||
orig_gpu = L::parse("gpux").to_physical(); | |||
orig_gpu = L::parse("gpux").to_physical(), | |||
orig_cpu = L::parse("cpux").to_physical(); | |||
constexpr auto CUDA = CompNode::DeviceType::CUDA; | |||
constexpr auto CPU = CompNode::DeviceType::CPU; | |||
L::set_unspec_device_type(CUDA); | |||
L::set_device_map(CUDA, -1, 2); | |||
auto run = []() { | |||
ASSERT_EQ(CompNode::load("xpu").locator(), L::parse("gpu2")); | |||
auto run = [](int device) { | |||
ASSERT_EQ(CompNode::load("xpu").locator(), | |||
L::parse("gpu" + std::to_string(device))); | |||
}; | |||
auto run_cpu = [](int device) { | |||
ASSERT_EQ(CompNode::load("cpux").locator(), | |||
L::parse("cpu" + std::to_string(device))); | |||
}; | |||
MGB_TRY { | |||
run(); | |||
L::set_device_map(CUDA, -1, 2); | |||
run(2); | |||
L::set_device_map(CUDA, -1, 1); | |||
run(1); | |||
L::set_device_map(CPU, -1, 2); | |||
run_cpu(2); | |||
L::set_device_map(CPU, -1, 1); | |||
run_cpu(1); | |||
} MGB_FINALLY({ | |||
L::set_unspec_device_type(orig_dt.type); | |||
L::set_device_map(CUDA, -1, orig_gpu.device); | |||
L::set_device_map(CPU, -1, orig_cpu.device); | |||
}); | |||
CompNode::finalize(); | |||
} | |||