Browse Source

!350 fix bug of norm level check

Merge pull request !350 from ZhidanLiu/master
tags/v1.8.0
i-robot Gitee 3 years ago
parent
commit
3e5dc310ec
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindarmour/utils/_check_param.py

+ 4
- 4
mindarmour/utils/_check_param.py View File

@@ -204,7 +204,7 @@ def check_norm_level(norm_level):
"""Check norm_level of regularization."""
if not isinstance(norm_level, (int, str)):
msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level))
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf]
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf]
if norm_level not in accept_norm:
msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level)
LOGGER.error(TAG, msg)
@@ -237,12 +237,12 @@ def normalize_value(value, norm_level):
elif norm_level in (2, '2', 'l2'):
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
elif norm_level in (np.inf, 'inf'):
elif norm_level in (np.inf, 'inf', 'np.inf'):
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
else:
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are currently not supported, but got {}.' \
.format(norm_level)
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf]
msg = 'Values of `norm_level`, must be in {}, but got {}'.format(accept_norm, norm_level)
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
return norm_value.reshape(ori_shape)


Loading…
Cancel
Save