From ee984e86087f42f7928795a89489fd33b99d7ef1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 Apr 2022 18:44:25 +0800 Subject: [PATCH] fix(imperative/amp): fix distributed backward callback for nhwc amp GitOrigin-RevId: 4d725b0ea438d078f1a57c1f58dc707e96f314c5 --- imperative/src/impl/transformations/format.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 454b1422..30d2ae6e 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -379,8 +379,11 @@ ValueRefList concat_rule( ValueRefList identity_rule_helper( const OpDef& op, const Span& inputs, const FormatTransformation& t) { // mgb_assert(inputs.size() == 1); - auto& src = inputs[0].cast(t.value_type()); - return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); + if (auto& src = inputs[0].as_ref(t.value_type())) { + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src->format()); + } else { + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); + } } ValueRefList batchnorm_rule(