From 518c7f3781bc33a4fd189562c74bc93be6479574 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 Apr 2022 16:20:17 +0800 Subject: [PATCH] fix(imperative/src): fix empty_tensor bug of rng GitOrigin-RevId: 4c948f41f04649620ce7b34c5f3dac69d66705e2 --- imperative/src/impl/ops/rng.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index a59e44f8..561ac269 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { template std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { + bool success = inputs[0].layout.ndim != 0; LogicalTensorDesc dest; auto&& xxx_rng_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); @@ -558,7 +559,11 @@ std::tuple, bool> infer_output_attrs_fallible( xxx_rng_def.dyn_typeinfo()->name, nr_inp); } dest.comp_node = inputs[0].comp_node; - dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); + if (success) { + dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); + } else { + dest.layout = TensorLayout(inputs[0].layout.dtype); + } return {{dest}, inputs[0].layout.ndim != 0}; }