|
|
@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
cudaStream_t get_stream(VarNode* var) { |
|
|
|
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; |
|
|
|
} |
|
|
|
} // anonymous namespace |
|
|
|
|
|
|
|
/* ================= ModeTrait ================= */ |
|
|
@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm( |
|
|
|
// add input |
|
|
|
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size()); |
|
|
|
if (inputs.size() > 0) { |
|
|
|
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA, |
|
|
|
"CollectiveComm currectly only supports CUDA"); |
|
|
|
add_input({inputs[0]}); |
|
|
|
} |
|
|
|
|
|
|
@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm( |
|
|
|
const auto& cns = config.comp_node(); |
|
|
|
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size()); |
|
|
|
if (cns.size() > 0) { |
|
|
|
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA, |
|
|
|
"CollectiveComm currectly only supports CUDA"); |
|
|
|
output(0)->comp_node(cns[0]); |
|
|
|
} else { |
|
|
|
output(0)->comp_node(inputs[0]->comp_node()); |
|
|
@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() { |
|
|
|
reg_info.hash, m_key, m_nr_devices, m_rank, |
|
|
|
get_megray_backend(m_backend), m_group_client); |
|
|
|
|
|
|
|
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); |
|
|
|
m_megray_ctx = get_megray_context(output(0)->comp_node()); |
|
|
|
|
|
|
|
m_init = true; |
|
|
|
} |
|
|
|