From 1aba1918a6ed27d4fdc0ed8b1931930fc7635574 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 11 May 2021 17:36:06 +0800 Subject: [PATCH] fix(imperative/ops): add check_dtype for Elemwise in infer_attrs GitOrigin-RevId: c7778557537f94648e1e706af00fe2d917f18fbd --- imperative/src/impl/ops/elemwise.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 1ee91e2a..422e72b1 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -65,6 +65,25 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{out_layout, out_cn}}, false}; } } + // copy from megdnn::ElemwiseForward::check_dtype + switch (out_dt.category()) { + case DTypeCategory::FLOAT: + mgb_assert(trait.allow_float, "unsupport mode %s for float\n", + trait.name); + break; + case DTypeCategory::INT: + mgb_assert(trait.allow_int, "unsupport mode %s for int\n", + trait.name); + break; + case DTypeCategory::BOOL: + mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n", + trait.name); + break; + default: + // Quantized Dtype could also be handled by this op, + // but scales need to be the same. + break; + } auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};