Browse Source

fix(imperative/amp): fix distributed backward callback for nhwc amp

GitOrigin-RevId: 4d725b0ea4
release-1.10
Megvii Engine Team 3 years ago
parent
commit
ee984e8608
1 changed files with 5 additions and 2 deletions
  1. +5
    -2
      imperative/src/impl/transformations/format.cpp

+ 5
- 2
imperative/src/impl/transformations/format.cpp View File

@@ -379,8 +379,11 @@ ValueRefList concat_rule(
ValueRefList identity_rule_helper(
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
// mgb_assert(inputs.size() == 1);
auto& src = inputs[0].cast(t.value_type());
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format());
if (auto& src = inputs[0].as_ref(t.value_type())) {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src->format());
} else {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
}

ValueRefList batchnorm_rule(


Loading…
Cancel
Save