diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 9f12f2f..b8ebc54 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -218,14 +218,14 @@ def normalize_value(value, norm_level): Args: value (numpy.ndarray): Inputs. - norm_level (Union[int, str]): Normalized level. Option: [1, 2, np.inf, '1', '2', 'inf', 'l1', 'l2'] + norm_level (Union[int, str]): Normalized level. Option: [1, 2, '1', '2', 'l1', 'l2', np.inf, 'np.inf', + 'linf', 'inf']. Returns: numpy.ndarray, normalized value. Raises: - NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', - 'inf', 'l1', 'l2'] + NotImplementedError: If norm_level is not in [1, 2, '1', '2', 'l1', 'l2', np.inf, 'np.inf', 'linf', 'inf']. """ norm_level = check_norm_level(norm_level) ori_shape = value.shape @@ -237,7 +237,7 @@ def normalize_value(value, norm_level): elif norm_level in (2, '2', 'l2'): norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm - elif norm_level in (np.inf, 'inf', 'np.inf'): + elif norm_level in (np.inf, 'inf', 'np.inf', 'linf'): norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm else: