|
|
@@ -58,6 +58,7 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, |
|
|
|
get_workspace_in_bytes(x, dy, saved_batch_mean, saved_batch_variance, |
|
|
|
bn_scale, d_bn_scale, d_bn_bias, dx); |
|
|
|
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); |
|
|
|
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, "BNBackward only support TRAINING mode"); |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace megdnn |
|
|
|