Browse Source

feat(mgb): add enflame comp node

GitOrigin-RevId: 478c8538aa
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
9f2af2099c
10 changed files with 27 additions and 10 deletions
  1. +1
    -0
      dnn/src/common/handle.cpp
  2. +1
    -0
      dnn/src/common/megcore/common/computing_context.cpp
  3. +3
    -2
      src/core/impl/comp_node/comp_node.cpp
  4. +6
    -7
      src/core/impl/comp_node/cpu/comp_node.cpp
  5. +4
    -0
      src/core/impl/comp_node_env.cpp
  6. +1
    -0
      src/core/impl/exception.cpp
  7. +1
    -1
      src/core/include/megbrain/comp_node.h
  8. +3
    -0
      src/core/include/megbrain/comp_node_env.h
  9. +5
    -0
      src/core/include/megbrain/exception.h
  10. +2
    -0
      src/core/test/comp_node.cpp

+ 1
- 0
dnn/src/common/handle.cpp View File

@@ -31,6 +31,7 @@
#endif #endif





#if MEGDNN_WITH_CUDA #if MEGDNN_WITH_CUDA
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#endif #endif


+ 1
- 0
dnn/src/common/megcore/common/computing_context.cpp View File

@@ -18,6 +18,7 @@
#endif #endif





#if MEGDNN_WITH_ROCM #if MEGDNN_WITH_ROCM
#include "src/rocm/megcore/computing_context.hpp" #include "src/rocm/megcore/computing_context.hpp"
#endif #endif


+ 3
- 2
src/core/impl/comp_node/comp_node.cpp View File

@@ -182,7 +182,8 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
} }
dev_type = DeviceType::MULTITHREAD; dev_type = DeviceType::MULTITHREAD;
ptr += 11; ptr += 11;
} else {
}
else {
if (ptr[1] != 'p' || ptr[2] != 'u') { if (ptr[1] != 'p' || ptr[2] != 'u') {
err(); err();
} }
@@ -237,7 +238,7 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
//! num_steam store the nr_thread //! num_steam store the nr_thread
std::swap(num_dev, num_stream); std::swap(num_dev, num_stream);
} }
return {dev_type, num_dev, {num_stream}}; return {dev_type, num_dev, {num_stream}};
} }




+ 6
- 7
src/core/impl/comp_node/cpu/comp_node.cpp View File

@@ -1021,13 +1021,12 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(


{ {
auto type = cn_impl->env().property().type; auto type = cn_impl->env().property().type;
mgb_throw_if(
type != CompNode::DeviceType::CPU &&
type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS &&
type != CompNode::DeviceType::CAMBRICON,
MegBrainError,
"currently CPU can only wait for CPU, CUDA, ATLAS, CAMBRICON"
mgb_throw_if(type != CompNode::DeviceType::CPU
&& type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS
,
MegBrainError,
"currently CPU can only wait for CPU, CUDA, ATLAS"
); );
} }




+ 4
- 0
src/core/impl/comp_node_env.cpp View File

@@ -36,6 +36,7 @@
#endif #endif





using namespace mgb; using namespace mgb;


/* =================== MegDNNHandle =================== */ /* =================== MegDNNHandle =================== */
@@ -232,6 +233,7 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node,
} }
#endif #endif



#if MGB_ATLAS #if MGB_ATLAS


void mgb::_on_atlas_error(const char* expr, int err, const char* file, void mgb::_on_atlas_error(const char* expr, int err, const char* file,
@@ -421,6 +423,7 @@ void CompNodeEnv::fini() {
MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream));
} }
#endif #endif

#if MGB_ROCM #if MGB_ROCM
if (m_property.type == DeviceType::ROCM) { if (m_property.type == DeviceType::ROCM) {
m_rocm_env.activate(); m_rocm_env.activate();
@@ -440,6 +443,7 @@ void CompNodeEnv::fini() {
MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream)); MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream));
} }
#endif #endif

} }


#if MGB_ENABLE_COMP_NODE_ASYNC_INIT #if MGB_ENABLE_COMP_NODE_ASYNC_INIT


+ 1
- 0
src/core/impl/exception.cpp View File

@@ -73,6 +73,7 @@ std::string CudaError::get_cuda_extra_info() {
#endif #endif
} }



AtlasError::AtlasError(const std::string &msg): AtlasError::AtlasError(const std::string &msg):
SystemError(msg) SystemError(msg)
{ {


+ 1
- 1
src/core/include/megbrain/comp_node.h View File

@@ -82,7 +82,7 @@ class CompNode {
CAMBRICON = 3, CAMBRICON = 3,
ROCM = 8, ROCM = 8,
ATLAS = 9, ATLAS = 9,
MULTITHREAD,
MULTITHREAD = 11,
MAX_DEVICE_ID, MAX_DEVICE_ID,
}; };
static constexpr size_t NR_DEVICE_TYPE = static constexpr size_t NR_DEVICE_TYPE =


+ 3
- 0
src/core/include/megbrain/comp_node_env.h View File

@@ -63,6 +63,7 @@
#endif //MGB_ENABLE_LOGGING #endif //MGB_ENABLE_LOGGING
#endif //MGB_CUDA #endif //MGB_CUDA



#if MGB_ATLAS #if MGB_ATLAS
#include "megcore_atlas.h" #include "megcore_atlas.h"
#include <atomic> #include <atomic>
@@ -205,6 +206,7 @@ namespace mgb {
#endif #endif





#if MGB_ROCM #if MGB_ROCM
[[noreturn]] void _on_hip_error(const char* expr, hipError_t err, [[noreturn]] void _on_hip_error(const char* expr, hipError_t err,
const char* file, const char* func, int line); const char* file, const char* func, int line);
@@ -369,6 +371,7 @@ public:
const ContinuationCtx<cudaStream_t>& cont); const ContinuationCtx<cudaStream_t>& cont);
#endif #endif



#if MGB_ATLAS #if MGB_ATLAS
struct AtlasEnv { struct AtlasEnv {
int device = -1; int device = -1;


+ 5
- 0
src/core/include/megbrain/exception.h View File

@@ -139,6 +139,11 @@ public:
CudaError(const std::string& msg); CudaError(const std::string& msg);
}; };


class EnFlameError final : public SystemError {
public:
EnFlameError(const std::string& msg);
};

class AtlasError final: public SystemError { class AtlasError final: public SystemError {
public: public:
AtlasError(const std::string& msg); AtlasError(const std::string& msg);


+ 2
- 0
src/core/test/comp_node.cpp View File

@@ -166,6 +166,7 @@ TEST(TestCompNode, Load) {
ASSERT_NE(atlas0, atlas1); ASSERT_NE(atlas0, atlas1);
#endif #endif



} }


TEST(TestCompNode, FreeAfterFinalize) { TEST(TestCompNode, FreeAfterFinalize) {
@@ -754,6 +755,7 @@ TEST(TestCompNodeCambricon, P2PCopy) {
#endif #endif
#endif // MGB_CAMBRICON #endif // MGB_CAMBRICON



#if MGB_ATLAS #if MGB_ATLAS


TEST(TestCompNodeAtlas, D2DCopy) { TEST(TestCompNodeAtlas, D2DCopy) {


Loading…
Cancel
Save