Browse Source

fix(imperative/src): fix empty_tensor bug of convbwd&rng

GitOrigin-RevId: 4c948f41f0
release-1.10
Megvii Engine Team 3 years ago
parent
commit
e59b6e13a3
2 changed files with 8 additions and 3 deletions
  1. +2
    -2
      imperative/src/impl/ops/convolution.cpp
  2. +6
    -1
      imperative/src/impl/ops/rng.cpp

+ 2
- 2
imperative/src/impl/ops/convolution.cpp View File

@@ -354,8 +354,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout diff = inputs[1].layout;
size_t filter_ndim = filter.ndim;
size_t diff_ndim = diff.ndim;
if (filter_ndim == 0) {
desc.layout = filter;
if (diff_ndim == 0) {
desc.layout = diff;
return {dests, false};
}



+ 6
- 1
imperative/src/impl/ops/rng.cpp View File

@@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
LogicalTensorDesc dest;
auto&& xxx_rng_def = def.cast_final_safe<Op>();
size_t nr_inp = inputs.size();
@@ -558,7 +559,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
xxx_rng_def.dyn_typeinfo()->name, nr_inp);
}
dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
if (success) {
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
} else {
dest.layout = TensorLayout(inputs[0].layout.dtype);
}
return {{dest}, inputs[0].layout.ndim != 0};
}



Loading…
Cancel
Save