|
|
@@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() { |
|
|
|
for (size_t i = 0; i < ishp.ndim; i++) { |
|
|
|
data_size *= ishp[i]; |
|
|
|
} |
|
|
|
data_size *= tensor.dtype().size(); |
|
|
|
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); |
|
|
|
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, |
|
|
|
get_megray_dtype(tensor.dtype()), |
|
|
|
1, m_megray_ctx); |
|
|
|
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); |
|
|
|
|
|
|
|
if (m_is_grad) { |
|
|
@@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() { |
|
|
|
for (size_t i = 0; i < ishp.ndim; i++) { |
|
|
|
data_size *= ishp[i]; |
|
|
|
} |
|
|
|
data_size *= tensor.dtype().size(); |
|
|
|
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx); |
|
|
|
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, |
|
|
|
get_megray_dtype(tensor.dtype()), |
|
|
|
0, m_megray_ctx); |
|
|
|
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); |
|
|
|
} |
|
|
|
|
|
|
|