Browse Source

fix(mgb/opr-mm): fix m_megray_ctx init

GitOrigin-RevId: 7804fbe2ef
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
2bbce2f924
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      src/opr-mm/impl/collective_comm.cpp

+ 2
- 1
src/opr-mm/impl/collective_comm.cpp View File

@@ -408,7 +408,6 @@ CollectiveComm::CollectiveComm(
"CollectiveComm inputs should not contain duplicated input device"); "CollectiveComm inputs should not contain duplicated input device");


ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));


const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG");
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { if (c_debug != nullptr and strcmp(c_debug, "1") == 0) {
@@ -469,6 +468,8 @@ void CollectiveComm::opr_register() {
hash, m_key, m_nr_devices, m_rank, hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client); get_megray_backend(m_backend), m_group_client);


m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));

m_init = true; m_init = true;
} }




Loading…
Cancel
Save