diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 454b1422..30d2ae6e 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -379,8 +379,11 @@ ValueRefList concat_rule( ValueRefList identity_rule_helper( const OpDef& op, const Span& inputs, const FormatTransformation& t) { // mgb_assert(inputs.size() == 1); - auto& src = inputs[0].cast(t.value_type()); - return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); + if (auto& src = inputs[0].as_ref(t.value_type())) { + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src->format()); + } else { + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); + } } ValueRefList batchnorm_rule(