@@ -31,6 +31,7 @@ | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#endif | #endif | ||||
@@ -18,6 +18,7 @@ | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_ROCM | #if MEGDNN_WITH_ROCM | ||||
#include "src/rocm/megcore/computing_context.hpp" | #include "src/rocm/megcore/computing_context.hpp" | ||||
#endif | #endif | ||||
@@ -182,7 +182,8 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { | |||||
} | } | ||||
dev_type = DeviceType::MULTITHREAD; | dev_type = DeviceType::MULTITHREAD; | ||||
ptr += 11; | ptr += 11; | ||||
} else { | |||||
} | |||||
else { | |||||
if (ptr[1] != 'p' || ptr[2] != 'u') { | if (ptr[1] != 'p' || ptr[2] != 'u') { | ||||
err(); | err(); | ||||
} | } | ||||
@@ -237,7 +238,7 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { | |||||
//! num_steam store the nr_thread | //! num_steam store the nr_thread | ||||
std::swap(num_dev, num_stream); | std::swap(num_dev, num_stream); | ||||
} | } | ||||
return {dev_type, num_dev, {num_stream}}; | return {dev_type, num_dev, {num_stream}}; | ||||
} | } | ||||
@@ -1021,13 +1021,12 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( | |||||
{ | { | ||||
auto type = cn_impl->env().property().type; | auto type = cn_impl->env().property().type; | ||||
mgb_throw_if( | |||||
type != CompNode::DeviceType::CPU && | |||||
type != CompNode::DeviceType::CUDA | |||||
&& type != CompNode::DeviceType::ATLAS && | |||||
type != CompNode::DeviceType::CAMBRICON, | |||||
MegBrainError, | |||||
"currently CPU can only wait for CPU, CUDA, ATLAS, CAMBRICON" | |||||
mgb_throw_if(type != CompNode::DeviceType::CPU | |||||
&& type != CompNode::DeviceType::CUDA | |||||
&& type != CompNode::DeviceType::ATLAS | |||||
, | |||||
MegBrainError, | |||||
"currently CPU can only wait for CPU, CUDA, ATLAS" | |||||
); | ); | ||||
} | } | ||||
@@ -36,6 +36,7 @@ | |||||
#endif | #endif | ||||
using namespace mgb; | using namespace mgb; | ||||
/* =================== MegDNNHandle =================== */ | /* =================== MegDNNHandle =================== */ | ||||
@@ -232,6 +233,7 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node, | |||||
} | } | ||||
#endif | #endif | ||||
#if MGB_ATLAS | #if MGB_ATLAS | ||||
void mgb::_on_atlas_error(const char* expr, int err, const char* file, | void mgb::_on_atlas_error(const char* expr, int err, const char* file, | ||||
@@ -421,6 +423,7 @@ void CompNodeEnv::fini() { | |||||
MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); | MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); | ||||
} | } | ||||
#endif | #endif | ||||
#if MGB_ROCM | #if MGB_ROCM | ||||
if (m_property.type == DeviceType::ROCM) { | if (m_property.type == DeviceType::ROCM) { | ||||
m_rocm_env.activate(); | m_rocm_env.activate(); | ||||
@@ -440,6 +443,7 @@ void CompNodeEnv::fini() { | |||||
MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream)); | MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream)); | ||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
#if MGB_ENABLE_COMP_NODE_ASYNC_INIT | #if MGB_ENABLE_COMP_NODE_ASYNC_INIT | ||||
@@ -73,6 +73,7 @@ std::string CudaError::get_cuda_extra_info() { | |||||
#endif | #endif | ||||
} | } | ||||
AtlasError::AtlasError(const std::string &msg): | AtlasError::AtlasError(const std::string &msg): | ||||
SystemError(msg) | SystemError(msg) | ||||
{ | { | ||||
@@ -82,7 +82,7 @@ class CompNode { | |||||
CAMBRICON = 3, | CAMBRICON = 3, | ||||
ROCM = 8, | ROCM = 8, | ||||
ATLAS = 9, | ATLAS = 9, | ||||
MULTITHREAD, | |||||
MULTITHREAD = 11, | |||||
MAX_DEVICE_ID, | MAX_DEVICE_ID, | ||||
}; | }; | ||||
static constexpr size_t NR_DEVICE_TYPE = | static constexpr size_t NR_DEVICE_TYPE = | ||||
@@ -63,6 +63,7 @@ | |||||
#endif //MGB_ENABLE_LOGGING | #endif //MGB_ENABLE_LOGGING | ||||
#endif //MGB_CUDA | #endif //MGB_CUDA | ||||
#if MGB_ATLAS | #if MGB_ATLAS | ||||
#include "megcore_atlas.h" | #include "megcore_atlas.h" | ||||
#include <atomic> | #include <atomic> | ||||
@@ -205,6 +206,7 @@ namespace mgb { | |||||
#endif | #endif | ||||
#if MGB_ROCM | #if MGB_ROCM | ||||
[[noreturn]] void _on_hip_error(const char* expr, hipError_t err, | [[noreturn]] void _on_hip_error(const char* expr, hipError_t err, | ||||
const char* file, const char* func, int line); | const char* file, const char* func, int line); | ||||
@@ -369,6 +371,7 @@ public: | |||||
const ContinuationCtx<cudaStream_t>& cont); | const ContinuationCtx<cudaStream_t>& cont); | ||||
#endif | #endif | ||||
#if MGB_ATLAS | #if MGB_ATLAS | ||||
struct AtlasEnv { | struct AtlasEnv { | ||||
int device = -1; | int device = -1; | ||||
@@ -139,6 +139,11 @@ public: | |||||
CudaError(const std::string& msg); | CudaError(const std::string& msg); | ||||
}; | }; | ||||
class EnFlameError final : public SystemError { | |||||
public: | |||||
EnFlameError(const std::string& msg); | |||||
}; | |||||
class AtlasError final: public SystemError { | class AtlasError final: public SystemError { | ||||
public: | public: | ||||
AtlasError(const std::string& msg); | AtlasError(const std::string& msg); | ||||
@@ -166,6 +166,7 @@ TEST(TestCompNode, Load) { | |||||
ASSERT_NE(atlas0, atlas1); | ASSERT_NE(atlas0, atlas1); | ||||
#endif | #endif | ||||
} | } | ||||
TEST(TestCompNode, FreeAfterFinalize) { | TEST(TestCompNode, FreeAfterFinalize) { | ||||
@@ -754,6 +755,7 @@ TEST(TestCompNodeCambricon, P2PCopy) { | |||||
#endif | #endif | ||||
#endif // MGB_CAMBRICON | #endif // MGB_CAMBRICON | ||||
#if MGB_ATLAS | #if MGB_ATLAS | ||||
TEST(TestCompNodeAtlas, D2DCopy) { | TEST(TestCompNodeAtlas, D2DCopy) { | ||||