Browse Source

modify normalize_value

tags/v1.8.0
ZhidanLiu 3 years ago
parent
commit
1e21ec5bbe
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindarmour/utils/_check_param.py

+ 4
- 4
mindarmour/utils/_check_param.py View File

@@ -218,14 +218,14 @@ def normalize_value(value, norm_level):

Args:
value (numpy.ndarray): Inputs.
norm_level (Union[int, str]): Normalized level. Option: [1, 2, np.inf, '1', '2', 'inf', 'l1', 'l2']
norm_level (Union[int, str]): Normalized level. Option: [1, 2, '1', '2', 'l1', 'l2', np.inf, 'np.inf',
'linf', 'inf'].

Returns:
numpy.ndarray, normalized value.

Raises:
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2',
'inf', 'l1', 'l2']
NotImplementedError: If norm_level is not in [1, 2, '1', '2', 'l1', 'l2', np.inf, 'np.inf', 'linf', 'inf'].
"""
norm_level = check_norm_level(norm_level)
ori_shape = value.shape
@@ -237,7 +237,7 @@ def normalize_value(value, norm_level):
elif norm_level in (2, '2', 'l2'):
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
elif norm_level in (np.inf, 'inf', 'np.inf'):
elif norm_level in (np.inf, 'inf', 'np.inf', 'linf'):
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
else:


Loading…
Cancel
Save