Browse Source

fix(imperative): fix grad of BatchNorm

GitOrigin-RevId: 1e8d8afaf2
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
54d18115b6
1 changed files with 8 additions and 6 deletions
  1. +8
    -6
      src/opr/impl/dnn/batch_norm.cpp

+ 8
- 6
src/opr/impl/dnn/batch_norm.cpp View File

@@ -232,16 +232,18 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(wrt_idx < 5);
if (wrt_idx < 3) {
SymbolVarArray grad = BatchNormBackward::make(
mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING,
"batch norm could only take grad in training mode");
mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx);
VarNodeArray ret(opr.input().size(), nullptr);
SymbolVarArray grad = BatchNormBackward::make(
opr.input(0), out_grad[4],
opr.output(2), opr.output(3),
opr.input(1), opr.param());
return grad[(wrt_idx + 2) % 3].node();
} else {
return nullptr;
for (size_t i = 0; i < 3; ++ i) {
ret[i] = grad[(i + 2) % 3].node();
}
return ret;
}
#endif



Loading…
Cancel
Save