|
|
@@ -41,26 +41,22 @@ AtlasComputingContext::~AtlasComputingContext() { |
|
|
|
void AtlasComputingContext::memcpy(void* dst, const void* src, |
|
|
|
size_t size_in_bytes, |
|
|
|
megcoreMemcpyKind_t kind) { |
|
|
|
aclrtMemcpyKind atlas_kind; |
|
|
|
switch (kind) { |
|
|
|
case megcoreMemcpyDeviceToHost: |
|
|
|
atlas_kind = ACL_MEMCPY_DEVICE_TO_HOST; |
|
|
|
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, |
|
|
|
ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
|
break; |
|
|
|
case megcoreMemcpyHostToDevice: |
|
|
|
atlas_kind = ACL_MEMCPY_HOST_TO_DEVICE; |
|
|
|
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, |
|
|
|
ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
|
break; |
|
|
|
case megcoreMemcpyDeviceToDevice: |
|
|
|
atlas_kind = ACL_MEMCPY_DEVICE_TO_DEVICE; |
|
|
|
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes, |
|
|
|
ACL_MEMCPY_DEVICE_TO_DEVICE, m_ctx.stream)); |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_throw("bad atlas memcpy kind"); |
|
|
|
} |
|
|
|
#if MGB_USE_ATLAS_ASYNC_API |
|
|
|
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes, |
|
|
|
atlas_kind, m_ctx.stream)); |
|
|
|
#else |
|
|
|
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, atlas_kind)); |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) { |
|
|
@@ -69,11 +65,7 @@ void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) { |
|
|
|
} |
|
|
|
|
|
|
|
void AtlasComputingContext::synchronize() { |
|
|
|
#if MGB_USE_ATLAS_ASYNC_API |
|
|
|
acl_check(aclrtSynchronizeStream(m_ctx.stream)); |
|
|
|
#else |
|
|
|
return; |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen |