You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

var_init.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Initialize.
  17. """
  18. import math
  19. from functools import reduce
  20. import numpy as np
  21. import mindspore.nn as nn
  22. from mindspore.common import initializer as init
  23. def _calculate_gain(nonlinearity, param=None):
  24. r"""
  25. Return the recommended gain value for the given nonlinearity function.
  26. The values are as follows:
  27. ================= ====================================================
  28. nonlinearity gain
  29. ================= ====================================================
  30. Linear / Identity :math:`1`
  31. Conv{1,2,3}D :math:`1`
  32. Sigmoid :math:`1`
  33. Tanh :math:`\frac{5}{3}`
  34. ReLU :math:`\sqrt{2}`
  35. Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
  36. ================= ====================================================
  37. Args:
  38. nonlinearity: the non-linear function
  39. param: optional parameter for the non-linear function
  40. Examples:
  41. >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
  42. """
  43. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  44. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  45. return 1
  46. if nonlinearity == 'tanh':
  47. return 5.0 / 3
  48. if nonlinearity == 'relu':
  49. return math.sqrt(2.0)
  50. if nonlinearity == 'leaky_relu':
  51. if param is None:
  52. negative_slope = 0.01
  53. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  54. negative_slope = param
  55. else:
  56. raise ValueError("negative_slope {} not a valid number".format(param))
  57. return math.sqrt(2.0 / (1 + negative_slope**2))
  58. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  59. def _assignment(arr, num):
  60. """Assign the value of `num` to `arr`."""
  61. if arr.shape == ():
  62. arr = arr.reshape((1))
  63. arr[:] = num
  64. arr = arr.reshape(())
  65. else:
  66. if isinstance(num, np.ndarray):
  67. arr[:] = num[:]
  68. else:
  69. arr[:] = num
  70. return arr
  71. def _calculate_in_and_out(arr):
  72. """
  73. Calculate n_in and n_out.
  74. Args:
  75. arr (Array): Input array.
  76. Returns:
  77. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  78. """
  79. dim = len(arr.shape)
  80. if dim < 2:
  81. raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
  82. n_in = arr.shape[1]
  83. n_out = arr.shape[0]
  84. if dim > 2:
  85. counter = reduce(lambda x, y: x*y, arr.shape[2:])
  86. n_in *= counter
  87. n_out *= counter
  88. return n_in, n_out
  89. def _select_fan(array, mode):
  90. mode = mode.lower()
  91. valid_modes = ['fan_in', 'fan_out']
  92. if mode not in valid_modes:
  93. raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
  94. fan_in, fan_out = _calculate_in_and_out(array)
  95. return fan_in if mode == 'fan_in' else fan_out
  96. class KaimingInit(init.Initializer):
  97. r"""
  98. Base Class. Initialize the array with He kaiming algorithm.
  99. Args:
  100. a: the negative slope of the rectifier used after this layer (only
  101. used with ``'leaky_relu'``)
  102. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  103. preserves the magnitude of the variance of the weights in the
  104. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  105. backwards pass.
  106. nonlinearity: the non-linear function, recommended to use only with
  107. ``'relu'`` or ``'leaky_relu'`` (default).
  108. """
  109. def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
  110. super(KaimingInit, self).__init__()
  111. self.mode = mode
  112. self.gain = _calculate_gain(nonlinearity, a)
  113. def _initialize(self, arr):
  114. pass
  115. class KaimingUniform(KaimingInit):
  116. r"""
  117. Initialize the array with He kaiming uniform algorithm. The resulting tensor will
  118. have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
  119. .. math::
  120. \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
  121. Input:
  122. arr (Array): The array to be assigned.
  123. Returns:
  124. Array, assigned array.
  125. Examples:
  126. >>> w = np.empty(3, 5)
  127. >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
  128. """
  129. def _initialize(self, arr):
  130. fan = _select_fan(arr, self.mode)
  131. bound = math.sqrt(3.0)*self.gain / math.sqrt(fan)
  132. np.random.seed(0)
  133. data = np.random.uniform(-bound, bound, arr.shape)
  134. _assignment(arr, data)
  135. class KaimingNormal(KaimingInit):
  136. r"""
  137. Initialize the array with He kaiming normal algorithm. The resulting tensor will
  138. have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
  139. .. math::
  140. \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
  141. Input:
  142. arr (Array): The array to be assigned.
  143. Returns:
  144. Array, assigned array.
  145. Examples:
  146. >>> w = np.empty(3, 5)
  147. >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
  148. """
  149. def _initialize(self, arr):
  150. fan = _select_fan(arr, self.mode)
  151. std = self.gain / math.sqrt(fan)
  152. np.random.seed(0)
  153. data = np.random.normal(0, std, arr.shape)
  154. _assignment(arr, data)
  155. def default_recurisive_init(custom_cell):
  156. """default_recurisive_init"""
  157. for _, cell in custom_cell.cells_and_names():
  158. if isinstance(cell, nn.Conv2d):
  159. cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
  160. cell.weight.shape,
  161. cell.weight.dtype)
  162. if cell.bias is not None:
  163. fan_in, _ = _calculate_in_and_out(cell.weight)
  164. bound = 1 / math.sqrt(fan_in)
  165. np.random.seed(0)
  166. cell.bias.default_input = init.initializer(init.Uniform(bound),
  167. cell.bias.shape,
  168. cell.bias.dtype)
  169. elif isinstance(cell, nn.Dense):
  170. cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
  171. cell.weight.shape,
  172. cell.weight.dtype)
  173. if cell.bias is not None:
  174. fan_in, _ = _calculate_in_and_out(cell.weight)
  175. bound = 1 / math.sqrt(fan_in)
  176. np.random.seed(0)
  177. cell.bias.default_input = init.initializer(init.Uniform(bound),
  178. cell.bias.shape,
  179. cell.bias.dtype)
  180. elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
  181. pass

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。