Browse Source

fix(atlas): add MGB_USE_ATLAS_ASYNC_API to enable async api

GitOrigin-RevId: ab821f4966
release-1.1
Megvii Engine Team 4 years ago
parent
commit
b8ddca4c38
3 changed files with 32 additions and 4 deletions
  1. +8
    -0
      dnn/src/atlas/megcore/computing_context.cpp
  2. +20
    -4
      src/core/impl/comp_node/atlas/comp_node.cpp
  3. +4
    -0
      src/core/impl/tensor.cpp

+ 8
- 0
dnn/src/atlas/megcore/computing_context.cpp View File

@@ -55,8 +55,12 @@ void AtlasComputingContext::memcpy(void* dst, const void* src,
default:
megdnn_throw("bad atlas memcpy kind");
}
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes,
atlas_kind, m_ctx.stream));
#else
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, atlas_kind));
#endif
}

void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
@@ -65,7 +69,11 @@ void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
}

void AtlasComputingContext::synchronize() {
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtSynchronizeStream(m_ctx.stream));
#else
return;
#endif
}

// vim: syntax=cpp.doxygen

+ 20
- 4
src/core/impl/comp_node/atlas/comp_node.cpp View File

@@ -104,9 +104,14 @@ public:
void copy_to_host(void* host_ptr, const void* device_ptr,
size_t size) override {
activate();
#if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK(aclrtMemcpyAsync(host_ptr, size, device_ptr, size,
ACL_MEMCPY_DEVICE_TO_HOST,
m_env.atlas_env().stream));
#else
MGB_ATLAS_CHECK(aclrtMemcpy(host_ptr, size, device_ptr, size,
ACL_MEMCPY_DEVICE_TO_HOST));
#endif
}

void copy_to_device(void* device_ptr, const void* host_ptr,
@@ -225,9 +230,14 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
auto&& src_env = m_env.atlas_env();
activate();
if (dst_env.device == src_env.device) {
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream));
#if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream));
#else
MGB_ATLAS_CHECK(aclrtMemcpy(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE));
#endif
} else {
mgb_throw(MegBrainError,
"Atlas does not support peer copy between differents "
@@ -239,12 +249,18 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
mgb_assert(dest_impl->env().property().type == DeviceType::CPU,
"cuda peer_copy_to only implemented for CPU");
auto copy = [this, dest, src, size]() {
auto stream = m_env.atlas_env().stream;
m_env.atlas_env().activate();

#if MGB_USE_ATLAS_ASYNC_API
auto stream = m_env.atlas_env().stream;
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_HOST,
m_env.atlas_env().stream));
MGB_ATLAS_CHECK(aclrtSynchronizeStream(stream));
#else
MGB_ATLAS_CHECK(
aclrtMemcpy(dest, size, src, size, ACL_MEMCPY_DEVICE_TO_HOST));
#endif
};
dest_impl->env().cpu_env().dispatch(copy);



+ 4
- 0
src/core/impl/tensor.cpp View File

@@ -614,8 +614,12 @@ void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) {
#endif
#if MGB_ATLAS
case CompNode::DeviceType::ATLAS:
#if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK(aclrtMemsetAsync(ptr, -1, val, size,
env.atlas_env().stream));
#else
MGB_ATLAS_CHECK(aclrtMemset(ptr, -1, val, size));
#endif
break;
#endif
case CompNode::DeviceType::CPU: {


Loading…
Cancel
Save