Browse Source

feat(mgb/opr): allow empty ImmutableTensor

Fixes MGE-675.

GitOrigin-RevId: c6771740fc
tags/v0.5.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
4f77509ea6
3 changed files with 29 additions and 11 deletions
  1. +2
    -1
      src/gopt/impl/framework.cpp
  2. +16
    -10
      src/opr/impl/io.cpp
  3. +11
    -0
      src/opr/test/io.cpp

+ 2
- 1
src/gopt/impl/framework.cpp View File

@@ -761,7 +761,8 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr(


if (is_const_var(m_const_var_type, opr)) { if (is_const_var(m_const_var_type, opr)) {
auto sz = var_mem_size(opr->output(0)); auto sz = var_mem_size(opr->output(0));
mgb_assert(sz);
mgb_assert(sz || opr->output(0)->contain_flag(
VarNode::Flag::ALLOW_EMPTY_SHAPE));
info.is_const = true; info.is_const = true;
info.max_size = sz; info.max_size = sz;
return make_ret(); return make_ret();


+ 16
- 10
src/opr/impl/io.cpp View File

@@ -382,7 +382,7 @@ class ImmutableTensor::Value {
void setup(CompNode cn, const HostTensorND &val); void setup(CompNode cn, const HostTensorND &val);


bool initialized() const { bool initialized() const {
return !m_dev.empty();
return m_dev.shape_valid();
} }


//! value on comp node //! value on comp node
@@ -400,8 +400,9 @@ class ImmutableTensor::Value {
}; };


void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) { void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) {
mgb_assert(m_dev.empty() && !val.empty());
mgb_assert(m_dev.empty() && !m_dev.shape_valid());
m_dev.comp_node(cn).copy_from(val).sync(); m_dev.comp_node(cn).copy_from(val).sync();
mgb_assert(val.empty() == m_dev.empty());


auto one_elem = [](const TensorShape& shape) { auto one_elem = [](const TensorShape& shape) {
for (size_t i = 0; i < shape.ndim; ++i) { for (size_t i = 0; i < shape.ndim; ++i) {
@@ -446,6 +447,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
HostTensorND m_val_ref; HostTensorND m_val_ref;


const dt_byte* val_ptr() const { const dt_byte* val_ptr() const {
mgb_assert(m_trait.size_bytes);
return m_val.empty() ? m_val_ref.raw_ptr() : m_val.data(); return m_val.empty() ? m_val_ref.raw_ptr() : m_val.data();
} }


@@ -454,9 +456,8 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
TensorKey(const HostTensorND &v): TensorKey(const HostTensorND &v):
m_val_ref{v} m_val_ref{v}
{ {
mgb_assert(v.layout().is_contiguous());
mgb_assert(v.layout().is_contiguous() || v.layout().is_empty());
m_trait.size_bytes = v.layout().span().high_byte; m_trait.size_bytes = v.layout().span().high_byte;
mgb_assert(m_trait.size_bytes);


auto &&layout = m_trait.layout; auto &&layout = m_trait.layout;
// zero to enable byte-comparison // zero to enable byte-comparison
@@ -467,15 +468,19 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
layout.shape[i] = v.layout().shape[i]; layout.shape[i] = v.layout().shape[i];
layout.stride[i] = v.layout().stride[i]; layout.stride[i] = v.layout().stride[i];
} }
m_trait.hash = XXHash{}.
update(v.raw_ptr(), m_trait.size_bytes).
update(&m_trait.layout, sizeof(m_trait.layout)).
digest();
XXHash hasher;
if (!v.empty()) {
hasher.update(v.raw_ptr(), m_trait.size_bytes);
}
hasher.update(&m_trait.layout, sizeof(m_trait.layout));
m_trait.hash = hasher.digest();
} }


bool operator == (const TensorKey &rhs) const { bool operator == (const TensorKey &rhs) const {
return !memcmp(&m_trait, &rhs.m_trait, sizeof(Trait)) && return !memcmp(&m_trait, &rhs.m_trait, sizeof(Trait)) &&
!memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes);
((m_trait.size_bytes == 0 &&
rhs.m_trait.size_bytes == 0) ||
!memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes));
} }


size_t hash() const { size_t hash() const {
@@ -485,6 +490,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
//! copy from m_val_ref to m_val, to avoid refed value being //! copy from m_val_ref to m_val, to avoid refed value being
//! modified //! modified
void copy_val_permanent() { void copy_val_permanent() {
if (m_trait.size_bytes == 0) return;
mgb_assert(m_val.empty()); mgb_assert(m_val.empty());
m_val.resize(m_trait.size_bytes); m_val.resize(m_trait.size_bytes);
memcpy(m_val.data(), m_val_ref.raw_ptr(), m_trait.size_bytes); memcpy(m_val.data(), m_val_ref.raw_ptr(), m_trait.size_bytes);
@@ -544,7 +550,6 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
} }


const Value& get(const HostTensorND &tensor) { const Value& get(const HostTensorND &tensor) {
mgb_assert(!tensor.empty());
if (tensor.shape().is_scalar()) { if (tensor.shape().is_scalar()) {
return get(DTypeScalar::make_from_raw( return get(DTypeScalar::make_from_raw(
tensor.dtype(), tensor.raw_ptr())); tensor.dtype(), tensor.raw_ptr()));
@@ -595,6 +600,7 @@ ImmutableTensor::ImmutableTensor(ComputingGraph &graph,


add_output(value.dev().dtype()); add_output(value.dev().dtype());
add_equivalence_component<ScalarHash<const void*>>(&value); add_equivalence_component<ScalarHash<const void*>>(&value);
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
} }


ImmutableTensor::~ImmutableTensor() noexcept = default; ImmutableTensor::~ImmutableTensor() noexcept = default;


+ 11
- 0
src/opr/test/io.cpp View File

@@ -177,6 +177,17 @@ TEST(TestOprIO, ImmutableTensorLarge) {
} }
} }


TEST(TestOprIO, ImmutableTensorEmpty) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen({1, 9, 1, 9, 8, 1, 0});
auto x = opr::ImmutableTensor::make(*graph, *host_x);
HostTensorND host_x2;
auto func = graph->compile({make_callback_copy(x, host_x2)});
func->execute();
ASSERT_TRUE(host_x2.shape().is_empty());
}

TEST(TestOprIO, SharedDeviceTensor) { TEST(TestOprIO, SharedDeviceTensor) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto hv = gen({123}); auto hv = gen({123});


Loading…
Cancel
Save