Browse Source

fix(imperative/src): fix empty_tensor bug of rng

GitOrigin-RevId: 4c948f41f0
release-1.9
Megvii Engine Team dengzheye 3 years ago
parent
commit
518c7f3781
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      imperative/src/impl/ops/rng.cpp

+ 6
- 1
imperative/src/impl/ops/rng.cpp View File

@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
LogicalTensorDesc dest;
auto&& xxx_rng_def = def.cast_final_safe<Op>();
size_t nr_inp = inputs.size();
@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
xxx_rng_def.dyn_typeinfo()->name, nr_inp);
}
dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
if (success) {
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
} else {
dest.layout = TensorLayout(inputs[0].layout.dtype);
}
return {{dest}, inputs[0].layout.ndim != 0};
}



Loading…
Cancel
Save