|
|
@@ -50,8 +50,9 @@ BatchNormForward::BatchNormForward(VarNode *x, |
|
|
|
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() || |
|
|
|
dest_opr->same_type<VolatileSharedDeviceTensor>()), |
|
|
|
GraphError, |
|
|
|
"mean and variance in BatchNorm must be SharedDeviceTensor " |
|
|
|
"or VolatileSharedDeviceTensor; got %s{%s} actually", |
|
|
|
"mean and variance in training mode BatchNorm must be", |
|
|
|
"SharedDeviceTensor or VolatileSharedDeviceTensor;", |
|
|
|
"got %s{%s} actually", |
|
|
|
dest_opr->cname(), dest_opr->dyn_typeinfo()->name); |
|
|
|
}; |
|
|
|
check_dest(mean); |
|
|
|