|
|
@@ -65,6 +65,25 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
return {{{out_layout, out_cn}}, false}; |
|
|
|
} |
|
|
|
} |
|
|
|
// copy from megdnn::ElemwiseForward::check_dtype |
|
|
|
switch (out_dt.category()) { |
|
|
|
case DTypeCategory::FLOAT: |
|
|
|
mgb_assert(trait.allow_float, "unsupport mode %s for float\n", |
|
|
|
trait.name); |
|
|
|
break; |
|
|
|
case DTypeCategory::INT: |
|
|
|
mgb_assert(trait.allow_int, "unsupport mode %s for int\n", |
|
|
|
trait.name); |
|
|
|
break; |
|
|
|
case DTypeCategory::BOOL: |
|
|
|
mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n", |
|
|
|
trait.name); |
|
|
|
break; |
|
|
|
default: |
|
|
|
// Quantized Dtype could also be handled by this op, |
|
|
|
// but scales need to be the same. |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); |
|
|
|
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; |
|
|
|