|
@@ -382,7 +382,7 @@ class ImmutableTensor::Value { |
|
|
void setup(CompNode cn, const HostTensorND &val); |
|
|
void setup(CompNode cn, const HostTensorND &val); |
|
|
|
|
|
|
|
|
bool initialized() const { |
|
|
bool initialized() const { |
|
|
return !m_dev.empty(); |
|
|
|
|
|
|
|
|
return m_dev.shape_valid(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
//! value on comp node |
|
|
//! value on comp node |
|
@@ -400,8 +400,9 @@ class ImmutableTensor::Value { |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) { |
|
|
void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) { |
|
|
mgb_assert(m_dev.empty() && !val.empty()); |
|
|
|
|
|
|
|
|
mgb_assert(m_dev.empty() && !m_dev.shape_valid()); |
|
|
m_dev.comp_node(cn).copy_from(val).sync(); |
|
|
m_dev.comp_node(cn).copy_from(val).sync(); |
|
|
|
|
|
mgb_assert(val.empty() == m_dev.empty()); |
|
|
|
|
|
|
|
|
auto one_elem = [](const TensorShape& shape) { |
|
|
auto one_elem = [](const TensorShape& shape) { |
|
|
for (size_t i = 0; i < shape.ndim; ++i) { |
|
|
for (size_t i = 0; i < shape.ndim; ++i) { |
|
@@ -446,6 +447,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { |
|
|
HostTensorND m_val_ref; |
|
|
HostTensorND m_val_ref; |
|
|
|
|
|
|
|
|
const dt_byte* val_ptr() const { |
|
|
const dt_byte* val_ptr() const { |
|
|
|
|
|
mgb_assert(m_trait.size_bytes); |
|
|
return m_val.empty() ? m_val_ref.raw_ptr() : m_val.data(); |
|
|
return m_val.empty() ? m_val_ref.raw_ptr() : m_val.data(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -454,9 +456,8 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { |
|
|
TensorKey(const HostTensorND &v): |
|
|
TensorKey(const HostTensorND &v): |
|
|
m_val_ref{v} |
|
|
m_val_ref{v} |
|
|
{ |
|
|
{ |
|
|
mgb_assert(v.layout().is_contiguous()); |
|
|
|
|
|
|
|
|
mgb_assert(v.layout().is_contiguous() || v.layout().is_empty()); |
|
|
m_trait.size_bytes = v.layout().span().high_byte; |
|
|
m_trait.size_bytes = v.layout().span().high_byte; |
|
|
mgb_assert(m_trait.size_bytes); |
|
|
|
|
|
|
|
|
|
|
|
auto &&layout = m_trait.layout; |
|
|
auto &&layout = m_trait.layout; |
|
|
// zero to enable byte-comparison |
|
|
// zero to enable byte-comparison |
|
@@ -467,15 +468,19 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { |
|
|
layout.shape[i] = v.layout().shape[i]; |
|
|
layout.shape[i] = v.layout().shape[i]; |
|
|
layout.stride[i] = v.layout().stride[i]; |
|
|
layout.stride[i] = v.layout().stride[i]; |
|
|
} |
|
|
} |
|
|
m_trait.hash = XXHash{}. |
|
|
|
|
|
update(v.raw_ptr(), m_trait.size_bytes). |
|
|
|
|
|
update(&m_trait.layout, sizeof(m_trait.layout)). |
|
|
|
|
|
digest(); |
|
|
|
|
|
|
|
|
XXHash hasher; |
|
|
|
|
|
if (!v.empty()) { |
|
|
|
|
|
hasher.update(v.raw_ptr(), m_trait.size_bytes); |
|
|
|
|
|
} |
|
|
|
|
|
hasher.update(&m_trait.layout, sizeof(m_trait.layout)); |
|
|
|
|
|
m_trait.hash = hasher.digest(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool operator == (const TensorKey &rhs) const { |
|
|
bool operator == (const TensorKey &rhs) const { |
|
|
return !memcmp(&m_trait, &rhs.m_trait, sizeof(Trait)) && |
|
|
return !memcmp(&m_trait, &rhs.m_trait, sizeof(Trait)) && |
|
|
!memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes); |
|
|
|
|
|
|
|
|
((m_trait.size_bytes == 0 && |
|
|
|
|
|
rhs.m_trait.size_bytes == 0) || |
|
|
|
|
|
!memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
size_t hash() const { |
|
|
size_t hash() const { |
|
@@ -485,6 +490,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { |
|
|
//! copy from m_val_ref to m_val, to avoid refed value being |
|
|
//! copy from m_val_ref to m_val, to avoid refed value being |
|
|
//! modified |
|
|
//! modified |
|
|
void copy_val_permanent() { |
|
|
void copy_val_permanent() { |
|
|
|
|
|
if (m_trait.size_bytes == 0) return; |
|
|
mgb_assert(m_val.empty()); |
|
|
mgb_assert(m_val.empty()); |
|
|
m_val.resize(m_trait.size_bytes); |
|
|
m_val.resize(m_trait.size_bytes); |
|
|
memcpy(m_val.data(), m_val_ref.raw_ptr(), m_trait.size_bytes); |
|
|
memcpy(m_val.data(), m_val_ref.raw_ptr(), m_trait.size_bytes); |
|
@@ -544,7 +550,6 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const Value& get(const HostTensorND &tensor) { |
|
|
const Value& get(const HostTensorND &tensor) { |
|
|
mgb_assert(!tensor.empty()); |
|
|
|
|
|
if (tensor.shape().is_scalar()) { |
|
|
if (tensor.shape().is_scalar()) { |
|
|
return get(DTypeScalar::make_from_raw( |
|
|
return get(DTypeScalar::make_from_raw( |
|
|
tensor.dtype(), tensor.raw_ptr())); |
|
|
tensor.dtype(), tensor.raw_ptr())); |
|
@@ -595,6 +600,7 @@ ImmutableTensor::ImmutableTensor(ComputingGraph &graph, |
|
|
|
|
|
|
|
|
add_output(value.dev().dtype()); |
|
|
add_output(value.dev().dtype()); |
|
|
add_equivalence_component<ScalarHash<const void*>>(&value); |
|
|
add_equivalence_component<ScalarHash<const void*>>(&value); |
|
|
|
|
|
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ImmutableTensor::~ImmutableTensor() noexcept = default; |
|
|
ImmutableTensor::~ImmutableTensor() noexcept = default; |
|
|