|
|
@@ -100,7 +100,7 @@ void intl::HostIONodeBase::init_output_static_infer_desc() { |
|
|
|
|
|
|
|
if (fill_in_static_infer(nullptr)) { |
|
|
|
auto infer_val = [this](DeviceTensorND& dest, const InpVal&) -> bool { |
|
|
|
if (fill_in_static_infer(&dest) && !dest.empty()) { |
|
|
|
if (fill_in_static_infer(&dest) && dest.shape_valid()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
@@ -423,8 +423,8 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) { |
|
|
|
|
|
|
|
DeviceTensorND& ImmutableTensor::Value::static_infer() { |
|
|
|
MGB_LOCK_GUARD(m_mtx); |
|
|
|
if (m_static_infer.empty()) { |
|
|
|
mgb_assert(!m_dev.empty()); |
|
|
|
if (!m_static_infer.shape_valid()) { |
|
|
|
mgb_assert(m_dev.shape_valid()); |
|
|
|
m_static_infer.comp_node(CompNode::default_cpu()).copy_from(m_dev); |
|
|
|
} |
|
|
|
return m_static_infer; |
|
|
|