Browse Source

fix bug of norm level check for r1.7

pull/353/head
ZhidanLiu 3 years ago
parent
commit
71143146d2
1 changed files with 6 additions and 7 deletions
  1. +6
    -7
      mindarmour/utils/_check_param.py

+ 6
- 7
mindarmour/utils/_check_param.py View File

@@ -204,7 +204,7 @@ def check_norm_level(norm_level):
"""Check norm_level of regularization."""
if not isinstance(norm_level, (int, str)):
msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level))
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf]
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf]
if norm_level not in accept_norm:
msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level)
LOGGER.error(TAG, msg)
@@ -218,14 +218,13 @@ def normalize_value(value, norm_level):

Args:
value (numpy.ndarray): Inputs.
norm_level (Union[int, str]): Normalized level. Option: [1, 2, np.inf, '1', '2', 'inf', 'l1', 'l2']
norm_level (Union[int, str]): Normalized level. Option: [1, 2, np.inf, 'np.inf', '1', '2', 'inf', 'l1', 'l2']

Returns:
numpy.ndarray, normalized value.

Raises:
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2',
'inf', 'l1', 'l2']
NotImplementedError: If norm_level is not in [1, 2 , np.inf, 'np.inf', '1', '2', 'inf', 'l1', 'l2']
"""
norm_level = check_norm_level(norm_level)
ori_shape = value.shape
@@ -237,12 +236,12 @@ def normalize_value(value, norm_level):
elif norm_level in (2, '2', 'l2'):
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
elif norm_level in (np.inf, 'inf'):
elif norm_level in (np.inf, 'inf', 'np.inf'):
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
else:
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are currently not supported, but got {}.' \
.format(norm_level)
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf]
msg = 'Values of `norm_level` must be in {}, but got {}'.format(accept_norm, norm_level)
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
return norm_value.reshape(ori_shape)


Loading…
Cancel
Save