From 71143146d2a2ba100ee7da7593cbafb234fc222a Mon Sep 17 00:00:00 2001 From: ZhidanLiu Date: Wed, 13 Apr 2022 16:16:02 +0800 Subject: [PATCH] fix bug of norm level check for r1.7 --- mindarmour/utils/_check_param.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index e3cd17b..f941645 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -204,7 +204,7 @@ def check_norm_level(norm_level): """Check norm_level of regularization.""" if not isinstance(norm_level, (int, str)): msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level)) - accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf] + accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf] if norm_level not in accept_norm: msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level) LOGGER.error(TAG, msg) @@ -218,14 +218,13 @@ def normalize_value(value, norm_level): Args: value (numpy.ndarray): Inputs. - norm_level (Union[int, str]): Normalized level. Option: [1, 2, np.inf, '1', '2', 'inf', 'l1', 'l2'] + norm_level (Union[int, str]): Normalized level. Option: [1, 2, np.inf, 'np.inf', '1', '2', 'inf', 'l1', 'l2'] Returns: numpy.ndarray, normalized value. Raises: - NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', - 'inf', 'l1', 'l2'] + NotImplementedError: If norm_level is not in [1, 2 , np.inf, 'np.inf', '1', '2', 'inf', 'l1', 'l2'] """ norm_level = check_norm_level(norm_level) ori_shape = value.shape @@ -237,12 +236,12 @@ def normalize_value(value, norm_level): elif norm_level in (2, '2', 'l2'): norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm - elif norm_level in (np.inf, 'inf'): + elif norm_level in (np.inf, 'inf', 'np.inf'): norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm else: - msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are currently not supported, but got {}.' \ - .format(norm_level) + accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf', np.inf] + msg = 'Values of `norm_level` must be in {}, but got {}'.format(accept_norm, norm_level) LOGGER.error(TAG, msg) raise NotImplementedError(msg) return norm_value.reshape(ori_shape)