diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index e18bfd94..6cbc6c04 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -100,7 +100,7 @@ void intl::HostIONodeBase::init_output_static_infer_desc() { if (fill_in_static_infer(nullptr)) { auto infer_val = [this](DeviceTensorND& dest, const InpVal&) -> bool { - if (fill_in_static_infer(&dest) && !dest.empty()) { + if (fill_in_static_infer(&dest) && dest.shape_valid()) { return true; } return false; @@ -423,8 +423,8 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) { DeviceTensorND& ImmutableTensor::Value::static_infer() { MGB_LOCK_GUARD(m_mtx); - if (m_static_infer.empty()) { - mgb_assert(!m_dev.empty()); + if (!m_static_infer.shape_valid()) { + mgb_assert(m_dev.shape_valid()); m_static_infer.comp_node(CompNode::default_cpu()).copy_from(m_dev); } return m_static_infer;