|
|
@@ -45,7 +45,7 @@ |
|
|
|
#endif //MGB_CUDA |
|
|
|
|
|
|
|
#if MGB_ATLAS |
|
|
|
#include "acl/acl.h" |
|
|
|
#include "megcore_atlas.h" |
|
|
|
#include <atomic> |
|
|
|
|
|
|
|
#if MGB_ENABLE_LOGGING |
|
|
@@ -378,7 +378,16 @@ public: |
|
|
|
|
|
|
|
void activate() const { |
|
|
|
init(); |
|
|
|
MGB_ATLAS_CHECK(aclrtSetDevice(device)); |
|
|
|
int32_t device_id = -1; |
|
|
|
auto err = aclrtGetDevice(&device_id); |
|
|
|
if (err == ACL_ERROR_INVALID_DEVICE || device != device_id) { |
|
|
|
MGB_ATLAS_CHECK(aclrtSetDevice(device)); |
|
|
|
} else { |
|
|
|
MGB_ATLAS_CHECK(err); |
|
|
|
mgb_assert(err == ACL_ERROR_NONE, |
|
|
|
"Failed to invoke aclrtGetDevice, get %s(%d)", |
|
|
|
megcore::atlas::get_error_str(err), err); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|