|
|
@@ -1194,6 +1194,13 @@ def batch_norm( |
|
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5 |
|
|
|
inplace: whether to update ``running_mean`` and ``running_var`` |
|
|
|
inplace or return new tensors. Default: True |
|
|
|
compute_mode: When set to 'default', no special requirements will be |
|
|
|
placed on the precision of intermediate results. When set to 'float32', |
|
|
|
float32 would be used for accumulator and intermediate result, but only |
|
|
|
effective when input and output are of float16 dtype. |
|
|
|
param_dim: a value indicating in which format the parameters are. |
|
|
|
Default: 'dim_1c11', which means NCHW format. |
|
|
|
And 'dim_111c' means NHWC format. |
|
|
|
""" |
|
|
|
if inp.ndim != 4: |
|
|
|
raise NotImplementedError("batch_norm for ndim != 4") |
|
|
|