GitOrigin-RevId: 5dddc68a84
tags/v1.3.0
@@ -420,7 +420,7 @@ CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator, | |||||
for (int i = 0; i < sd.nr_node; ++i) { | for (int i = 0; i < sd.nr_node; ++i) { | ||||
auto&& cur = sd.node[i]; | auto&& cur = sd.node[i]; | ||||
if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
if (cur.m_locator_logical == locator_logical) { | |||||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
return &cur; | return &cur; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -604,7 +604,7 @@ CompNode::Impl* CambriconCompNode::load_cambricon( | |||||
for (int i = 0; i < sd.nr_node; ++i) { | for (int i = 0; i < sd.nr_node; ++i) { | ||||
auto&& cur = sd.node[i]; | auto&& cur = sd.node[i]; | ||||
if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
if (cur.m_locator_logical == locator_logical) { | |||||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
return &cur; | return &cur; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -250,6 +250,10 @@ void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { | |||||
void CompNode::Locator::set_unspec_device_type(DeviceType type) { | void CompNode::Locator::set_unspec_device_type(DeviceType type) { | ||||
mgb_assert(type != DeviceType::UNSPEC); | mgb_assert(type != DeviceType::UNSPEC); | ||||
if (type != DeviceType::CPU && type != DeviceType::CUDA) { | |||||
mgb_log_warn("to resolve unspec device type as one except " | |||||
"CUDA and CPU may lead to unknown problems."); | |||||
} | |||||
g_unspec_locator_type = type; | g_unspec_locator_type = type; | ||||
} | } | ||||
@@ -723,12 +723,13 @@ struct CpuCompNode::Pool { | |||||
impl_storage[MAX_NR_COMP_NODE]; | impl_storage[MAX_NR_COMP_NODE]; | ||||
size_t nr_used_impl_storage = 0; | size_t nr_used_impl_storage = 0; | ||||
ThinHashMap<std::pair<int, int>, | |||||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>> logical2impl; | |||||
std::unordered_map<CompNode::LocatorPairHashKey, | |||||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, | |||||
CompNode::LocatorPairHashKey::Hash> locator2impl; | |||||
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue; | ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue; | ||||
ThinHashMap<std::pair<int, int>, | |||||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>> | |||||
logical2impl_multi_thread; | |||||
std::unordered_map<CompNode::LocatorPairHashKey, | |||||
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, | |||||
CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread; | |||||
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> | ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> | ||||
physical2queue_multithead; | physical2queue_multithead; | ||||
}; | }; | ||||
@@ -792,14 +793,9 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||||
MGB_LOCK_GUARD(sm_pool->mtx); | MGB_LOCK_GUARD(sm_pool->mtx); | ||||
// encode both device ID and type into a int | // encode both device ID and type into a int | ||||
int compact_logical_device = locator_logical.device; | |||||
mgb_assert(compact_logical_device >= -1 || | |||||
compact_logical_device <= Locator::DEVICE_CPU_DEFAULT); | |||||
if (locator_logical.type == CompNode::DeviceType::UNSPEC) { | |||||
compact_logical_device += std::numeric_limits<int>::min() + 1; | |||||
mgb_assert(compact_logical_device < | |||||
Locator::DEVICE_MULTITHREAD_DEFAULT); | |||||
} else { | |||||
mgb_assert(locator_logical.device >= -1 || | |||||
locator_logical.device <= Locator::DEVICE_CPU_DEFAULT); | |||||
if (locator_logical.type != CompNode::DeviceType::UNSPEC) { | |||||
mgb_assert(locator_logical.type == CompNode::DeviceType::CPU || | mgb_assert(locator_logical.type == CompNode::DeviceType::CPU || | ||||
locator_logical.type == CompNode::DeviceType::MULTITHREAD); | locator_logical.type == CompNode::DeviceType::MULTITHREAD); | ||||
} | } | ||||
@@ -811,8 +807,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||||
pqueue = std::make_shared<WorkerQueue>(locator); | pqueue = std::make_shared<WorkerQueue>(locator); | ||||
pqueue_weak = pqueue; | pqueue_weak = pqueue; | ||||
} | } | ||||
auto&& pimpl = sm_pool->logical2impl[{compact_logical_device, | |||||
locator_logical.stream}]; | |||||
auto&& pimpl = sm_pool->locator2impl[{locator, | |||||
locator_logical}]; | |||||
if (!pimpl) { | if (!pimpl) { | ||||
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | ||||
"too many cpu comp nodes; max %d allowed", | "too many cpu comp nodes; max %d allowed", | ||||
@@ -833,8 +829,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, | |||||
pqueue = std::make_shared<WorkerQueue>(locator); | pqueue = std::make_shared<WorkerQueue>(locator); | ||||
pqueue_weak = pqueue; | pqueue_weak = pqueue; | ||||
} | } | ||||
auto&& pimpl = sm_pool->logical2impl_multi_thread[{ | |||||
compact_logical_device, locator_logical.nr_threads}]; | |||||
auto&& pimpl = sm_pool->locator2impl_multi_thread[{ | |||||
locator, locator_logical}]; | |||||
if (!pimpl) { | if (!pimpl) { | ||||
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, | ||||
"too many cpu multithread comp nodes; max %d allowed", | "too many cpu multithread comp nodes; max %d allowed", | ||||
@@ -854,9 +850,9 @@ void CpuCompNode::sync_all() { | |||||
return; | return; | ||||
MGB_LOCK_GUARD(sm_pool->mtx); | MGB_LOCK_GUARD(sm_pool->mtx); | ||||
for (auto &&i: sm_pool->logical2impl) | |||||
for (auto &&i: sm_pool->locator2impl) | |||||
i.second->sync(); | i.second->sync(); | ||||
for (auto&& i : sm_pool->logical2impl_multi_thread) | |||||
for (auto&& i : sm_pool->locator2impl_multi_thread) | |||||
i.second->sync(); | i.second->sync(); | ||||
} | } | ||||
@@ -718,7 +718,7 @@ CompNode::Impl* CudaCompNode::load_cuda( | |||||
for (int i = 0; i < sd.nr_node; ++ i) { | for (int i = 0; i < sd.nr_node; ++ i) { | ||||
auto &&cur = sd.node[i]; | auto &&cur = sd.node[i]; | ||||
if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
if (cur.m_locator_logical == locator_logical) { | |||||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
return &cur; | return &cur; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -606,7 +606,7 @@ CompNode::Impl* ROCmCompNode::load_rocm(const Locator& locator, | |||||
for (int i = 0; i < sd.nr_node; ++i) { | for (int i = 0; i < sd.nr_node; ++i) { | ||||
auto&& cur = sd.node[i]; | auto&& cur = sd.node[i]; | ||||
if (cur.m_initialized) { | if (cur.m_initialized) { | ||||
if (cur.m_locator_logical == locator_logical) { | |||||
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { | |||||
return &cur; | return &cur; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -168,6 +168,22 @@ class CompNode { | |||||
return type == rhs.type && device == rhs.device && | return type == rhs.type && device == rhs.device && | ||||
stream == rhs.stream; | stream == rhs.stream; | ||||
} | } | ||||
}; | |||||
struct LocatorPairHashKey { | |||||
Locator locator, locator_logical; | |||||
bool operator==(const LocatorPairHashKey& rhs) const { | |||||
return locator == rhs.locator && locator_logical == rhs.locator_logical; | |||||
} | |||||
struct Hash { | |||||
size_t operator()(const LocatorPairHashKey& k) const { | |||||
return hash_pair_combine(mgb::hash(k.locator), | |||||
mgb::hash(k.locator_logical)); | |||||
} | |||||
}; | |||||
}; | }; | ||||
//! predefined special streams | //! predefined special streams | ||||
@@ -537,6 +553,7 @@ class CompNode { | |||||
friend class CompNodeEnv; | friend class CompNodeEnv; | ||||
friend struct HashTrait<CompNode>; | friend struct HashTrait<CompNode>; | ||||
friend struct HashTrait<CompNode::Locator>; | |||||
friend class CompNodeImplHelper; | friend class CompNodeImplHelper; | ||||
public: | public: | ||||
CompNode(ImplBase* impl) : m_impl{impl} {} | CompNode(ImplBase* impl) : m_impl{impl} {} | ||||
@@ -686,6 +703,15 @@ struct HashTrait<CompNode> { | |||||
} | } | ||||
}; | }; | ||||
template<> | |||||
struct HashTrait<CompNode::Locator> { | |||||
static size_t eval(const CompNode::Locator &val) { | |||||
return static_cast<size_t>(val.device) | |||||
+ (static_cast<size_t>(val.type) << 4) | |||||
+ (static_cast<size_t>(val.stream) << 8); | |||||
} | |||||
}; | |||||
namespace comp_node_detail { | namespace comp_node_detail { | ||||
/*! | /*! | ||||
@@ -86,19 +86,34 @@ TEST(TestCompNode, SetDefaultDev) { | |||||
CompNode::finalize(); | CompNode::finalize(); | ||||
using L = CompNode::Locator; | using L = CompNode::Locator; | ||||
auto orig_dt = L::parse("xpu").to_physical(), | auto orig_dt = L::parse("xpu").to_physical(), | ||||
orig_gpu = L::parse("gpux").to_physical(); | |||||
orig_gpu = L::parse("gpux").to_physical(), | |||||
orig_cpu = L::parse("cpux").to_physical(); | |||||
constexpr auto CUDA = CompNode::DeviceType::CUDA; | constexpr auto CUDA = CompNode::DeviceType::CUDA; | ||||
constexpr auto CPU = CompNode::DeviceType::CPU; | |||||
L::set_unspec_device_type(CUDA); | L::set_unspec_device_type(CUDA); | ||||
L::set_device_map(CUDA, -1, 2); | |||||
auto run = []() { | |||||
ASSERT_EQ(CompNode::load("xpu").locator(), L::parse("gpu2")); | |||||
auto run = [](int device) { | |||||
ASSERT_EQ(CompNode::load("xpu").locator(), | |||||
L::parse("gpu" + std::to_string(device))); | |||||
}; | |||||
auto run_cpu = [](int device) { | |||||
ASSERT_EQ(CompNode::load("cpux").locator(), | |||||
L::parse("cpu" + std::to_string(device))); | |||||
}; | }; | ||||
MGB_TRY { | MGB_TRY { | ||||
run(); | |||||
L::set_device_map(CUDA, -1, 2); | |||||
run(2); | |||||
L::set_device_map(CUDA, -1, 1); | |||||
run(1); | |||||
L::set_device_map(CPU, -1, 2); | |||||
run_cpu(2); | |||||
L::set_device_map(CPU, -1, 1); | |||||
run_cpu(1); | |||||
} MGB_FINALLY({ | } MGB_FINALLY({ | ||||
L::set_unspec_device_type(orig_dt.type); | L::set_unspec_device_type(orig_dt.type); | ||||
L::set_device_map(CUDA, -1, orig_gpu.device); | L::set_device_map(CUDA, -1, orig_gpu.device); | ||||
L::set_device_map(CPU, -1, orig_cpu.device); | |||||
}); | }); | ||||
CompNode::finalize(); | CompNode::finalize(); | ||||
} | } | ||||