|
|
@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
template <typename Op> |
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
bool success = inputs[0].layout.ndim != 0; |
|
|
|
LogicalTensorDesc dest; |
|
|
|
auto&& xxx_rng_def = def.cast_final_safe<Op>(); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
xxx_rng_def.dyn_typeinfo()->name, nr_inp); |
|
|
|
} |
|
|
|
dest.comp_node = inputs[0].comp_node; |
|
|
|
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); |
|
|
|
if (success) { |
|
|
|
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); |
|
|
|
} else { |
|
|
|
dest.layout = TensorLayout(inputs[0].layout.dtype); |
|
|
|
} |
|
|
|
return {{dest}, inputs[0].layout.ndim != 0}; |
|
|
|
} |
|
|
|
|
|
|
|