|
|
@@ -61,17 +61,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
TensorLayout out_layout = src.layout; |
|
|
|
if (tshp.layout.ndim == 0 || tshp.value.empty()) { |
|
|
|
out_layout.ndim = 0; |
|
|
|
return {{{out_layout, src.comp_node}}, true}; |
|
|
|
return {{{out_layout, src.comp_node}}, false}; |
|
|
|
} |
|
|
|
mgb_assert( |
|
|
|
tshp.layout.ndim == 1, |
|
|
|
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", |
|
|
|
tshp.layout.ndim == 1, |
|
|
|
"target shape of Broadcast expects ndim=1; got ndim=%lu actually", |
|
|
|
tshp.layout.ndim); |
|
|
|
|
|
|
|
size_t target_ndim = tshp.layout.shape[0]; |
|
|
|
out_layout.ndim = target_ndim; |
|
|
|
auto* ptr = tshp.value.ptr<dt_int32>(); |
|
|
|
for(size_t i=0; i<target_ndim; ++i) { |
|
|
|
for (size_t i = 0; i < target_ndim; ++i) { |
|
|
|
out_layout.shape[i] = ptr[i]; |
|
|
|
} |
|
|
|
mgb_assert(valid_broadcast(src.layout, out_layout), |
|
|
|