|
@@ -408,7 +408,6 @@ CollectiveComm::CollectiveComm( |
|
|
"CollectiveComm inputs should not contain duplicated input device"); |
|
|
"CollectiveComm inputs should not contain duplicated input device"); |
|
|
|
|
|
|
|
|
ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); |
|
|
ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); |
|
|
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); |
|
|
|
|
|
|
|
|
|
|
|
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); |
|
|
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); |
|
|
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { |
|
|
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { |
|
@@ -469,6 +468,8 @@ void CollectiveComm::opr_register() { |
|
|
hash, m_key, m_nr_devices, m_rank, |
|
|
hash, m_key, m_nr_devices, m_rank, |
|
|
get_megray_backend(m_backend), m_group_client); |
|
|
get_megray_backend(m_backend), m_group_client); |
|
|
|
|
|
|
|
|
|
|
|
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); |
|
|
|
|
|
|
|
|
m_init = true; |
|
|
m_init = true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|