|
|
@@ -379,8 +379,11 @@ ValueRefList concat_rule( |
|
|
|
ValueRefList identity_rule_helper( |
|
|
|
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { |
|
|
|
// mgb_assert(inputs.size() == 1); |
|
|
|
auto& src = inputs[0].cast(t.value_type()); |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); |
|
|
|
if (auto& src = inputs[0].as_ref(t.value_type())) { |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src->format()); |
|
|
|
} else { |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
ValueRefList batchnorm_rule( |
|
|
|