From 4f77509ea6df5a2af8077e6e3cc81e7e8e8156c1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 28 Apr 2020 17:52:50 +0800 Subject: [PATCH] feat(mgb/opr): allow empty ImmutableTensor Fixes MGE-675. GitOrigin-RevId: c6771740fc48226f1b7c79d519de61445e671290 --- src/gopt/impl/framework.cpp | 3 ++- src/opr/impl/io.cpp | 26 ++++++++++++++++---------- src/opr/test/io.cpp | 11 +++++++++++ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 570e05f7..2703fffd 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -761,7 +761,8 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( if (is_const_var(m_const_var_type, opr)) { auto sz = var_mem_size(opr->output(0)); - mgb_assert(sz); + mgb_assert(sz || opr->output(0)->contain_flag( + VarNode::Flag::ALLOW_EMPTY_SHAPE)); info.is_const = true; info.max_size = sz; return make_ret(); diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index dd8b5dfd..c535c837 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -382,7 +382,7 @@ class ImmutableTensor::Value { void setup(CompNode cn, const HostTensorND &val); bool initialized() const { - return !m_dev.empty(); + return m_dev.shape_valid(); } //! value on comp node @@ -400,8 +400,9 @@ class ImmutableTensor::Value { }; void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) { - mgb_assert(m_dev.empty() && !val.empty()); + mgb_assert(m_dev.empty() && !m_dev.shape_valid()); m_dev.comp_node(cn).copy_from(val).sync(); + mgb_assert(val.empty() == m_dev.empty()); auto one_elem = [](const TensorShape& shape) { for (size_t i = 0; i < shape.ndim; ++i) { @@ -446,6 +447,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { HostTensorND m_val_ref; const dt_byte* val_ptr() const { + mgb_assert(m_trait.size_bytes); return m_val.empty() ? m_val_ref.raw_ptr() : m_val.data(); } @@ -454,9 +456,8 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { TensorKey(const HostTensorND &v): m_val_ref{v} { - mgb_assert(v.layout().is_contiguous()); + mgb_assert(v.layout().is_contiguous() || v.layout().is_empty()); m_trait.size_bytes = v.layout().span().high_byte; - mgb_assert(m_trait.size_bytes); auto &&layout = m_trait.layout; // zero to enable byte-comparison @@ -467,15 +468,19 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { layout.shape[i] = v.layout().shape[i]; layout.stride[i] = v.layout().stride[i]; } - m_trait.hash = XXHash{}. - update(v.raw_ptr(), m_trait.size_bytes). - update(&m_trait.layout, sizeof(m_trait.layout)). - digest(); + XXHash hasher; + if (!v.empty()) { + hasher.update(v.raw_ptr(), m_trait.size_bytes); + } + hasher.update(&m_trait.layout, sizeof(m_trait.layout)); + m_trait.hash = hasher.digest(); } bool operator == (const TensorKey &rhs) const { return !memcmp(&m_trait, &rhs.m_trait, sizeof(Trait)) && - !memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes); + ((m_trait.size_bytes == 0 && + rhs.m_trait.size_bytes == 0) || + !memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes)); } size_t hash() const { @@ -485,6 +490,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { //! copy from m_val_ref to m_val, to avoid refed value being //! modified void copy_val_permanent() { + if (m_trait.size_bytes == 0) return; mgb_assert(m_val.empty()); m_val.resize(m_trait.size_bytes); memcpy(m_val.data(), m_val_ref.raw_ptr(), m_trait.size_bytes); @@ -544,7 +550,6 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData { } const Value& get(const HostTensorND &tensor) { - mgb_assert(!tensor.empty()); if (tensor.shape().is_scalar()) { return get(DTypeScalar::make_from_raw( tensor.dtype(), tensor.raw_ptr())); @@ -595,6 +600,7 @@ ImmutableTensor::ImmutableTensor(ComputingGraph &graph, add_output(value.dev().dtype()); add_equivalence_component>(&value); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } ImmutableTensor::~ImmutableTensor() noexcept = default; diff --git a/src/opr/test/io.cpp b/src/opr/test/io.cpp index b9f7b591..08b60cdb 100644 --- a/src/opr/test/io.cpp +++ b/src/opr/test/io.cpp @@ -177,6 +177,17 @@ TEST(TestOprIO, ImmutableTensorLarge) { } } +TEST(TestOprIO, ImmutableTensorEmpty) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + auto host_x = gen({1, 9, 1, 9, 8, 1, 0}); + auto x = opr::ImmutableTensor::make(*graph, *host_x); + HostTensorND host_x2; + auto func = graph->compile({make_callback_copy(x, host_x2)}); + func->execute(); + ASSERT_TRUE(host_x2.shape().is_empty()); +} + TEST(TestOprIO, SharedDeviceTensor) { HostTensorGenerator<> gen; auto hv = gen({123});