|
|
@@ -6,7 +6,8 @@ |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
* software distributed under the License is distributed on an |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
|
* implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "megdnn/basic_types.h" |
|
|
@@ -65,7 +66,7 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, |
|
|
|
// only enable midout for CPU, becuase CPU might be unused when some |
|
|
|
// other platforms are used |
|
|
|
MIDOUT_BEGIN(HandlePlatform, midout_iv(megcorePlatformCPU)) { |
|
|
|
// CPU |
|
|
|
// CPU |
|
|
|
#if MEGDNN_NAIVE |
|
|
|
return make_unique<naive::HandleImpl>(computing_handle); |
|
|
|
#else |
|
|
@@ -90,91 +91,92 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, |
|
|
|
} else { |
|
|
|
megdnn_throw("Debug level must be 0/1/2."); |
|
|
|
} |
|
|
|
} |
|
|
|
MIDOUT_END(); |
|
|
|
#endif |
|
|
|
} |
|
|
|
else if (platform == megcorePlatformROCM) { |
|
|
|
MIDOUT_END(); |
|
|
|
|
|
|
|
} |
|
|
|
else if (platform == megcorePlatformROCM) { |
|
|
|
#if MEGDNN_WITH_ROCM |
|
|
|
return make_rocm_handle(computing_handle); |
|
|
|
return make_rocm_handle(computing_handle); |
|
|
|
#else |
|
|
|
return nullptr; |
|
|
|
return nullptr; |
|
|
|
#endif |
|
|
|
} |
|
|
|
else if (platform == megcorePlatformCambricon) { |
|
|
|
} else if (platform == megcorePlatformCambricon) { |
|
|
|
#if MEGDNN_WITH_CAMBRICON |
|
|
|
return make_unique<cambricon::HandleImpl>(computing_handle); |
|
|
|
return make_unique<cambricon::HandleImpl>(computing_handle); |
|
|
|
#else |
|
|
|
return nullptr; |
|
|
|
return nullptr; |
|
|
|
#endif |
|
|
|
} |
|
|
|
else if (platform == megcorePlatformAtlas) { |
|
|
|
} else if (platform == megcorePlatformAtlas) { |
|
|
|
#if MEGDNN_WITH_ATLAS |
|
|
|
return make_unique<atlas::HandleImpl>(computing_handle); |
|
|
|
return make_unique<atlas::HandleImpl>(computing_handle); |
|
|
|
#else |
|
|
|
return nullptr; |
|
|
|
return nullptr; |
|
|
|
#endif |
|
|
|
} |
|
|
|
else { |
|
|
|
// CUDA |
|
|
|
megdnn_throw_if(platform != megcorePlatformCUDA, megdnn_error, |
|
|
|
"platform should be CUDA Platform"); |
|
|
|
} |
|
|
|
else { |
|
|
|
// CUDA |
|
|
|
megdnn_throw_if(platform != megcorePlatformCUDA, megdnn_error, |
|
|
|
"platform should be CUDA Platform"); |
|
|
|
#if MEGDNN_WITH_CUDA |
|
|
|
return make_unique<cuda::HandleImpl>(computing_handle); |
|
|
|
return make_unique<cuda::HandleImpl>(computing_handle); |
|
|
|
#else |
|
|
|
return nullptr; |
|
|
|
#endif |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void Handle::set_destructor(const thin_function<void()>& d) { |
|
|
|
megdnn_assert(!m_destructor, "destructor can be set only once"); |
|
|
|
m_destructor = d; |
|
|
|
} |
|
|
|
|
|
|
|
Handle::~Handle() { |
|
|
|
if (m_destructor) |
|
|
|
m_destructor(); |
|
|
|
m_alive_magic = 0; |
|
|
|
} |
|
|
|
|
|
|
|
size_t Handle::alignment_requirement() const { |
|
|
|
// default to 32 |
|
|
|
return 32; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void Handle::set_destructor(const thin_function<void()>& d) { |
|
|
|
megdnn_assert(!m_destructor, "destructor can be set only once"); |
|
|
|
m_destructor = d; |
|
|
|
} |
|
|
|
|
|
|
|
Handle::~Handle() { |
|
|
|
if (m_destructor) |
|
|
|
m_destructor(); |
|
|
|
m_alive_magic = 0; |
|
|
|
} |
|
|
|
|
|
|
|
size_t Handle::alignment_requirement() const { |
|
|
|
// default to 32 |
|
|
|
return 32; |
|
|
|
} |
|
|
|
|
|
|
|
size_t Handle::image2d_pitch_alignment() const { |
|
|
|
megdnn_throw("image2d tensor format not supported on this handle"); |
|
|
|
} |
|
|
|
|
|
|
|
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const { |
|
|
|
return HandleVendorType::NOT_SPEC; |
|
|
|
} |
|
|
|
|
|
|
|
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) { |
|
|
|
return src.is_contiguous(); |
|
|
|
} |
|
|
|
|
|
|
|
void Handle::on_opr_destructed(OperatorBase* opr) { |
|
|
|
if (m_alive_magic != ALIVE_MAGIC) { |
|
|
|
megdnn_log_error( |
|
|
|
"Handle is destructed before opr gets destructed. " |
|
|
|
"Please fix the destruction order as this would cause " |
|
|
|
"undefined memory access. " |
|
|
|
"Abort now to avoid further problems."); |
|
|
|
abort(); |
|
|
|
} |
|
|
|
|
|
|
|
size_t Handle::image2d_pitch_alignment() const { |
|
|
|
megdnn_throw("image2d tensor format not supported on this handle"); |
|
|
|
if (m_on_opr_destructed) { |
|
|
|
m_on_opr_destructed(opr); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const { |
|
|
|
return HandleVendorType::NOT_SPEC; |
|
|
|
} |
|
|
|
OperatorBase::~OperatorBase() { |
|
|
|
m_handle->on_opr_destructed(this); |
|
|
|
} |
|
|
|
|
|
|
|
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) { |
|
|
|
return src.is_contiguous(); |
|
|
|
} |
|
|
|
|
|
|
|
void Handle::on_opr_destructed(OperatorBase * opr) { |
|
|
|
if (m_alive_magic != ALIVE_MAGIC) { |
|
|
|
megdnn_log_error( |
|
|
|
"Handle is destructed before opr gets destructed. " |
|
|
|
"Please fix the destruction order as this would cause " |
|
|
|
"undefined memory access. " |
|
|
|
"Abort now to avoid further problems."); |
|
|
|
abort(); |
|
|
|
} |
|
|
|
if (m_on_opr_destructed) { |
|
|
|
m_on_opr_destructed(opr); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
OperatorBase::~OperatorBase() { m_handle->on_opr_destructed(this); } |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
std::unique_ptr<Opr> Handle::create_operator() { |
|
|
|
template <typename Opr> |
|
|
|
std::unique_ptr<Opr> Handle::create_operator() { |
|
|
|
#define CASE(etype, nm) \ |
|
|
|
case HandleType::etype: { \ |
|
|
|
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::etype)) { \ |
|
|
@@ -183,48 +185,47 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, |
|
|
|
MIDOUT_END(); \ |
|
|
|
} |
|
|
|
|
|
|
|
switch (m_handle_type) { |
|
|
|
CASE(NAIVE, naive); |
|
|
|
switch (m_handle_type) { |
|
|
|
CASE(NAIVE, naive); |
|
|
|
#if !MEGDNN_NAIVE |
|
|
|
CASE(FALLBACK, fallback); |
|
|
|
CASE(FALLBACK, fallback); |
|
|
|
#if MEGDNN_X86 |
|
|
|
CASE(X86, x86); |
|
|
|
CASE(X86, x86); |
|
|
|
#endif |
|
|
|
#if MEGDNN_ARMV7 |
|
|
|
CASE(ARMV7, armv7); |
|
|
|
CASE(ARMV7, armv7); |
|
|
|
#endif |
|
|
|
#if MEGDNN_AARCH64 |
|
|
|
CASE(AARCH64, aarch64); |
|
|
|
CASE(AARCH64, aarch64); |
|
|
|
#endif |
|
|
|
#if MEGDNN_ARMV7 || MEGDNN_AARCH64 |
|
|
|
CASE(ARM_COMMON, arm_common); |
|
|
|
CASE(ARM_COMMON, arm_common); |
|
|
|
#endif |
|
|
|
#endif // !MEGDNN_NAIVE |
|
|
|
#if MEGDNN_WITH_CUDA |
|
|
|
CASE(CUDA,cuda); |
|
|
|
CASE(CUDA, cuda); |
|
|
|
#endif |
|
|
|
#if MEGDNN_WITH_ATLAS |
|
|
|
CASE(ATLAS, atlas); |
|
|
|
CASE(ATLAS, atlas); |
|
|
|
#endif |
|
|
|
#if MEGDNN_WITH_ROCM |
|
|
|
case HandleType::ROCM: { |
|
|
|
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::ROCM)) { |
|
|
|
return create_rocm_operator<Opr>(); |
|
|
|
} |
|
|
|
MIDOUT_END(); |
|
|
|
case HandleType::ROCM: { |
|
|
|
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::ROCM)) { |
|
|
|
return create_rocm_operator<Opr>(); |
|
|
|
} |
|
|
|
MIDOUT_END(); |
|
|
|
} |
|
|
|
#endif |
|
|
|
#if MEGDNN_WITH_CAMBRICON |
|
|
|
CASE(CAMBRICON, cambricon); |
|
|
|
#endif |
|
|
|
default: |
|
|
|
megdnn_throw("bad handle type"); |
|
|
|
} |
|
|
|
#undef CASE |
|
|
|
default: |
|
|
|
megdnn_throw("bad handle type"); |
|
|
|
} |
|
|
|
#undef CASE |
|
|
|
} |
|
|
|
|
|
|
|
#define INST(opr) template std::unique_ptr<opr> Handle::create_operator(); |
|
|
|
MEGDNN_FOREACH_OPR_CLASS(INST) |
|
|
|
MEGDNN_FOREACH_OPR_CLASS(INST) |
|
|
|
#undef INST |
|
|
|
// vim: syntax=cpp.doxygen |
|
|
|
|