Browse Source

modify normalize_value

pull/360/head
ZhidanLiu 3 years ago
parent
commit
831c48d850
1 changed files with 6 additions and 3 deletions
  1. +6
    -3
      mindarmour/utils/_check_param.py

+ 6
- 3
mindarmour/utils/_check_param.py View File

@@ -218,13 +218,16 @@ def normalize_value(value, norm_level):

Args:
value (numpy.ndarray): Inputs.
norm_level (Union[int, str]): Normalized level. Option: [1, 2, np.inf, 'np.inf', '1', '2', 'inf', 'l1', 'l2']
norm_level (Union[int, str]): Normalized level. Option: [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf',
'np.inf', np.inf].


Returns:
numpy.ndarray, normalized value.

Raises:
NotImplementedError: If norm_level is not in [1, 2 , np.inf, 'np.inf', '1', '2', 'inf', 'l1', 'l2']
NotImplementedError: If norm_level is not in [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf].

"""
norm_level = check_norm_level(norm_level)
ori_shape = value.shape
@@ -236,7 +239,7 @@ def normalize_value(value, norm_level):
elif norm_level in (2, '2', 'l2'):
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
elif norm_level in (np.inf, 'inf', 'np.inf'):
elif norm_level in (np.inf, 'inf', 'np.inf', 'linf'):
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
else:


Loading…
Cancel
Save