diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 1ee91e2a..422e72b1 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -65,6 +65,25 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{out_layout, out_cn}}, false}; } } + // copy from megdnn::ElemwiseForward::check_dtype + switch (out_dt.category()) { + case DTypeCategory::FLOAT: + mgb_assert(trait.allow_float, "unsupport mode %s for float\n", + trait.name); + break; + case DTypeCategory::INT: + mgb_assert(trait.allow_int, "unsupport mode %s for int\n", + trait.name); + break; + case DTypeCategory::BOOL: + mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n", + trait.name); + break; + default: + // Quantized Dtype could also be handled by this op, + // but scales need to be the same. + break; + } auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};