Browse Source

fix(imperative/ops): add check_dtype for Elemwise in infer_attrs

GitOrigin-RevId: c777855753
release-1.4
Megvii Engine Team 4 years ago
parent
commit
1aba1918a6
1 changed files with 19 additions and 0 deletions
  1. +19
    -0
      imperative/src/impl/ops/elemwise.cpp

+ 19
- 0
imperative/src/impl/ops/elemwise.cpp View File

@@ -65,6 +65,25 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{out_layout, out_cn}}, false};
}
}
// copy from megdnn::ElemwiseForward::check_dtype
switch (out_dt.category()) {
case DTypeCategory::FLOAT:
mgb_assert(trait.allow_float, "unsupport mode %s for float\n",
trait.name);
break;
case DTypeCategory::INT:
mgb_assert(trait.allow_int, "unsupport mode %s for int\n",
trait.name);
break;
case DTypeCategory::BOOL:
mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n",
trait.name);
break;
default:
// Quantized Dtype could also be handled by this op,
// but scales need to be the same.
break;
}

auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};


Loading…
Cancel
Save