Browse Source

!374 Add Net Class to API Examples and Fix Description Errors

Merge pull request !374 from 张澍坤/master
tags/v1.8.0
i-robot Gitee 3 years ago
parent
commit
333ca2272f
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
31 changed files with 212 additions and 99 deletions
  1. +1
    -0
      docs/api/api_python/mindarmour.adv_robustness.detectors.rst
  2. +1
    -0
      docs/api/api_python/mindarmour.adv_robustness.evaluations.rst
  3. +5
    -4
      docs/api/api_python/mindarmour.privacy.sup_privacy.rst
  4. +4
    -3
      docs/api/api_python/mindarmour.reliability.rst
  5. +5
    -4
      docs/api/api_python/mindarmour.rst
  6. +12
    -1
      mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py
  7. +14
    -3
      mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py
  8. +12
    -1
      mindarmour/adv_robustness/attacks/black/pointwise_attack.py
  9. +14
    -1
      mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py
  10. +3
    -4
      mindarmour/adv_robustness/attacks/carlini_wagner.py
  11. +1
    -2
      mindarmour/adv_robustness/attacks/deep_fool.py
  12. +12
    -18
      mindarmour/adv_robustness/attacks/gradient_method.py
  13. +10
    -15
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  14. +1
    -2
      mindarmour/adv_robustness/attacks/jsma.py
  15. +12
    -1
      mindarmour/adv_robustness/attacks/lbfgs.py
  16. +44
    -11
      mindarmour/adv_robustness/defenses/adversarial_defense.py
  17. +15
    -4
      mindarmour/adv_robustness/defenses/natural_adversarial_defense.py
  18. +15
    -4
      mindarmour/adv_robustness/defenses/projected_adversarial_defense.py
  19. +1
    -2
      mindarmour/adv_robustness/detectors/ensemble_detector.py
  20. +1
    -2
      mindarmour/adv_robustness/detectors/mag_net.py
  21. +1
    -2
      mindarmour/adv_robustness/detectors/region_based_detector.py
  22. +1
    -2
      mindarmour/adv_robustness/detectors/spatial_smoothing.py
  23. +9
    -4
      mindarmour/privacy/diff_privacy/optimizer/optimizer.py
  24. +2
    -1
      mindarmour/privacy/diff_privacy/train/model.py
  25. +2
    -1
      mindarmour/privacy/evaluation/membership_inference.py
  26. +2
    -1
      mindarmour/privacy/sup_privacy/mask_monitor/masker.py
  27. +4
    -2
      mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py
  28. +2
    -1
      mindarmour/privacy/sup_privacy/train/model.py
  29. +2
    -1
      mindarmour/reliability/concept_drift/concept_drift_check_images.py
  30. +2
    -1
      mindarmour/reliability/concept_drift/concept_drift_check_time_series.py
  31. +2
    -1
      mindarmour/reliability/model_fault_injection/fault_injection.py

+ 1
- 0
docs/api/api_python/mindarmour.adv_robustness.detectors.rst View File

@@ -306,6 +306,7 @@ mindarmour.adv_robustness.detectors
- **max_k_neighbor** (int) - 最近邻的最大数量。默认值:1000。
- **chunk_size** (int) - 缓冲区大小。默认值:1000。
- **max_buffer_size** (int) - 最大缓冲区大小。默认值:10000。默认值:False。
- **tuning** (bool) - 计算k个最近邻的平均距离,如果'tuning'为true,k=K。如果为False,k=1,...,K。默认值:False。
- **fpr** (float) - 合法查询序列上的误报率。默认值:0.001

.. py:method:: clear_buffer()


+ 1
- 0
docs/api/api_python/mindarmour.adv_robustness.evaluations.rst View File

@@ -179,6 +179,7 @@ mindarmour.adv_robustness.evaluations
**参数:**

- **metrics_name** (Union[tuple, list]) - 要显示的度量名称数组。每组值对应一条雷达曲线。
- **metrics_data** (numpy.ndarray) - 多个雷达曲线的每个度量的(归一化)值,如[[0.5, 0.8, ...], [0.2,0.6,...], ...]。
- **labels** (Union[tuple, list]) - 所有雷达曲线的图例。
- **title** (str) - 图表的标题。
- **scale** (str) - 用于调整轴刻度的标量,如'hide'、'norm'、'sparse'、'dense'。默认值:'hide'。


+ 5
- 4
docs/api/api_python/mindarmour.privacy.sup_privacy.rst View File

@@ -7,7 +7,8 @@ mindarmour.privacy.sup_privacy

周期性检查抑制隐私功能状态和切换(启动/关闭)抑制操作。

详情请查看: `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_。
详情请查看: `应用抑制隐私机制保护用户隐私
<https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_。

**参数:**

@@ -26,7 +27,7 @@ mindarmour.privacy.sup_privacy

完整的模型训练功能。抑制隐私函数嵌入到重载的mindspore.train.model.Model中。

有关详细信息,请查看: `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html>`_。
有关详细信息,请查看: `应用抑制隐私机制保护用户隐私 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html>`_。

**参数:**

@@ -47,7 +48,7 @@ mindarmour.privacy.sup_privacy

SuppressCtrl机制的工厂类。

详情请查看: `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_。
详情请查看: `应用抑制隐私机制保护用户隐私 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_。

.. py:method:: create(networks, mask_layers, policy='local_train', end_epoch=10, batch_num=20, start_epoch=3, mask_times=1000, lr=0.05, sparse_end=0.90, sparse_start=0.0)

@@ -72,7 +73,7 @@ mindarmour.privacy.sup_privacy

完成抑制隐私操作,包括计算抑制比例,找到应该抑制的参数,并永久抑制这些参数。

详情请查看: `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_。
详情请查看: `应用抑制隐私机制保护用户隐私 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_。

**参数:**



+ 4
- 3
docs/api/api_python/mindarmour.reliability.rst View File

@@ -7,7 +7,7 @@ MindArmour的可靠性方法。

故障注入模块模拟深度神经网络的各种故障场景,并评估模型的性能和可靠性。

详情请查看 `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/fault_injection.html>`_。
详情请查看 `实现模型故障注入评估模型容错性 <https://mindspore.cn/mindarmour/docs/zh-CN/master/fault_injection.html>`_。

**参数:**

@@ -41,7 +41,8 @@ MindArmour的可靠性方法。
.. py:class:: mindarmour.reliability.ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10, step=10, threshold_index=1.5, need_label=False)

概念漂移检查时间序列(ConceptDriftCheckTimeSeries)用于样本序列分布变化检测。
有关详细信息,请查看 `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_time_series.html>`_。
有关详细信息,请查看 `实现时序数据概念漂移检测应用
<https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_time_series.html>`_。

**参数:**

@@ -106,7 +107,7 @@ MindArmour的可靠性方法。

训练OOD检测器。提取训练数据特征,得到聚类中心。测试数据特征与聚类中心之间的距离确定图像是否为分布外(OOD)图像。

有关详细信息,请查看 `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_images.html>`_。
有关详细信息,请查看 `实现图像数据概念漂移检测应用 <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_images.html>`_。

**参数:**



+ 5
- 4
docs/api/api_python/mindarmour.rst View File

@@ -223,7 +223,7 @@ MindArmour是MindSpore的工具箱,用于增强模型可信,实现隐私保

这个类就是重载Mindpore.train.model.Model。

详情请查看: `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_differential_privacy.html#%E5%B7%AE%E5%88%86%E9%9A%90%E7%A7%81>`_。
详情请查看: `应用差分隐私机制保护用户隐私 <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_differential_privacy.html#%E5%B7%AE%E5%88%86%E9%9A%90%E7%A7%81>`_。

**参数:**

@@ -241,7 +241,7 @@ MindArmour是MindSpore的工具箱,用于增强模型可信,实现隐私保

成员推理是由Shokri、Stronati、Song和Shmatikov提出的一种用于推测用户隐私数据的灰盒攻击。它需要训练样本的loss或logits结果。(隐私是指单个用户的一些敏感属性)。

有关详细信息,请参见:`教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/test_model_security_membership_inference.html>`_。
有关详细信息,请参见:`使用成员推理测试模型安全性 <https://mindspore.cn/mindarmour/docs/zh-CN/master/test_model_security_membership_inference.html>`_。

参考文献:`Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. Membership Inference Attacks against Machine Learning Models. 2017. <https://arxiv.org/abs/1610.05820v2>`_。

@@ -340,7 +340,8 @@ MindArmour是MindSpore的工具箱,用于增强模型可信,实现隐私保
根据target_features重建图像。

**参数:**

- **target_features** (numpy.ndarray) - 原始图像的深度表示。 `target_features` 的第一个维度应该是img_num。
需要注意的是,如果img_num等于1,则target_features的形状应该是(1, dim2, dim3, ...)。
- **iters** (int) - 逆向攻击的迭代次数,应为正整数。默认值:100。

**返回:**
@@ -356,7 +357,7 @@ MindArmour是MindSpore的工具箱,用于增强模型可信,实现隐私保

概念漂移检查时间序列(ConceptDriftCheckTimeSeries)用于样本序列分布变化检测。

有关详细信息,请查看: `教程 <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_time_series.html>`_。
有关详细信息,请查看: `实现时序数据概念漂移检测应用 <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_time_series.html>`_。

**参数:**



+ 12
- 1
mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py View File

@@ -77,8 +77,19 @@ class HopSkipJumpAttack(Attack):
Examples:
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.attacks import HopSkipJumpAttack
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._reduce = P.ReduceSum()
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._reduce(out, 2)
... out = self._squeeze(out)
... return out
>>> class ModelToBeAttacked(BlackModel):
... def __init__(self, network):
... super(ModelToBeAttacked, self).__init__()


+ 14
- 3
mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py View File

@@ -81,15 +81,26 @@ class NES(Attack):
Examples:
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.attacks import NES
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._reduce = P.ReduceSum()
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._reduce(out, 2)
... out = self._squeeze(out)
... return out
>>> class ModelToBeAttacked(BlackModel):
... def __init__(self, network):
... super(ModelToBeAttacked, self).__init__()
... self._network = network
... def predict(self, inputs):
... if len(inputs.shape) == 3:
... inputs = inputs[np.newaxis, :]
... if len(inputs.shape) == 1:
... inputs = np.expand_dims(inputs, axis=0)
... result = self._network(Tensor(inputs.astype(np.float32)))
... return result.asnumpy()
>>> net = Net()


+ 12
- 1
mindarmour/adv_robustness/attacks/black/pointwise_attack.py View File

@@ -49,8 +49,19 @@ class PointWiseAttack(Attack):
Examples:
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.attacks import PointWiseAttack
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._reduce = P.ReduceSum()
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._reduce(out, 2)
... out = self._squeeze(out)
... return out
>>> class ModelToBeAttacked(BlackModel):
... def __init__(self, network):
... super(ModelToBeAttacked, self).__init__()


+ 14
- 1
mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py View File

@@ -42,13 +42,26 @@ class SaltAndPepperNoiseAttack(Attack):
Examples:
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.attacks import SaltAndPepperNoiseAttack
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._reduce = P.ReduceSum()
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._reduce(out, 2)
... out = self._squeeze(out)
... return out
>>> class ModelToBeAttacked(BlackModel):
... def __init__(self, network):
... super(ModelToBeAttacked, self).__init__()
... self._network = network
... def predict(self, inputs):
... if len(inputs.shape) == 1:
... inputs = np.expand_dims(inputs, axis=0)
... result = self._network(Tensor(inputs.astype(np.float32)))
... return result.asnumpy()
>>> net = Net()


+ 3
- 4
mindarmour/adv_robustness/attacks/carlini_wagner.py View File

@@ -97,13 +97,12 @@ class CarliniWagnerL2Attack(Attack):
input labels are onehot-coded. Default: True.

Examples:
>>> import mindspore.ops.operations as M
>>> from mindspore.nn import Cell
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = M.Softmax()
... self._softmax = P.Softmax()
... def construct(self, inputs):
... out = self._softmax(inputs)
... return out


+ 1
- 2
mindarmour/adv_robustness/attacks/deep_fool.py View File

@@ -118,10 +118,9 @@ class DeepFool(Attack):

Examples:
>>> import mindspore.ops.operations as P
>>> from mindspore.nn import Cell
>>> from mindspore import Tensor
>>> from mindarmour.adv_robustness.attacks import DeepFool
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()


+ 12
- 18
mindarmour/adv_robustness/attacks/gradient_method.py View File

@@ -149,9 +149,8 @@ class FastGradientMethod(GradientMethod):
is already equipped with loss function. Default: None.

Examples:
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import FastGradientMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()
@@ -162,7 +161,7 @@ class FastGradientMethod(GradientMethod):
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
>>> net = Net()
>>> attack = FastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = FastGradientMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -230,9 +229,8 @@ class RandomFastGradientMethod(FastGradientMethod):
ValueError: eps is smaller than alpha!

Examples:
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import RandomFastGradientMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()
@@ -243,7 +241,7 @@ class RandomFastGradientMethod(FastGradientMethod):
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
>>> attack = RandomFastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = RandomFastGradientMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -283,9 +281,8 @@ class FastGradientSignMethod(GradientMethod):
is already equipped with loss function. Default: None.

Examples:
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import FastGradientSignMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()
@@ -296,7 +293,7 @@ class FastGradientSignMethod(GradientMethod):
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
>>> attack = FastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = FastGradientSignMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -361,9 +358,8 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
ValueError: eps is smaller than alpha!

Examples:
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import RandomFastGradientSignMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()
@@ -374,7 +370,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
>>> attack = RandomFastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = RandomFastGradientSignMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -410,9 +406,8 @@ class LeastLikelyClassMethod(FastGradientSignMethod):
is already equipped with loss function. Default: None.

Examples:
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import LeastLikelyClassMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()
@@ -423,7 +418,7 @@ class LeastLikelyClassMethod(FastGradientSignMethod):
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
>>> net = Net()
>>> attack = LeastLikelyClassMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = LeastLikelyClassMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -462,9 +457,8 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
ValueError: eps is smaller than alpha!

Examples:
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import RandomLeastLikelyClassMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()
@@ -475,7 +469,7 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
>>> net = Net()
>>> attack = RandomLeastLikelyClassMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = RandomLeastLikelyClassMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""



+ 10
- 15
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -179,9 +179,8 @@ class BasicIterativeMethod(IterativeGradientMethod):

Examples:
>>> from mindspore.ops import operations as P
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import BasicIterativeMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
@@ -189,7 +188,7 @@ class BasicIterativeMethod(IterativeGradientMethod):
... out = self._softmax(inputs)
... return out
>>> net = Net()
>>> attack = BasicIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = BasicIterativeMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
@@ -284,9 +283,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):

Examples:
>>> from mindspore.ops import operations as P
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import MomentumIterativeMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
@@ -294,7 +292,7 @@ class MomentumIterativeMethod(IterativeGradientMethod):
... out = self._softmax(inputs)
... return out
>>> net = Net()
>>> attack = MomentumIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = MomentumIterativeMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
@@ -428,9 +426,8 @@ class ProjectedGradientDescent(BasicIterativeMethod):

Examples:
>>> from mindspore.ops import operations as P
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import ProjectedGradientDescent
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
@@ -438,7 +435,7 @@ class ProjectedGradientDescent(BasicIterativeMethod):
... out = self._softmax(inputs)
... return out
>>> net = Net()
>>> attack = ProjectedGradientDescent(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = ProjectedGradientDescent(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
@@ -526,9 +523,8 @@ class DiverseInputIterativeMethod(BasicIterativeMethod):

Examples:
>>> from mindspore.ops import operations as P
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import DiverseInputIterativeMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
@@ -536,7 +532,7 @@ class DiverseInputIterativeMethod(BasicIterativeMethod):
... out = self._softmax(inputs)
... return out
>>> net = Net()
>>> attack = DiverseInputIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = DiverseInputIterativeMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)
@@ -584,9 +580,8 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):

Examples:
>>> from mindspore.ops import operations as P
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import MomentumDiverseInputIterativeMethod
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
@@ -594,7 +589,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):
... out = self._softmax(inputs)
... return out
>>> net = Net()
>>> attack = MomentumDiverseInputIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = MomentumDiverseInputIterativeMethod(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
>>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32)
>>> labels = np.asarray([2],np.int32)
>>> labels = np.eye(3)[labels].astype(np.float32)


+ 1
- 2
mindarmour/adv_robustness/attacks/jsma.py View File

@@ -56,9 +56,8 @@ class JSMAAttack(Attack):
input labels are onehot-coded. Default: True.

Examples:
>>> from mindspore.nn import Cell
>>> from mindarmour.adv_robustness.attacks import JSMAAttack
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()


+ 12
- 1
mindarmour/adv_robustness/attacks/lbfgs.py View File

@@ -56,7 +56,18 @@ class LBFGS(Attack):

Examples:
>>> from mindarmour.adv_robustness.attacks import LBFGS
>>> from tests.ut.python.utils.mock_net import Net
>>> import mindspore.ops.operations as P
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._reduce = P.ReduceSum()
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._reduce(out, 2)
... out = self._squeeze(out)
... return out
>>> net = Net()
>>> classes = 10
>>> attack = LBFGS(net, is_targeted=True)


+ 44
- 11
mindarmour/adv_robustness/defenses/adversarial_defense.py View File

@@ -37,17 +37,28 @@ class AdversarialDefense(Defense):

Examples:
>>> from mindspore.nn.optim.momentum import Momentum
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.defenses import AdversarialDefense
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._dense = nn.Dense(10, 10)
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._dense(out)
... out = self._squeeze(out)
... return out
>>> net = Net()
>>> lr = 0.001
>>> momentum = 0.9
>>> batch_size = 32
>>> batch_size = 16
>>> num_classes = 10
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
>>> adv_defense = AdversarialDefense(net, loss_fn, optimizer)
>>> inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> inputs = np.random.rand(batch_size, 1, 10).astype(np.float32)
>>> labels = np.random.randint(10, size=batch_size).astype(np.int32)
>>> labels = np.eye(num_classes)[labels].astype(np.float32)
>>> adv_defense.defense(inputs, labels)
@@ -106,14 +117,25 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):

Examples:
>>> from mindspore.nn.optim.momentum import Momentum
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.attacks import FastGradientSignMethod
>>> from mindarmour.adv_robustness.attacks import ProjectedGradientDescent
>>> from mindarmour.adv_robustness.defenses import AdversarialDefenseWithAttacks
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._dense = nn.Dense(10, 10)
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._dense(out)
... out = self._squeeze(out)
... return out
>>> net = Net()
>>> lr = 0.001
>>> momentum = 0.9
>>> batch_size = 32
>>> batch_size = 16
>>> num_classes = 10
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
@@ -121,8 +143,8 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
>>> pgd = ProjectedGradientDescent(net, loss_fn=loss_fn)
>>> ead = AdversarialDefenseWithAttacks(net, [fgsm, pgd], loss_fn=loss_fn,
... optimizer=optimizer)
>>> inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> labels = np.random.randint(num_classes, size=batch_size).astype(np.int32)
>>> inputs = np.random.rand(batch_size, 1, 10).astype(np.float32)
>>> labels = np.random.randint(10, size=batch_size).astype(np.int32)
>>> labels = np.eye(num_classes)[labels].astype(np.float32)
>>> loss = ead.defense(inputs, labels)
"""
@@ -193,14 +215,25 @@ class EnsembleAdversarialDefense(AdversarialDefenseWithAttacks):

Examples:
>>> from mindspore.nn.optim.momentum import Momentum
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.attacks import FastGradientSignMethod
>>> from mindarmour.adv_robustness.attacks import ProjectedGradientDescent
>>> from mindarmour.adv_robustness.defenses import EnsembleAdversarialDefense
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._dense = nn.Dense(10, 10)
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._dense(out)
... out = self._squeeze(out)
... return out
>>> net = Net()
>>> lr = 0.001
>>> momentum = 0.9
>>> batch_size = 32
>>> batch_size = 16
>>> num_classes = 10
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
@@ -208,8 +241,8 @@ class EnsembleAdversarialDefense(AdversarialDefenseWithAttacks):
>>> pgd = ProjectedGradientDescent(net, loss_fn=loss_fn)
>>> ead = EnsembleAdversarialDefense(net, [fgsm, pgd], loss_fn=loss_fn,
... optimizer=optimizer)
>>> inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> labels = np.random.randint(num_classes, size=batch_size).astype(np.int32)
>>> inputs = np.random.rand(batch_size, 1, 10).astype(np.float32)
>>> labels = np.random.randint(10, size=batch_size).astype(np.int32)
>>> labels = np.eye(num_classes)[labels].astype(np.float32)
>>> loss = ead.defense(inputs, labels)
"""


+ 15
- 4
mindarmour/adv_robustness/defenses/natural_adversarial_defense.py View File

@@ -37,18 +37,29 @@ class NaturalAdversarialDefense(AdversarialDefenseWithAttacks):

Examples:
>>> from mindspore.nn.optim.momentum import Momentum
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.defenses import NaturalAdversarialDefense
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._dense = nn.Dense(10, 10)
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._dense(out)
... out = self._squeeze(out)
... return out
>>> net = Net()
>>> lr = 0.001
>>> momentum = 0.9
>>> batch_size = 32
>>> batch_size = 16
>>> num_classes = 10
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
>>> nad = NaturalAdversarialDefense(net, loss_fn=loss_fn, optimizer=optimizer)
>>> inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> labels = np.random.randint(num_classes, size=batch_size).astype(np.int32)
>>> inputs = np.random.rand(batch_size, 1, 10).astype(np.float32)
>>> labels = np.random.randint(10, size=batch_size).astype(np.int32)
>>> labels = np.eye(num_classes)[labels].astype(np.float32)
>>> loss = nad.defense(inputs, labels)
"""


+ 15
- 4
mindarmour/adv_robustness/defenses/projected_adversarial_defense.py View File

@@ -43,18 +43,29 @@ class ProjectedAdversarialDefense(AdversarialDefenseWithAttacks):

Examples:
>>> from mindspore.nn.optim.momentum import Momentum
>>> import mindspore.ops.operations as P
>>> from mindarmour.adv_robustness.defenses import ProjectedAdversarialDefense
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._dense = nn.Dense(10, 10)
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._dense(out)
... out = self._squeeze(out)
... return out
>>> net = Net()
>>> lr = 0.001
>>> momentum = 0.9
>>> batch_size = 32
>>> batch_size = 16
>>> num_classes = 10
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
>>> pad = ProjectedAdversarialDefense(net, loss_fn=loss_fn, optimizer=optimizer)
>>> inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> labels = np.random.randint(num_classes, size=batch_size).astype(np.int32)
>>> inputs = np.random.rand(batch_size, 1, 10).astype(np.float32)
>>> labels = np.random.randint(10, size=batch_size).astype(np.int32)
>>> labels = np.eye(num_classes)[labels].astype(np.float32)
>>> loss = pad.defense(inputs, labels)
"""


+ 1
- 2
mindarmour/adv_robustness/detectors/ensemble_detector.py View File

@@ -36,12 +36,11 @@ class EnsembleDetector(Detector):
Default: 'vote'
Examples:
>>> from mindspore.ops.operations import Add
>>> from mindspore.nn import Cell
>>> from mindspore import Model
>>> from mindarmour.adv_robustness.detectors import ErrorBasedDetector
>>> from mindarmour.adv_robustness.detectors import RegionBasedDetector
>>> from mindarmour.adv_robustness.detectors import EnsembleDetector
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.add = Add()


+ 1
- 2
mindarmour/adv_robustness/detectors/mag_net.py View File

@@ -49,10 +49,9 @@ class ErrorBasedDetector(Detector):

Examples:
>>> from mindspore.ops.operations import Add
>>> from mindspore.nn import Cell
>>> from mindspore import Model
>>> from mindarmour.adv_robustness.detectors import ErrorBasedDetector
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.add = Add()


+ 1
- 2
mindarmour/adv_robustness/detectors/region_based_detector.py View File

@@ -55,10 +55,9 @@ class RegionBasedDetector(Detector):

Examples:
>>> from mindspore.ops.operations import Add
>>> from mindspore.nn import Cell
>>> from mindspore import Model
>>> from mindarmour.adv_robustness.detectors import RegionBasedDetector
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.add = Add()


+ 1
- 2
mindarmour/adv_robustness/detectors/spatial_smoothing.py View File

@@ -54,10 +54,9 @@ class SpatialSmoothing(Detector):

Examples:
>>> import mindspore.ops.operations as P
>>> from mindspore.nn import Cell
>>> from mindspore import Model
>>> from mindarmour.adv_robustness.detectors import SpatialSmoothing
>>> class Net(Cell):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()


+ 9
- 4
mindarmour/privacy/diff_privacy/optimizer/optimizer.py View File

@@ -63,7 +63,13 @@ class DPOptimizerClassFactory:

Examples:
>>> from mindarmour.privacy.diff_privacy import DPOptimizerClassFactory
>>> from tests.ut.python.utils.mock_net import Net
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._relu = nn.ReLU()
... def construct(self, inputs):
... out = self._relu(inputs)
... return out
>>> network = Net()
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2)
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5)
@@ -79,9 +85,8 @@ class DPOptimizerClassFactory:

def set_mechanisms(self, policy, *args, **kwargs):
"""
Get noise mechanism object. Policies can be 'sgd', 'momentum'
or 'adam'. Candidate args and kwargs can be seen in class
NoiseMechanismsFactory of mechanisms.py.
Get noise mechanism object. Policies can be 'Gaussian' or 'AdaGaussian'.
Candidate args and kwargs can be seen in class NoiseMechanismsFactory of mechanisms.py.

Args:
policy (str): Choose mechanism type.


+ 2
- 1
mindarmour/privacy/diff_privacy/train/model.py View File

@@ -70,7 +70,8 @@ class DPModel(Model):
DPModel is used for constructing a model for differential privacy training.
This class is overload mindspore.train.model.Model.

For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_differential_privacy.html#%E5%B7%AE%E5%88%86%E9%9A%90%E7%A7%81>`_.
For details, please check `Protecting User Privacy with Differential Privacy Mechanism
<https://mindspore.cn/mindarmour/docs/en/master/protect_user_privacy_with_differential_privacy.html#%E5%B7%AE%E5%88%86%E9%9A%90%E7%A7%81>`_.

Args:
micro_batches (int): The number of small batches split from an original


+ 2
- 1
mindarmour/privacy/evaluation/membership_inference.py View File

@@ -98,7 +98,8 @@ class MembershipInference:
for inferring user's privacy data. It requires loss or logits results of the training samples.
(Privacy refers to some sensitive attributes of a single user).

For details, please refer to the `Tutorial <https://mindspore.cn/mindarmour/docs/en/master/test_model_security_membership_inference.html>`_.
For details, please refer to the `Using Membership Inference to Test Model Security
<https://mindspore.cn/mindarmour/docs/en/master/test_model_security_membership_inference.html>`_.

References: `Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov.
Membership Inference Attacks against Machine Learning Models. 2017.


+ 2
- 1
mindarmour/privacy/sup_privacy/mask_monitor/masker.py View File

@@ -27,7 +27,8 @@ TAG = 'suppress masker'
class SuppressMasker(Callback):
"""
Periodicity check suppress privacy function status and toggle suppress operation.
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_.
For details, please check `Protecting User Privacy with Suppression Privacy
<https://mindspore.cn/mindarmour/docs/en/master/protect_user_privacy_with_suppress_privacy.html>`_.

Args:
model (SuppressModel): SuppressModel instance.


+ 4
- 2
mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py View File

@@ -35,7 +35,8 @@ class SuppressPrivacyFactory:
"""
Factory class of SuppressCtrl mechanisms.

For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_.
For details, please check `Protecting User Privacy with Suppress Privacy
<https://mindspore.cn/mindarmour/docs/en/master/protect_user_privacy_with_suppress_privacy.html>`_.
"""

def __init__(self):
@@ -118,7 +119,8 @@ class SuppressCtrl(Cell):
finding the parameters that should be suppressed, and suppress these
parameters permanently.

For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_.
For details, please check `Protecting User Privacy with Suppress Privacy
<https://mindspore.cn/mindarmour/docs/en/master/protect_user_privacy_with_suppress_privacy.html>`_.

Args:
networks (Cell): The training network.


+ 2
- 1
mindarmour/privacy/sup_privacy/train/model.py View File

@@ -59,7 +59,8 @@ class SuppressModel(Model):
Complete model train function. The suppress privacy function is embedded into the overload
mindspore.train.model.Model.

For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html>`_.
For details, please check `Protecting User Privacy with Suppress Privacy
<https://mindspore.cn/mindarmour/docs/en/master/protect_user_privacy_with_suppress_privacy.html>`_.

Args:
network (Cell): The training network.


+ 2
- 1
mindarmour/reliability/concept_drift/concept_drift_check_images.py View File

@@ -89,7 +89,8 @@ class OodDetectorFeatureCluster(OodDetector):
the testing data features and the clustering centers determines whether an image is an out-of-distribution(OOD)
image or not.

For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_images.html>`_.
For details, please check `Implementing the Concept Drift Detection Application of Image Data
<https://mindspore.cn/mindarmour/docs/en/master/concept_drift_images.html>`_.

Args:
model (Model):The training model.


+ 2
- 1
mindarmour/reliability/concept_drift/concept_drift_check_time_series.py View File

@@ -23,7 +23,8 @@ from mindarmour.utils._check_param import check_param_type, check_param_in_range
class ConceptDriftCheckTimeSeries:
r"""
ConceptDriftCheckTimeSeries is used for example series distribution change detection.
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_time_series.html>`_.
For details, please check `Implementing the Concept Drift Detection Application of Time Series Data
<https://mindspore.cn/mindarmour/docs/en/master/concept_drift_time_series.html>`_.

Args:
window_size(int): Size of a concept window, no less than 10. If given the input data,


+ 2
- 1
mindarmour/reliability/model_fault_injection/fault_injection.py View File

@@ -31,7 +31,8 @@ class FaultInjector:
Fault injection module simulates various fault scenarios for deep neural networks and evaluates
performance and reliability of the model.

For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/fault_injection.html>`_.
For details, please check `Implementing the Model Fault Injection and Evaluation
<https://mindspore.cn/mindarmour/docs/en/master/fault_injection.html>`_.

Args:
model (Model): The model need to be evaluated.


Loading…
Cancel
Save