|
|
@@ -232,16 +232,18 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { |
|
|
|
|
|
|
|
#ifdef MGB_ENABLE_GRAD |
|
|
|
MGB_IMPL_OPR_GRAD(BatchNormForward) { |
|
|
|
mgb_assert(wrt_idx < 5); |
|
|
|
if (wrt_idx < 3) { |
|
|
|
SymbolVarArray grad = BatchNormBackward::make( |
|
|
|
mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, |
|
|
|
"batch norm could only take grad in training mode"); |
|
|
|
mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx); |
|
|
|
VarNodeArray ret(opr.input().size(), nullptr); |
|
|
|
SymbolVarArray grad = BatchNormBackward::make( |
|
|
|
opr.input(0), out_grad[4], |
|
|
|
opr.output(2), opr.output(3), |
|
|
|
opr.input(1), opr.param()); |
|
|
|
return grad[(wrt_idx + 2) % 3].node(); |
|
|
|
} else { |
|
|
|
return nullptr; |
|
|
|
for (size_t i = 0; i < 3; ++ i) { |
|
|
|
ret[i] = grad[(i + 2) % 3].node(); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|