diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index ff82a840..a3e1606c 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -81,6 +81,7 @@ std::tuple, bool> infer_output_attrs_fallible( src.layout.ndim); size_t idx = 0; bool input_used[TensorLayout::MAX_NDIM] = {0}; + out_shape.ndim = ds.pattern.size(); for (auto i : ds.pattern) { if (i < 0) { out_shape[idx] = 1;