Browse Source

fix(mgb/tensor): do tensor overlap check only when d2d and h2h

GitOrigin-RevId: 9125936a35
release-1.2
Megvii Engine Team 4 years ago
parent
commit
cf53d9e0f8
7 changed files with 84 additions and 22 deletions
  1. +3
    -1
      src/core/impl/comp_node/atlas/comp_node.h
  2. +3
    -2
      src/core/impl/comp_node/cambricon/comp_node.h
  3. +2
    -1
      src/core/impl/comp_node/cpu/comp_node.h
  4. +3
    -2
      src/core/impl/comp_node/cuda/comp_node.h
  5. +3
    -1
      src/core/impl/comp_node/rocm/comp_node.h
  6. +66
    -15
      src/core/impl/tensor.cpp
  7. +4
    -0
      src/core/include/megbrain/comp_node.h

+ 3
- 1
src/core/impl/comp_node/atlas/comp_node.h View File

@@ -18,7 +18,9 @@
namespace mgb {
class AtlasCompNode final : public CompNodeImplHelper {
public:
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;

class CompNodeImpl;
class EventImpl;


+ 3
- 2
src/core/impl/comp_node/cambricon/comp_node.h View File

@@ -16,8 +16,9 @@
namespace mgb {
class CambriconCompNode final: public CompNodeImplHelper {
public:
static constexpr Flag sm_flag =
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;

class CompNodeImpl;
class EventImpl;


+ 2
- 1
src/core/impl/comp_node/cpu/comp_node.h View File

@@ -38,7 +38,8 @@ namespace mgb {
static constexpr Flag sm_flag =
Flag::SUPPORT_RECORDER |
Flag::RECORDER_SUPPORT_DYNAMIC_ALLOC |
Flag::EVENT_DTOR_UNSAFE;
Flag::EVENT_DTOR_UNSAFE |
Flag::SUPPORT_UNIFIED_ADDRESS;

//! base class for comp nodes that can be dispatched on CPU.
//! This is currently used by CPU, FPGA and CADENCE


+ 3
- 2
src/core/impl/comp_node/cuda/comp_node.h View File

@@ -16,8 +16,9 @@
namespace mgb {
class CudaCompNode final: public CompNodeImplHelper {
public:
static constexpr Flag sm_flag =
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;

class CompNodeImpl;
class EventImpl;


+ 3
- 1
src/core/impl/comp_node/rocm/comp_node.h View File

@@ -16,7 +16,9 @@
namespace mgb {
class ROCmCompNode final : public CompNodeImplHelper {
public:
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;

class CompNodeImpl;
class EventImpl;


+ 66
- 15
src/core/impl/tensor.cpp View File

@@ -518,6 +518,60 @@ DEF(sub, )(const SubTensorSpec &spec) const {
// def }

/* ===================== TensorND::copy_from ===================== */
namespace {
/**
* \brief determine whether to check overlap of two tensors.
* \return true : when HostStorage || (DeviceStorage && SUPPORT_UNIFIED_ADDRESS)
* \note when both support unified address, we can treat them both on CPU. So,
* overlap check should be done
*/
template <typename TensorStorage, typename RStorage>
inline bool should_check_overlap(const TensorND<TensorStorage>& dst,
const TensorND<RStorage>& src) {
return true;
}

template <>
inline bool should_check_overlap<HostTensorStorage, DeviceTensorStorage>(
const HostTensorND& dst, const DeviceTensorND& src) {
return src.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
}

template <>
inline bool should_check_overlap<DeviceTensorStorage, HostTensorStorage>(
const DeviceTensorND& dst, const HostTensorND& src) {
return dst.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
}

/**
* \brief D2D tensor copy should check overlap when
* 1. They are on the same mem node. But note that the address must be logical
* comparable. i.e. the original address alloc on enflame is uncomparable.
* 2. They both support unified address, so can be treated as CPU address.
*/
template <>
inline bool should_check_overlap<DeviceTensorStorage, DeviceTensorStorage>(
const DeviceTensorND& dst, const DeviceTensorND& src) {
bool is_same_memnode =
dst.comp_node().mem_node() == src.comp_node().mem_node();
bool unified_address = src.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS) &&
dst.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
return is_same_memnode || unified_address;
}

/**
* \brief check overlap of two tensors. throw exception when overlapped
*/
inline void check_overlapped(const dt_byte* dst_min, const dt_byte* dst_max,
const dt_byte* src_min, const dt_byte* src_max) {
mgb_throw_if(src_min < dst_max && dst_min < src_max, TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
}
} // namespace

template<class TensorStorage>
template<class RStorage>
@@ -539,12 +593,12 @@ TensorND<TensorStorage>::copy_from(const TensorND<RStorage> &src) {
return static_cast<ChainReturnType&>(*this);
}
if (src.layout().is_physical_contiguous()) {
const dt_byte
*dst_min = m_storage.ptr(), *dst_max = dst_min + size_bytes,
*src_min = src.storage().ptr(), *src_max = src_min + size_bytes;
mgb_throw_if(src_max > dst_min && dst_max > src_min,
TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
if (should_check_overlap(*this, src)) {
check_overlapped(m_storage.ptr(),
m_storage.ptr() + size_bytes,
src.storage().ptr(),
src.storage().ptr() + size_bytes);
}
m_storage.copy_from(src.storage(), size_bytes);
return static_cast<ChainReturnType&>(*this);
}
@@ -574,15 +628,12 @@ TensorND<TensorStorage>::copy_from_fixlayout(
src_span = src.layout().span(),
dst_span = layout().span();

const dt_byte
*src_ptr_min = src.raw_ptr() + src_span.low_byte,
*src_ptr_max = src.raw_ptr() + src_span.high_byte,
*dst_ptr_min = this->raw_ptr() + dst_span.low_byte,
*dst_ptr_max = this->raw_ptr() + dst_span.high_byte;

mgb_throw_if(src_ptr_max > dst_ptr_min && dst_ptr_max > src_ptr_min,
TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
if (should_check_overlap(*this, src)) {
check_overlapped(this->raw_ptr() + dst_span.low_byte,
this->raw_ptr() + dst_span.high_byte,
src.raw_ptr() + src_span.low_byte,
src.raw_ptr() + src_span.high_byte);
}

bool self_contig = m_layout.is_physical_contiguous(),
src_contig = src.layout().is_physical_contiguous();


+ 4
- 0
src/core/include/megbrain/comp_node.h View File

@@ -436,6 +436,10 @@ class CompNode {
//! MGB_HAVE_THREAD=0. Usually this means that execution on the
//! CompNode is synchronous, i.e. behaves like cpu:default
SUPPORT_NO_THREAD = 1 << 5,

//! Whether this comp node supports unified address. i.e. CPU and
//! CUDA supports unified address.
SUPPORT_UNIFIED_ADDRESS = 1 << 6,
};

bool contain_flag(Flag flag) {


Loading…
Cancel
Save