|
|
@@ -410,6 +410,11 @@ CollectiveComm::CollectiveComm( |
|
|
|
ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); |
|
|
|
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); |
|
|
|
|
|
|
|
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); |
|
|
|
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { |
|
|
|
m_debug_mode = true; |
|
|
|
} |
|
|
|
|
|
|
|
add_equivalence_component<PODHash<Param>>(&m_param); |
|
|
|
add_equivalence_component<PODHash<size_t>>(&m_nr_devices); |
|
|
|
m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest(); |
|
|
@@ -536,6 +541,11 @@ void CollectiveComm::do_execute(ExecEnv& env) { |
|
|
|
opr_register(); |
|
|
|
cn.activate(); |
|
|
|
|
|
|
|
if (m_debug_mode) { |
|
|
|
mgb_log_debug("collective comm: executing %s, rank = %d, key = %s", |
|
|
|
cname(), rank(), key().c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn); |
|
|
|
trait.exec(this); |
|
|
|
owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn); |
|
|
|