|
|
@@ -518,6 +518,60 @@ DEF(sub, )(const SubTensorSpec &spec) const { |
|
|
|
// def } |
|
|
|
|
|
|
|
/* ===================== TensorND::copy_from ===================== */ |
|
|
|
namespace { |
|
|
|
/** |
|
|
|
* \brief determine whether to check overlap of two tensors. |
|
|
|
* \return true : when HostStorage || (DeviceStorage && SUPPORT_UNIFIED_ADDRESS) |
|
|
|
* \note when both support unified address, we can treat them both on CPU. So, |
|
|
|
* overlap check should be done |
|
|
|
*/ |
|
|
|
template <typename TensorStorage, typename RStorage> |
|
|
|
inline bool should_check_overlap(const TensorND<TensorStorage>& dst, |
|
|
|
const TensorND<RStorage>& src) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
inline bool should_check_overlap<HostTensorStorage, DeviceTensorStorage>( |
|
|
|
const HostTensorND& dst, const DeviceTensorND& src) { |
|
|
|
return src.comp_node().contain_flag( |
|
|
|
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS); |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
inline bool should_check_overlap<DeviceTensorStorage, HostTensorStorage>( |
|
|
|
const DeviceTensorND& dst, const HostTensorND& src) { |
|
|
|
return dst.comp_node().contain_flag( |
|
|
|
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS); |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* \brief D2D tensor copy should check overlap when |
|
|
|
* 1. They are on the same mem node. But note that the address must be logical |
|
|
|
* comparable. i.e. the original address alloc on enflame is uncomparable. |
|
|
|
* 2. They both support unified address, so can be treated as CPU address. |
|
|
|
*/ |
|
|
|
template <> |
|
|
|
inline bool should_check_overlap<DeviceTensorStorage, DeviceTensorStorage>( |
|
|
|
const DeviceTensorND& dst, const DeviceTensorND& src) { |
|
|
|
bool is_same_memnode = |
|
|
|
dst.comp_node().mem_node() == src.comp_node().mem_node(); |
|
|
|
bool unified_address = src.comp_node().contain_flag( |
|
|
|
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS) && |
|
|
|
dst.comp_node().contain_flag( |
|
|
|
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS); |
|
|
|
return is_same_memnode || unified_address; |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* \brief check overlap of two tensors. throw exception when overlapped |
|
|
|
*/ |
|
|
|
inline void check_overlapped(const dt_byte* dst_min, const dt_byte* dst_max, |
|
|
|
const dt_byte* src_min, const dt_byte* src_max) { |
|
|
|
mgb_throw_if(src_min < dst_max && dst_min < src_max, TensorCopyOverlapError, |
|
|
|
"cound not perform copy between overlapped tensors"); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
template<class TensorStorage> |
|
|
|
template<class RStorage> |
|
|
@@ -539,12 +593,12 @@ TensorND<TensorStorage>::copy_from(const TensorND<RStorage> &src) { |
|
|
|
return static_cast<ChainReturnType&>(*this); |
|
|
|
} |
|
|
|
if (src.layout().is_physical_contiguous()) { |
|
|
|
const dt_byte |
|
|
|
*dst_min = m_storage.ptr(), *dst_max = dst_min + size_bytes, |
|
|
|
*src_min = src.storage().ptr(), *src_max = src_min + size_bytes; |
|
|
|
mgb_throw_if(src_max > dst_min && dst_max > src_min, |
|
|
|
TensorCopyOverlapError, |
|
|
|
"cound not perform copy between overlapped tensors"); |
|
|
|
if (should_check_overlap(*this, src)) { |
|
|
|
check_overlapped(m_storage.ptr(), |
|
|
|
m_storage.ptr() + size_bytes, |
|
|
|
src.storage().ptr(), |
|
|
|
src.storage().ptr() + size_bytes); |
|
|
|
} |
|
|
|
m_storage.copy_from(src.storage(), size_bytes); |
|
|
|
return static_cast<ChainReturnType&>(*this); |
|
|
|
} |
|
|
@@ -574,15 +628,12 @@ TensorND<TensorStorage>::copy_from_fixlayout( |
|
|
|
src_span = src.layout().span(), |
|
|
|
dst_span = layout().span(); |
|
|
|
|
|
|
|
const dt_byte |
|
|
|
*src_ptr_min = src.raw_ptr() + src_span.low_byte, |
|
|
|
*src_ptr_max = src.raw_ptr() + src_span.high_byte, |
|
|
|
*dst_ptr_min = this->raw_ptr() + dst_span.low_byte, |
|
|
|
*dst_ptr_max = this->raw_ptr() + dst_span.high_byte; |
|
|
|
|
|
|
|
mgb_throw_if(src_ptr_max > dst_ptr_min && dst_ptr_max > src_ptr_min, |
|
|
|
TensorCopyOverlapError, |
|
|
|
"cound not perform copy between overlapped tensors"); |
|
|
|
if (should_check_overlap(*this, src)) { |
|
|
|
check_overlapped(this->raw_ptr() + dst_span.low_byte, |
|
|
|
this->raw_ptr() + dst_span.high_byte, |
|
|
|
src.raw_ptr() + src_span.low_byte, |
|
|
|
src.raw_ptr() + src_span.high_byte); |
|
|
|
} |
|
|
|
|
|
|
|
bool self_contig = m_layout.is_physical_contiguous(), |
|
|
|
src_contig = src.layout().is_physical_contiguous(); |
|
|
|