diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/README.md b/examples/model_security/model_attacks/cv/faster_rcnn/README.md
new file mode 100644
index 0000000..92317da
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/README.md
@@ -0,0 +1,47 @@
+# Dataset
+
+Dataset used: [COCO2017]()
+
+- Dataset size:19G
+ - Train:18G,118000 images
+ - Val:1G,5000 images
+ - Annotations:241M,instances,captions,person_keypoints etc
+- Data format:image and json files
+ - Note:Data will be processed in dataset.py
+
+# Environment Requirements
+
+- Install [MindSpore](https://www.mindspore.cn/install/en).
+
+- Download the dataset COCO2017.
+
+- We use COCO2017 as dataset in this example.
+
+ Install Cython and pycocotool, and you can also install mmcv to process data.
+
+ ```
+ pip install Cython
+
+ pip install pycocotools
+
+ pip install mmcv==0.2.14
+ ```
+
+ And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
+
+ ```
+ .
+ └─cocodataset
+ ├─annotations
+ ├─instance_train2017.json
+ └─instance_val2017.json
+ ├─val2017
+ └─train2017
+ ```
+
+# Quick start
+You can download the pre-trained model checkpoint file [here]().
+```
+python coco_attack_pgd.py --ann_file [VAL_JSON_FILE] --pre_trained [PRETRAINED_CHECKPOINT_FILE]
+```
+> Adversarial samples will be generated and saved as pickle file.
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pgd.py b/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pgd.py
new file mode 100755
index 0000000..835f07b
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pgd.py
@@ -0,0 +1,135 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PGD attack for faster rcnn"""
+import os
+import argparse
+import pickle
+
+from mindspore import context
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.common import set_seed
+from mindspore.nn import Cell
+from mindspore.ops.composite import GradOperation
+
+from mindarmour.adv_robustness.attacks import ProjectedGradientDescent
+
+from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
+from src.config import config
+from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
+
+# pylint: disable=locally-disabled, unused-argument, redefined-outer-name
+
+set_seed(1)
+
+parser = argparse.ArgumentParser(description='FasterRCNN attack')
+parser.add_argument('--ann_file', type=str, required=True, help='Ann file path.')
+parser.add_argument('--pre_trained', type=str, required=True, help='pre-trained ckpt file path for target model.')
+parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.')
+parser.add_argument('--num', type=int, default=5, help='Number of adversarial examples.')
+args = parser.parse_args()
+
+context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=args.device_id)
+
+
+class LossNet(Cell):
+ """loss function."""
+ def construct(self, x1, x2, x3, x4, x5, x6):
+ return x4 + x6
+
+
+class WithLossCell(Cell):
+ """Wrap the network with loss function."""
+ def __init__(self, backbone, loss_fn):
+ super(WithLossCell, self).__init__(auto_prefix=False)
+ self._backbone = backbone
+ self._loss_fn = loss_fn
+
+ def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num):
+ loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
+ return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6)
+
+ @property
+ def backbone_network(self):
+ return self._backbone
+
+
+class GradWrapWithLoss(Cell):
+ """
+ Construct a network to compute the gradient of loss function in \
+ input space and weighted by `weight`.
+ """
+ def __init__(self, network):
+ super(GradWrapWithLoss, self).__init__()
+ self._grad_all = GradOperation(get_all=True, sens_param=False)
+ self._network = network
+
+ def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num):
+ gout = self._grad_all(self._network)(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
+ return gout[0]
+
+
+if __name__ == '__main__':
+ prefix = 'FasterRcnn_eval.mindrecord'
+ mindrecord_dir = config.mindrecord_dir
+ mindrecord_file = os.path.join(mindrecord_dir, prefix)
+ pre_trained = args.pre_trained
+ ann_file = args.ann_file
+
+ print("CHECKING MINDRECORD FILES ...")
+ if not os.path.exists(mindrecord_file):
+ if not os.path.isdir(mindrecord_dir):
+ os.makedirs(mindrecord_dir)
+ if os.path.isdir(config.coco_root):
+ print("Create Mindrecord. It may take some time.")
+ data_to_mindrecord_byte_image("coco", False, prefix, file_num=1)
+ print("Create Mindrecord Done, at {}".format(mindrecord_dir))
+ else:
+ print("coco_root not exits.")
+
+ print('Start generate adversarial samples.')
+
+ # build network and dataset
+ ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \
+ repeat_num=1, is_training=True)
+ net = Faster_Rcnn_Resnet50(config)
+ param_dict = load_checkpoint(pre_trained)
+ load_param_into_net(net, param_dict)
+ net = net.set_train()
+
+ # build attacker
+ with_loss_cell = WithLossCell(net, LossNet())
+ grad_with_loss_net = GradWrapWithLoss(with_loss_cell)
+ attack = ProjectedGradientDescent(grad_with_loss_net, bounds=None, eps=0.1)
+
+ # generate adversarial samples
+ num = args.num
+ num_batches = num // config.test_batch_size
+ channel = 3
+ adv_samples = [0] * (num_batches * config.test_batch_size)
+ adv_id = 0
+ for data in ds.create_dict_iterator(num_epochs=num_batches):
+ img_data = data['image']
+ img_metas = data['image_shape']
+ gt_bboxes = data['box']
+ gt_labels = data['label']
+ gt_num = data['valid_num']
+
+ adv_img = attack.generate(img_data.asnumpy(), \
+ (img_metas.asnumpy(), gt_bboxes.asnumpy(), gt_labels.asnumpy(), gt_num.asnumpy()))
+ for item in adv_img:
+ adv_samples[adv_id] = item
+ adv_id += 1
+
+ pickle.dump(adv_samples, open('adv_samples.pkl', 'wb'))
+ print('Generate adversarial samples complete.')
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/__init__.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/__init__.py
new file mode 100644
index 0000000..cbc0a27
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn Init."""
+
+from .resnet50 import ResNetFea, ResidualBlockUsing
+from .bbox_assign_sample import BboxAssignSample
+from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
+from .fpn_neck import FeatPyramidNeck
+from .proposal_generator import Proposal
+from .rcnn import Rcnn
+from .rpn import RPN
+from .roi_align import SingleRoIExtractor
+from .anchor_generator import AnchorGenerator
+
+__all__ = [
+ "ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn",
+ "FeatPyramidNeck", "Proposal", "Rcnn",
+ "RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing"
+]
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/anchor_generator.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/anchor_generator.py
new file mode 100644
index 0000000..666508c
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/anchor_generator.py
@@ -0,0 +1,84 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn anchor generator."""
+
+import numpy as np
+
+class AnchorGenerator():
+ """Anchor generator for FasterRcnn."""
+ def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
+ """Anchor generator init method."""
+ self.base_size = base_size
+ self.scales = np.array(scales)
+ self.ratios = np.array(ratios)
+ self.scale_major = scale_major
+ self.ctr = ctr
+ self.base_anchors = self.gen_base_anchors()
+
+ def gen_base_anchors(self):
+ """Generate a single anchor."""
+ w = self.base_size
+ h = self.base_size
+ if self.ctr is None:
+ x_ctr = 0.5 * (w - 1)
+ y_ctr = 0.5 * (h - 1)
+ else:
+ x_ctr, y_ctr = self.ctr
+
+ h_ratios = np.sqrt(self.ratios)
+ w_ratios = 1 / h_ratios
+ if self.scale_major:
+ ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1)
+ hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1)
+ else:
+ ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1)
+ hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1)
+
+ base_anchors = np.stack(
+ [
+ x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
+ x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
+ ],
+ axis=-1).round()
+
+ return base_anchors
+
+ def _meshgrid(self, x, y, row_major=True):
+ """Generate grid."""
+ xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
+ yy = np.repeat(y, len(x))
+ if row_major:
+ return xx, yy
+
+ return yy, xx
+
+ def grid_anchors(self, featmap_size, stride=16):
+ """Generate anchor list."""
+ base_anchors = self.base_anchors
+
+ feat_h, feat_w = featmap_size
+ shift_x = np.arange(0, feat_w) * stride
+ shift_y = np.arange(0, feat_h) * stride
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
+ shifts = shifts.astype(base_anchors.dtype)
+ # first feat_w elements correspond to the first row of shifts
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
+
+ all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
+ all_anchors = all_anchors.reshape(-1, 4)
+
+ return all_anchors
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py
new file mode 100644
index 0000000..2645edf
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py
@@ -0,0 +1,166 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn positive and negative sample screening for RPN."""
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+from mindspore.common.tensor import Tensor
+import mindspore.common.dtype as mstype
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+
+class BboxAssignSample(nn.Cell):
+ """
+ Bbox assigner and sampler defination.
+
+ Args:
+ config (dict): Config.
+ batch_size (int): Batchsize.
+ num_bboxes (int): The anchor nums.
+ add_gt_as_proposals (bool): add gt bboxes as proposals flag.
+
+ Returns:
+ Tensor, output tensor.
+ bbox_targets: bbox location, (batch_size, num_bboxes, 4)
+ bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
+ labels: label for every bboxes, (batch_size, num_bboxes, 1)
+ label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
+
+ Examples:
+ BboxAssignSample(config, 2, 1024, True)
+ """
+
+ def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
+ super(BboxAssignSample, self).__init__()
+ cfg = config
+ self.batch_size = batch_size
+
+ self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
+ self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
+ self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
+ self.zero_thr = Tensor(0.0, mstype.float16)
+
+ self.num_bboxes = num_bboxes
+ self.num_gts = cfg.num_gts
+ self.num_expected_pos = cfg.num_expected_pos
+ self.num_expected_neg = cfg.num_expected_neg
+ self.add_gt_as_proposals = add_gt_as_proposals
+
+ if self.add_gt_as_proposals:
+ self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
+
+ self.concat = P.Concat(axis=0)
+ self.max_gt = P.ArgMaxWithValue(axis=0)
+ self.max_anchor = P.ArgMaxWithValue(axis=1)
+ self.sum_inds = P.ReduceSum()
+ self.iou = P.IOU()
+ self.greaterequal = P.GreaterEqual()
+ self.greater = P.Greater()
+ self.select = P.Select()
+ self.gatherND = P.GatherNd()
+ self.squeeze = P.Squeeze()
+ self.cast = P.Cast()
+ self.logicaland = P.LogicalAnd()
+ self.less = P.Less()
+ self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
+ self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
+ self.reshape = P.Reshape()
+ self.equal = P.Equal()
+ self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
+ self.scatterNdUpdate = P.ScatterNdUpdate()
+ self.scatterNd = P.ScatterNd()
+ self.logicalnot = P.LogicalNot()
+ self.tile = P.Tile()
+ self.zeros_like = P.ZerosLike()
+
+ self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
+ self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
+ self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
+ self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
+ self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
+
+ self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
+ self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
+ self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
+ self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
+
+
+ def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
+ gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
+ (self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
+ bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
+ (self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
+
+ overlaps = self.iou(bboxes, gt_bboxes_i)
+
+ max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
+ _, max_overlaps_w_ac = self.max_anchor(overlaps)
+
+ neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
+ self.less(max_overlaps_w_gt, self.neg_iou_thr))
+ assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
+
+ pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
+ assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
+ max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
+ assigned_gt_inds4 = assigned_gt_inds3
+ for j in range(self.num_gts):
+ max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
+ overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
+
+ pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
+ self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
+
+ assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
+
+ assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
+
+ pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
+
+ pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
+ pos_check_valid = self.sum_inds(pos_check_valid, -1)
+ valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
+ pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
+
+ pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
+ pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
+ pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
+
+ neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
+
+ num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
+ num_pos = self.sum_inds(num_pos, -1)
+ unvalid_pos_index = self.less(self.range_pos_size, num_pos)
+ valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
+
+ pos_bboxes_ = self.gatherND(bboxes, pos_index)
+ pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
+ pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
+
+ pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
+
+ valid_pos_index = self.cast(valid_pos_index, mstype.int32)
+ valid_neg_index = self.cast(valid_neg_index, mstype.int32)
+ bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
+ bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
+ labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
+ total_index = self.concat((pos_index, neg_index))
+ total_valid_index = self.concat((valid_pos_index, valid_neg_index))
+ label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
+
+ return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
+ labels_total, self.cast(label_weights_total, mstype.bool_)
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py
new file mode 100644
index 0000000..6fbc075
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py
@@ -0,0 +1,197 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn tpositive and negative sample screening for Rcnn."""
+
+import numpy as np
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+from mindspore.ops import operations as P
+from mindspore.common.tensor import Tensor
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+
+class BboxAssignSampleForRcnn(nn.Cell):
+ """
+ Bbox assigner and sampler defination.
+
+ Args:
+ config (dict): Config.
+ batch_size (int): Batchsize.
+ num_bboxes (int): The anchor nums.
+ add_gt_as_proposals (bool): add gt bboxes as proposals flag.
+
+ Returns:
+ Tensor, output tensor.
+ bbox_targets: bbox location, (batch_size, num_bboxes, 4)
+ bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
+ labels: label for every bboxes, (batch_size, num_bboxes, 1)
+ label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
+
+ Examples:
+ BboxAssignSampleForRcnn(config, 2, 1024, True)
+ """
+
+ def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
+ super(BboxAssignSampleForRcnn, self).__init__()
+ cfg = config
+ self.batch_size = batch_size
+ self.neg_iou_thr = cfg.neg_iou_thr_stage2
+ self.pos_iou_thr = cfg.pos_iou_thr_stage2
+ self.min_pos_iou = cfg.min_pos_iou_stage2
+ self.num_gts = cfg.num_gts
+ self.num_bboxes = num_bboxes
+ self.num_expected_pos = cfg.num_expected_pos_stage2
+ self.num_expected_neg = cfg.num_expected_neg_stage2
+ self.num_expected_total = cfg.num_expected_total_stage2
+
+ self.add_gt_as_proposals = add_gt_as_proposals
+ self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
+ self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts),
+ dtype=np.int32))
+
+ self.concat = P.Concat(axis=0)
+ self.max_gt = P.ArgMaxWithValue(axis=0)
+ self.max_anchor = P.ArgMaxWithValue(axis=1)
+ self.sum_inds = P.ReduceSum()
+ self.iou = P.IOU()
+ self.greaterequal = P.GreaterEqual()
+ self.greater = P.Greater()
+ self.select = P.Select()
+ self.gatherND = P.GatherNd()
+ self.squeeze = P.Squeeze()
+ self.cast = P.Cast()
+ self.logicaland = P.LogicalAnd()
+ self.less = P.Less()
+ self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
+ self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
+ self.reshape = P.Reshape()
+ self.equal = P.Equal()
+ self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
+ self.concat_axis1 = P.Concat(axis=1)
+ self.logicalnot = P.LogicalNot()
+ self.tile = P.Tile()
+
+ # Check
+ self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
+ self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
+
+ # Init tensor
+ self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
+ self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
+ self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
+ self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
+ self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
+
+ self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
+ self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
+ self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
+ self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
+ self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
+
+ self.reshape_shape_pos = (self.num_expected_pos, 1)
+ self.reshape_shape_neg = (self.num_expected_neg, 1)
+
+ self.scalar_zero = Tensor(0.0, dtype=mstype.float16)
+ self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16)
+ self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16)
+ self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16)
+
+ def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
+ gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
+ (self.num_gts, 1)), (1, 4)), mstype.bool_), \
+ gt_bboxes_i, self.check_gt_one)
+ bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
+ (self.num_bboxes, 1)), (1, 4)), mstype.bool_), \
+ bboxes, self.check_anchor_two)
+
+ overlaps = self.iou(bboxes, gt_bboxes_i)
+
+ max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
+ _, max_overlaps_w_ac = self.max_anchor(overlaps)
+
+ neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
+ self.scalar_zero),
+ self.less(max_overlaps_w_gt,
+ self.scalar_neg_iou_thr))
+
+ assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
+
+ pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr)
+ assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
+ max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
+
+ for j in range(self.num_gts):
+ max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
+ overlaps_w_ac_j = overlaps[j:j+1:1, ::]
+ temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou)
+ temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j))
+ pos_mask_j = self.logicaland(temp1, temp2)
+ assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3)
+
+ assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores)
+
+ bboxes = self.concat((gt_bboxes_i, bboxes))
+ label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores)
+ label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid
+ assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5))
+
+ # Get pos index
+ pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
+
+ pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
+ pos_check_valid = self.sum_inds(pos_check_valid, -1)
+ valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
+ pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
+
+ num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1)
+ valid_pos_index = self.cast(valid_pos_index, mstype.int32)
+ pos_index = self.reshape(pos_index, self.reshape_shape_pos)
+ valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
+ pos_index = pos_index * valid_pos_index
+
+ pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
+ pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
+ pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index
+
+ pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
+
+ # Get neg index
+ neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
+
+ unvalid_pos_index = self.less(self.range_pos_size, num_pos)
+ valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
+ neg_index = self.reshape(neg_index, self.reshape_shape_neg)
+
+ valid_neg_index = self.cast(valid_neg_index, mstype.int32)
+ valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
+ neg_index = neg_index * valid_neg_index
+
+ pos_bboxes_ = self.gatherND(bboxes, pos_index)
+
+ neg_bboxes_ = self.gatherND(bboxes, neg_index)
+ pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
+ pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
+ pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
+
+ total_bboxes = self.concat((pos_bboxes_, neg_bboxes_))
+ total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask))
+ total_labels = self.concat((pos_gt_labels, self.labels_neg_mask))
+
+ valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
+ valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
+ total_mask = self.concat((valid_pos_index, valid_neg_index))
+
+ return total_bboxes, total_deltas, total_labels, total_mask
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py
new file mode 100644
index 0000000..891b030
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py
@@ -0,0 +1,428 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn based on ResNet50."""
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+from mindspore.common.tensor import Tensor
+import mindspore.common.dtype as mstype
+from mindspore.ops import functional as F
+from .resnet50 import ResNetFea, ResidualBlockUsing
+from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
+from .fpn_neck import FeatPyramidNeck
+from .proposal_generator import Proposal
+from .rcnn import Rcnn
+from .rpn import RPN
+from .roi_align import SingleRoIExtractor
+from .anchor_generator import AnchorGenerator
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+
+class Faster_Rcnn_Resnet50(nn.Cell):
+ """
+ FasterRcnn Network.
+
+ Note:
+ backbone = resnet50
+
+ Returns:
+ Tuple, tuple of output tensor.
+ rpn_loss: Scalar, Total loss of RPN subnet.
+ rcnn_loss: Scalar, Total loss of RCNN subnet.
+ rpn_cls_loss: Scalar, Classification loss of RPN subnet.
+ rpn_reg_loss: Scalar, Regression loss of RPN subnet.
+ rcnn_cls_loss: Scalar, Classification loss of RCNN subnet.
+ rcnn_reg_loss: Scalar, Regression loss of RCNN subnet.
+
+ Examples:
+ net = Faster_Rcnn_Resnet50()
+ """
+ def __init__(self, config):
+ super(Faster_Rcnn_Resnet50, self).__init__()
+ self.train_batch_size = config.batch_size
+ self.num_classes = config.num_classes
+ self.anchor_scales = config.anchor_scales
+ self.anchor_ratios = config.anchor_ratios
+ self.anchor_strides = config.anchor_strides
+ self.target_means = tuple(config.rcnn_target_means)
+ self.target_stds = tuple(config.rcnn_target_stds)
+
+ # Anchor generator
+ anchor_base_sizes = None
+ self.anchor_base_sizes = list(
+ self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
+
+ self.anchor_generators = []
+ for anchor_base in self.anchor_base_sizes:
+ self.anchor_generators.append(
+ AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
+
+ self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
+
+ featmap_sizes = config.feature_shapes
+ assert len(featmap_sizes) == len(self.anchor_generators)
+
+ self.anchor_list = self.get_anchors(featmap_sizes)
+
+ # Backbone resnet50
+ self.backbone = ResNetFea(ResidualBlockUsing,
+ config.resnet_block,
+ config.resnet_in_channels,
+ config.resnet_out_channels,
+ False)
+
+ # Fpn
+ self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
+ config.fpn_out_channels,
+ config.fpn_num_outs)
+
+ # Rpn and rpn loss
+ self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
+ self.rpn_with_loss = RPN(config,
+ self.train_batch_size,
+ config.rpn_in_channels,
+ config.rpn_feat_channels,
+ config.num_anchors,
+ config.rpn_cls_out_channels)
+
+ # Proposal
+ self.proposal_generator = Proposal(config,
+ self.train_batch_size,
+ config.activate_num_classes,
+ config.use_sigmoid_cls)
+ self.proposal_generator.set_train_local(config, True)
+ self.proposal_generator_test = Proposal(config,
+ config.test_batch_size,
+ config.activate_num_classes,
+ config.use_sigmoid_cls)
+ self.proposal_generator_test.set_train_local(config, False)
+
+ # Assign and sampler stage two
+ self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
+ config.num_bboxes_stage2, True)
+ self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=self.target_means, \
+ stds=self.target_stds)
+
+ # Roi
+ self.roi_align = SingleRoIExtractor(config,
+ config.roi_layer,
+ config.roi_align_out_channels,
+ config.roi_align_featmap_strides,
+ self.train_batch_size,
+ config.roi_align_finest_scale)
+ self.roi_align.set_train_local(config, True)
+ self.roi_align_test = SingleRoIExtractor(config,
+ config.roi_layer,
+ config.roi_align_out_channels,
+ config.roi_align_featmap_strides,
+ 1,
+ config.roi_align_finest_scale)
+ self.roi_align_test.set_train_local(config, False)
+
+ # Rcnn
+ self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer['out_size'] * config.roi_layer['out_size'],
+ self.train_batch_size, self.num_classes)
+
+ # Op declare
+ self.squeeze = P.Squeeze()
+ self.cast = P.Cast()
+
+ self.concat = P.Concat(axis=0)
+ self.concat_1 = P.Concat(axis=1)
+ self.concat_2 = P.Concat(axis=2)
+ self.reshape = P.Reshape()
+ self.select = P.Select()
+ self.greater = P.Greater()
+ self.transpose = P.Transpose()
+
+ # Test mode
+ self.test_batch_size = config.test_batch_size
+ self.split = P.Split(axis=0, output_num=self.test_batch_size)
+ self.split_shape = P.Split(axis=0, output_num=4)
+ self.split_scores = P.Split(axis=1, output_num=self.num_classes)
+ self.split_cls = P.Split(axis=0, output_num=self.num_classes-1)
+ self.tile = P.Tile()
+ self.gather = P.GatherNd()
+
+ self.rpn_max_num = config.rpn_max_num
+
+ self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16))
+ self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
+ self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
+ self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
+ self.ones_mask, self.zeros_mask), axis=1))
+ self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
+ self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
+
+ self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr)
+ self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0)
+ self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1)
+ self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr)
+ self.test_max_per_img = config.test_max_per_img
+ self.nms_test = P.NMSWithMask(config.test_iou_thr)
+ self.softmax = P.Softmax(axis=1)
+ self.logicand = P.LogicalAnd()
+ self.oneslike = P.OnesLike()
+ self.test_topk = P.TopK(sorted=True)
+ self.test_num_proposal = self.test_batch_size * self.rpn_max_num
+
+ # Improve speed
+ self.concat_start = min(self.num_classes - 2, 55)
+ self.concat_end = (self.num_classes - 1)
+
+ # Init tensor
+ roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
+ dtype=np.float16) for i in range(self.train_batch_size)]
+
+ roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \
+ for i in range(self.test_batch_size)]
+
+ self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
+ self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
+
+ def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
+ x = self.backbone(img_data)
+ x = self.fpn_ncek(x)
+
+ rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
+ img_metas,
+ self.anchor_list,
+ gt_bboxes,
+ self.gt_labels_stage1,
+ gt_valids)
+
+ if self.training:
+ proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
+ else:
+ proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
+
+ gt_labels = self.cast(gt_labels, mstype.int32)
+ gt_valids = self.cast(gt_valids, mstype.int32)
+ bboxes_tuple = ()
+ deltas_tuple = ()
+ labels_tuple = ()
+ mask_tuple = ()
+ if self.training:
+ for i in range(self.train_batch_size):
+ gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
+
+ gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
+ gt_labels_i = self.cast(gt_labels_i, mstype.uint8)
+
+ gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
+ gt_valids_i = self.cast(gt_valids_i, mstype.bool_)
+
+ bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
+ gt_labels_i,
+ proposal_mask[i],
+ proposal[i][::, 0:4:1],
+ gt_valids_i)
+ bboxes_tuple += (bboxes,)
+ deltas_tuple += (deltas,)
+ labels_tuple += (labels,)
+ mask_tuple += (mask,)
+
+ bbox_targets = self.concat(deltas_tuple)
+ rcnn_labels = self.concat(labels_tuple)
+ bbox_targets = F.stop_gradient(bbox_targets)
+ rcnn_labels = F.stop_gradient(rcnn_labels)
+ rcnn_labels = self.cast(rcnn_labels, mstype.int32)
+ else:
+ mask_tuple += proposal_mask
+ bbox_targets = proposal_mask
+ rcnn_labels = proposal_mask
+ for p_i in proposal:
+ bboxes_tuple += (p_i[::, 0:4:1],)
+
+ if self.training:
+ if self.train_batch_size > 1:
+ bboxes_all = self.concat(bboxes_tuple)
+ else:
+ bboxes_all = bboxes_tuple[0]
+ rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
+ else:
+ if self.test_batch_size > 1:
+ bboxes_all = self.concat(bboxes_tuple)
+ else:
+ bboxes_all = bboxes_tuple[0]
+ rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
+
+
+ rois = self.cast(rois, mstype.float32)
+ rois = F.stop_gradient(rois)
+
+ if self.training:
+ roi_feats = self.roi_align(rois,
+ self.cast(x[0], mstype.float32),
+ self.cast(x[1], mstype.float32),
+ self.cast(x[2], mstype.float32),
+ self.cast(x[3], mstype.float32))
+ else:
+ roi_feats = self.roi_align_test(rois,
+ self.cast(x[0], mstype.float32),
+ self.cast(x[1], mstype.float32),
+ self.cast(x[2], mstype.float32),
+ self.cast(x[3], mstype.float32))
+
+
+ roi_feats = self.cast(roi_feats, mstype.float16)
+ rcnn_masks = self.concat(mask_tuple)
+ rcnn_masks = F.stop_gradient(rcnn_masks)
+ rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
+ rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats,
+ bbox_targets,
+ rcnn_labels,
+ rcnn_mask_squeeze)
+
+ output = ()
+ if self.training:
+ output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss)
+ else:
+ output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas)
+
+ return output
+
+ def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas):
+ """Get the actual detection box."""
+ scores = self.softmax(cls_logits)
+
+ boxes_all = ()
+ for i in range(self.num_classes):
+ k = i * 4
+ reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1])
+ out_boxes_i = self.decode(rois, reg_logits_i)
+ boxes_all += (out_boxes_i,)
+
+ img_metas_all = self.split(img_metas)
+ scores_all = self.split(scores)
+ mask_all = self.split(self.cast(mask_logits, mstype.int32))
+
+ boxes_all_with_batchsize = ()
+ for i in range(self.test_batch_size):
+ scale = self.split_shape(self.squeeze(img_metas_all[i]))
+ scale_h = scale[2]
+ scale_w = scale[3]
+ boxes_tuple = ()
+ for j in range(self.num_classes):
+ boxes_tmp = self.split(boxes_all[j])
+ out_boxes_h = boxes_tmp[i] / scale_h
+ out_boxes_w = boxes_tmp[i] / scale_w
+ boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),)
+ boxes_all_with_batchsize += (boxes_tuple,)
+
+ output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all)
+
+ return output
+
+ def multiclass_nms(self, boxes_all, scores_all, mask_all):
+ """Multiscale postprocessing."""
+ all_bboxes = ()
+ all_labels = ()
+ all_masks = ()
+
+ for i in range(self.test_batch_size):
+ bboxes = boxes_all[i]
+ scores = scores_all[i]
+ masks = self.cast(mask_all[i], mstype.bool_)
+
+ res_boxes_tuple = ()
+ res_labels_tuple = ()
+ res_masks_tuple = ()
+
+ for j in range(self.num_classes - 1):
+ k = j + 1
+ _cls_scores = scores[::, k:k + 1:1]
+ _bboxes = self.squeeze(bboxes[k])
+ _mask_o = self.reshape(masks, (self.rpn_max_num, 1))
+
+ cls_mask = self.greater(_cls_scores, self.test_score_thresh)
+ _mask = self.logicand(_mask_o, cls_mask)
+
+ _reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_)
+
+ _bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros)
+ _cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros)
+ __cls_scores = self.squeeze(_cls_scores)
+ scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num)
+ topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1))
+ scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1))
+ _bboxes_sorted = self.gather(_bboxes, topk_inds)
+ _mask_sorted = self.gather(_mask, topk_inds)
+
+ scores_sorted = self.tile(scores_sorted, (1, 4))
+ cls_dets = self.concat_1((_bboxes_sorted, scores_sorted))
+ cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5))
+
+ cls_dets, _index, _mask_nms = self.nms_test(cls_dets)
+ _index = self.reshape(_index, (self.rpn_max_num, 1))
+ _mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1))
+
+ _mask_n = self.gather(_mask_sorted, _index)
+
+ _mask_n = self.logicand(_mask_n, _mask_nms)
+ cls_labels = self.oneslike(_index) * j
+ res_boxes_tuple += (cls_dets,)
+ res_labels_tuple += (cls_labels,)
+ res_masks_tuple += (_mask_n,)
+
+ res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start])
+ res_labels_start = self.concat(res_labels_tuple[:self.concat_start])
+ res_masks_start = self.concat(res_masks_tuple[:self.concat_start])
+
+ res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end])
+ res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end])
+ res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end])
+
+ res_boxes = self.concat((res_boxes_start, res_boxes_end))
+ res_labels = self.concat((res_labels_start, res_labels_end))
+ res_masks = self.concat((res_masks_start, res_masks_end))
+
+ reshape_size = (self.num_classes - 1) * self.rpn_max_num
+ res_boxes = self.reshape(res_boxes, (1, reshape_size, 5))
+ res_labels = self.reshape(res_labels, (1, reshape_size, 1))
+ res_masks = self.reshape(res_masks, (1, reshape_size, 1))
+
+ all_bboxes += (res_boxes,)
+ all_labels += (res_labels,)
+ all_masks += (res_masks,)
+
+ all_bboxes = self.concat(all_bboxes)
+ all_labels = self.concat(all_labels)
+ all_masks = self.concat(all_masks)
+ return all_bboxes, all_labels, all_masks
+
+ def get_anchors(self, featmap_sizes):
+ """Get anchors according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ tuple: anchors of each image, valid flags of each image
+ """
+ num_levels = len(featmap_sizes)
+
+ # since feature map sizes of all images are the same, we only compute
+ # anchors for one time
+ multi_level_anchors = ()
+ for i in range(num_levels):
+ anchors = self.anchor_generators[i].grid_anchors(
+ featmap_sizes[i], self.anchor_strides[i])
+ multi_level_anchors += (Tensor(anchors.astype(np.float16)),)
+
+ return multi_level_anchors
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py
new file mode 100644
index 0000000..73781bd
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py
@@ -0,0 +1,114 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn feature pyramid network."""
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore import context
+from mindspore.ops import operations as P
+from mindspore.common.tensor import Tensor
+from mindspore.common import dtype as mstype
+from mindspore.common.initializer import initializer
+
+# pylint: disable=locally-disabled, missing-docstring
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+def bias_init_zeros(shape):
+ """Bias init method."""
+ return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
+
+def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
+ """Conv2D wrapper."""
+ shape = (out_channels, in_channels, kernel_size, kernel_size)
+ weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
+ shape_bias = (out_channels,)
+ biass = bias_init_zeros(shape_bias)
+ return nn.Conv2d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)
+
+class FeatPyramidNeck(nn.Cell):
+ """
+ Feature pyramid network cell, usually uses as network neck.
+
+ Applies the convolution on multiple, input feature maps
+ and output feature map with same channel size. if required num of
+ output larger then num of inputs, add extra maxpooling for further
+ downsampling;
+
+ Args:
+ in_channels (tuple) - Channel size of input feature maps.
+ out_channels (int) - Channel size output.
+ num_outs (int) - Num of output features.
+
+ Returns:
+ Tuple, with tensors of same channel size.
+
+ Examples:
+ neck = FeatPyramidNeck([100,200,300], 50, 4)
+ input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)),
+ dtype=np.float32) \
+ for i, c in enumerate(config.fpn_in_channels))
+ x = neck(input_data)
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs):
+ super(FeatPyramidNeck, self).__init__()
+ self.num_outs = num_outs
+ self.in_channels = in_channels
+ self.fpn_layer = len(self.in_channels)
+
+ assert not self.num_outs < len(in_channels)
+
+ self.lateral_convs_list_ = []
+ self.fpn_convs_ = []
+
+ for _, channel in enumerate(in_channels):
+ l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
+ fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
+ self.lateral_convs_list_.append(l_conv)
+ self.fpn_convs_.append(fpn_conv)
+ self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
+ self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
+ self.interpolate1 = P.ResizeNearestNeighbor((48, 80))
+ self.interpolate2 = P.ResizeNearestNeighbor((96, 160))
+ self.interpolate3 = P.ResizeNearestNeighbor((192, 320))
+ self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same")
+
+ def construct(self, inputs):
+ x = ()
+ for i in range(self.fpn_layer):
+ x += (self.lateral_convs_list[i](inputs[i]),)
+
+ y = (x[3],)
+ y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),)
+ y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),)
+ y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),)
+
+ z = ()
+ for i in range(self.fpn_layer - 1, -1, -1):
+ z = z + (y[i],)
+
+ outs = ()
+ for i in range(self.fpn_layer):
+ outs = outs + (self.fpn_convs_list[i](z[i]),)
+
+ for i in range(self.num_outs - self.fpn_layer):
+ outs = outs + (self.maxpool(outs[3]),)
+ return outs
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/proposal_generator.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/proposal_generator.py
new file mode 100644
index 0000000..d24fd04
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/proposal_generator.py
@@ -0,0 +1,201 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn proposal generator."""
+
+import numpy as np
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+from mindspore.ops import operations as P
+from mindspore import Tensor
+from mindspore import context
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+class Proposal(nn.Cell):
+ """
+ Proposal subnet.
+
+ Args:
+ config (dict): Config.
+ batch_size (int): Batchsize.
+ num_classes (int) - Class number.
+ use_sigmoid_cls (bool) - Select sigmoid or softmax function.
+ target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
+ target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
+
+ Returns:
+ Tuple, tuple of output tensor,(proposal, mask).
+
+ Examples:
+ Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
+ target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
+ """
+ def __init__(self,
+ config,
+ batch_size,
+ num_classes,
+ use_sigmoid_cls,
+ target_means=(.0, .0, .0, .0),
+ target_stds=(1.0, 1.0, 1.0, 1.0)
+ ):
+ super(Proposal, self).__init__()
+ cfg = config
+ self.batch_size = batch_size
+ self.num_classes = num_classes
+ self.target_means = target_means
+ self.target_stds = target_stds
+ self.use_sigmoid_cls = use_sigmoid_cls
+
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = num_classes - 1
+ self.activation = P.Sigmoid()
+ self.reshape_shape = (-1, 1)
+ else:
+ self.cls_out_channels = num_classes
+ self.activation = P.Softmax(axis=1)
+ self.reshape_shape = (-1, 2)
+
+ if self.cls_out_channels <= 0:
+ raise ValueError('num_classes={} is too small'.format(num_classes))
+
+ self.num_pre = cfg.rpn_proposal_nms_pre
+ self.min_box_size = cfg.rpn_proposal_min_bbox_size
+ self.nms_thr = cfg.rpn_proposal_nms_thr
+ self.nms_post = cfg.rpn_proposal_nms_post
+ self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
+ self.max_num = cfg.rpn_proposal_max_num
+ self.num_levels = cfg.fpn_num_outs
+
+ # Op Define
+ self.squeeze = P.Squeeze()
+ self.reshape = P.Reshape()
+ self.cast = P.Cast()
+
+ self.feature_shapes = cfg.feature_shapes
+
+ self.transpose_shape = (1, 2, 0)
+
+ self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \
+ means=self.target_means, \
+ stds=self.target_stds)
+
+ self.nms = P.NMSWithMask(self.nms_thr)
+ self.concat_axis0 = P.Concat(axis=0)
+ self.concat_axis1 = P.Concat(axis=1)
+ self.split = P.Split(axis=1, output_num=5)
+ self.min = P.Minimum()
+ self.gatherND = P.GatherNd()
+ self.slice = P.Slice()
+ self.select = P.Select()
+ self.greater = P.Greater()
+ self.transpose = P.Transpose()
+ self.tile = P.Tile()
+ self.set_train_local(config, training=True)
+
+ self.multi_10 = Tensor(10.0, mstype.float16)
+
+ def set_train_local(self, config, training=True):
+ """Set training flag."""
+ self.training_local = training
+
+ cfg = config
+ self.topK_stage1 = ()
+ self.topK_shape = ()
+ total_max_topk_input = 0
+ if not self.training_local:
+ self.num_pre = cfg.rpn_nms_pre
+ self.min_box_size = cfg.rpn_min_bbox_min_size
+ self.nms_thr = cfg.rpn_nms_thr
+ self.nms_post = cfg.rpn_nms_post
+ self.nms_across_levels = cfg.rpn_nms_across_levels
+ self.max_num = cfg.rpn_max_num
+
+ for shp in self.feature_shapes:
+ k_num = min(self.num_pre, (shp[0] * shp[1] * 3))
+ total_max_topk_input += k_num
+ self.topK_stage1 += (k_num,)
+ self.topK_shape += ((k_num, 1),)
+
+ self.topKv2 = P.TopK(sorted=True)
+ self.topK_shape_stage2 = (self.max_num, 1)
+ self.min_float_num = -65536.0
+ self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
+
+ def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
+ proposals_tuple = ()
+ masks_tuple = ()
+ for img_id in range(self.batch_size):
+ cls_score_list = ()
+ bbox_pred_list = ()
+ for i in range(self.num_levels):
+ rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::])
+ rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::])
+
+ cls_score_list = cls_score_list + (rpn_cls_score_i,)
+ bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,)
+
+ proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list)
+ proposals_tuple += (proposals,)
+ masks_tuple += (masks,)
+ return proposals_tuple, masks_tuple
+
+ def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
+ """Get proposal boundingbox."""
+ mlvl_proposals = ()
+ mlvl_mask = ()
+ for idx in range(self.num_levels):
+ rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape)
+ rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape)
+ anchors = mlvl_anchors[idx]
+
+ rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
+ rpn_cls_score = self.activation(rpn_cls_score)
+ rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16)
+
+ rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
+
+ scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])
+
+ topk_inds = self.reshape(topk_inds, self.topK_shape[idx])
+
+ bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
+ anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
+
+ proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
+
+ proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
+ proposals, _, mask_valid = self.nms(proposals_decode)
+
+ mlvl_proposals = mlvl_proposals + (proposals,)
+ mlvl_mask = mlvl_mask + (mask_valid,)
+
+ proposals = self.concat_axis0(mlvl_proposals)
+ masks = self.concat_axis0(mlvl_mask)
+
+ _, _, _, _, scores = self.split(proposals)
+ scores = self.squeeze(scores)
+ topk_mask = self.cast(self.topK_mask, mstype.float16)
+ scores_using = self.select(masks, scores, topk_mask)
+
+ _, topk_inds = self.topKv2(scores_using, self.max_num)
+
+ topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
+ proposals = self.gatherND(proposals, topk_inds)
+ masks = self.gatherND(masks, topk_inds)
+ return proposals, masks
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rcnn.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rcnn.py
new file mode 100644
index 0000000..3ddca9d
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rcnn.py
@@ -0,0 +1,173 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn Rcnn network."""
+
+import numpy as np
+import mindspore.common.dtype as mstype
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+from mindspore.common.tensor import Tensor
+from mindspore.common.initializer import initializer
+from mindspore.common.parameter import Parameter
+
+# pylint: disable=locally-disabled, missing-docstring
+
+
+class DenseNoTranpose(nn.Cell):
+ """Dense method"""
+ def __init__(self, input_channels, output_channels, weight_init):
+ super(DenseNoTranpose, self).__init__()
+
+ self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16),
+ name="weight")
+ self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias")
+
+ self.matmul = P.MatMul(transpose_b=False)
+ self.bias_add = P.BiasAdd()
+
+ def construct(self, x):
+ output = self.bias_add(self.matmul(x, self.weight), self.bias)
+ return output
+
+
+class Rcnn(nn.Cell):
+ """
+ Rcnn subnet.
+
+ Args:
+ config (dict) - Config.
+ representation_size (int) - Channels of shared dense.
+ batch_size (int) - Batchsize.
+ num_classes (int) - Class number.
+ target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
+ target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
+
+ Returns:
+ Tuple, tuple of output tensor.
+
+ Examples:
+ Rcnn(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
+ target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
+ """
+ def __init__(self,
+ config,
+ representation_size,
+ batch_size,
+ num_classes,
+ target_means=(0., 0., 0., 0.),
+ target_stds=(0.1, 0.1, 0.2, 0.2)
+ ):
+ super(Rcnn, self).__init__()
+ cfg = config
+ self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16))
+ self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16))
+ self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
+ self.target_means = target_means
+ self.target_stds = target_stds
+ self.num_classes = num_classes
+ self.in_channels = cfg.rcnn_in_channels
+ self.train_batch_size = batch_size
+ self.test_batch_size = cfg.test_batch_size
+
+ shape_0 = (self.rcnn_fc_out_channels, representation_size)
+ weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor()
+ shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels)
+ weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor()
+ self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0)
+ self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1)
+
+ cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1],
+ dtype=mstype.float16).to_tensor()
+ reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1],
+ dtype=mstype.float16).to_tensor()
+ self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight)
+ self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight)
+
+ self.flatten = P.Flatten()
+ self.relu = P.ReLU()
+ self.logicaland = P.LogicalAnd()
+ self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
+ self.loss_bbox = P.SmoothL1Loss(beta=1.0)
+ self.reshape = P.Reshape()
+ self.onehot = P.OneHot()
+ self.greater = P.Greater()
+ self.cast = P.Cast()
+ self.sum_loss = P.ReduceSum()
+ self.tile = P.Tile()
+ self.expandims = P.ExpandDims()
+
+ self.gather = P.GatherNd()
+ self.argmax = P.ArgMaxWithValue(axis=1)
+
+ self.on_value = Tensor(1.0, mstype.float32)
+ self.off_value = Tensor(0.0, mstype.float32)
+ self.value = Tensor(1.0, mstype.float16)
+
+ self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
+
+ rmv_first = np.ones((self.num_bboxes, self.num_classes))
+ rmv_first[:, 0] = np.zeros((self.num_bboxes,))
+ self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
+
+ self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size
+
+ range_max = np.arange(self.num_bboxes_test).astype(np.int32)
+ self.range_max = Tensor(range_max)
+
+ def construct(self, featuremap, bbox_targets, labels, mask):
+ x = self.flatten(featuremap)
+
+ x = self.relu(self.shared_fc_0(x))
+ x = self.relu(self.shared_fc_1(x))
+
+ x_cls = self.cls_scores(x)
+ x_reg = self.reg_scores(x)
+
+ if self.training:
+ bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
+ labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16)
+ bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
+
+ loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)
+ out = (loss, loss_cls, loss_reg, loss_print)
+ else:
+ out = (x_cls, (x_cls / self.value), x_reg, x_cls)
+
+ return out
+
+ def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights):
+ """Loss method."""
+ loss_print = ()
+ loss_cls, _ = self.loss_cls(cls_score, labels)
+
+ weights = self.cast(weights, mstype.float16)
+ loss_cls = loss_cls * weights
+ loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
+
+ bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
+ mstype.float16)
+ bbox_weights = bbox_weights * self.rmv_first_tensor
+
+ pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
+ loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
+ loss_reg = self.sum_loss(loss_reg, (2,))
+ loss_reg = loss_reg * bbox_weights
+ loss_reg = loss_reg / self.sum_loss(weights, (0,))
+ loss_reg = self.sum_loss(loss_reg, (0, 1))
+
+ loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg
+ loss_print += (loss_cls, loss_reg)
+
+ return loss, loss_cls, loss_reg, loss_print
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/resnet50.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/resnet50.py
new file mode 100644
index 0000000..eb0fd57
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/resnet50.py
@@ -0,0 +1,250 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Resnet50 backbone."""
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+from mindspore.common.tensor import Tensor
+from mindspore.ops import functional as F
+from mindspore import context
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+def weight_init_ones(shape):
+ """Weight init."""
+ return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16))
+
+
+def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
+ """Conv2D wrapper."""
+ shape = (out_channels, in_channels, kernel_size, kernel_size)
+ weights = weight_init_ones(shape)
+ return nn.Conv2d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ pad_mode=pad_mode, weight_init=weights, has_bias=False)
+
+
+def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
+ """Batchnorm2D wrapper."""
+ gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
+ beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
+ moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
+ moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
+
+ return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
+ beta_init=beta_init, moving_mean_init=moving_mean_init,
+ moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
+
+
+class ResNetFea(nn.Cell):
+ """
+ ResNet architecture.
+
+ Args:
+ block (Cell): Block for network.
+ layer_nums (list): Numbers of block in different layers.
+ in_channels (list): Input channel in each layer.
+ out_channels (list): Output channel in each layer.
+ weights_update (bool): Weight update flag.
+ Returns:
+ Tensor, output tensor.
+
+ Examples:
+ >>> ResNet(ResidualBlock,
+ >>> [3, 4, 6, 3],
+ >>> [64, 256, 512, 1024],
+ >>> [256, 512, 1024, 2048],
+ >>> False)
+ """
+ def __init__(self,
+ block,
+ layer_nums,
+ in_channels,
+ out_channels,
+ weights_update=False):
+ super(ResNetFea, self).__init__()
+
+ if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
+ raise ValueError("the length of "
+ "layer_num, inchannel, outchannel list must be 4!")
+
+ bn_training = False
+ self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
+ self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training)
+ self.relu = P.ReLU()
+ self.maxpool = P.MaxPool(ksize=3, strides=2, padding="SAME")
+ self.weights_update = weights_update
+
+ if not self.weights_update:
+ self.conv1.weight.requires_grad = False
+
+ self.layer1 = self._make_layer(block,
+ layer_nums[0],
+ in_channel=in_channels[0],
+ out_channel=out_channels[0],
+ stride=1,
+ training=bn_training,
+ weights_update=self.weights_update)
+ self.layer2 = self._make_layer(block,
+ layer_nums[1],
+ in_channel=in_channels[1],
+ out_channel=out_channels[1],
+ stride=2,
+ training=bn_training,
+ weights_update=True)
+ self.layer3 = self._make_layer(block,
+ layer_nums[2],
+ in_channel=in_channels[2],
+ out_channel=out_channels[2],
+ stride=2,
+ training=bn_training,
+ weights_update=True)
+ self.layer4 = self._make_layer(block,
+ layer_nums[3],
+ in_channel=in_channels[3],
+ out_channel=out_channels[3],
+ stride=2,
+ training=bn_training,
+ weights_update=True)
+
+ def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False):
+ """Make block layer."""
+ layers = []
+ down_sample = False
+ if stride != 1 or in_channel != out_channel:
+ down_sample = True
+ resblk = block(in_channel,
+ out_channel,
+ stride=stride,
+ down_sample=down_sample,
+ training=training,
+ weights_update=weights_update)
+ layers.append(resblk)
+
+ for _ in range(1, layer_num):
+ resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update)
+ layers.append(resblk)
+
+ return nn.SequentialCell(layers)
+
+ def construct(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ c1 = self.maxpool(x)
+
+ c2 = self.layer1(c1)
+ identity = c2
+ if not self.weights_update:
+ identity = F.stop_gradient(c2)
+ c3 = self.layer2(identity)
+ c4 = self.layer3(c3)
+ c5 = self.layer4(c4)
+
+ return identity, c3, c4, c5
+
+
+class ResidualBlockUsing(nn.Cell):
+ """
+ ResNet V1 residual block definition.
+
+ Args:
+ in_channels (int) - Input channel.
+ out_channels (int) - Output channel.
+ stride (int) - Stride size for the initial convolutional layer. Default: 1.
+ down_sample (bool) - If to do the downsample in block. Default: False.
+ momentum (float) - Momentum for batchnorm layer. Default: 0.1.
+ training (bool) - Training flag. Default: False.
+ weights_updata (bool) - Weights update flag. Default: False.
+
+ Returns:
+ Tensor, output tensor.
+
+ Examples:
+ ResidualBlock(3,256,stride=2,down_sample=True)
+ """
+ expansion = 4
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride=1,
+ down_sample=False,
+ momentum=0.1,
+ training=False,
+ weights_update=False):
+ super(ResidualBlockUsing, self).__init__()
+
+ self.affine = weights_update
+
+ out_chls = out_channels // self.expansion
+ self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0)
+ self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
+
+ self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1)
+ self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
+
+ self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0)
+ self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training)
+
+ if training:
+ self.bn1 = self.bn1.set_train()
+ self.bn2 = self.bn2.set_train()
+ self.bn3 = self.bn3.set_train()
+
+ if not weights_update:
+ self.conv1.weight.requires_grad = False
+ self.conv2.weight.requires_grad = False
+ self.conv3.weight.requires_grad = False
+
+ self.relu = P.ReLU()
+ self.downsample = down_sample
+ if self.downsample:
+ self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
+ self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine,
+ use_batch_statistics=training)
+ if training:
+ self.bn_down_sample = self.bn_down_sample.set_train()
+ if not weights_update:
+ self.conv_down_sample.weight.requires_grad = False
+ self.add = P.TensorAdd()
+
+ def construct(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample:
+ identity = self.conv_down_sample(identity)
+ identity = self.bn_down_sample(identity)
+
+ out = self.add(out, identity)
+ out = self.relu(out)
+
+ return out
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/roi_align.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/roi_align.py
new file mode 100644
index 0000000..f174381
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/roi_align.py
@@ -0,0 +1,181 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn ROIAlign module."""
+
+import numpy as np
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+from mindspore.ops import operations as P
+from mindspore.ops import composite as C
+from mindspore.nn import layer as L
+from mindspore.common.tensor import Tensor
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+
+class ROIAlign(nn.Cell):
+ """
+ Extract RoI features from mulitple feature map.
+
+ Args:
+ out_size_h (int) - RoI height.
+ out_size_w (int) - RoI width.
+ spatial_scale (int) - RoI spatial scale.
+ sample_num (int) - RoI sample number.
+ """
+ def __init__(self,
+ out_size_h,
+ out_size_w,
+ spatial_scale,
+ sample_num=0):
+ super(ROIAlign, self).__init__()
+
+ self.out_size = (out_size_h, out_size_w)
+ self.spatial_scale = float(spatial_scale)
+ self.sample_num = int(sample_num)
+ self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1],
+ self.spatial_scale, self.sample_num)
+
+ def construct(self, features, rois):
+ return self.align_op(features, rois)
+
+ def __repr__(self):
+ format_str = self.__class__.__name__
+ format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
+ self.out_size, self.spatial_scale, self.sample_num)
+ return format_str
+
+
+class SingleRoIExtractor(nn.Cell):
+ """
+ Extract RoI features from a single level feature map.
+
+ If there are mulitple input feature levels, each RoI is mapped to a level
+ according to its scale.
+
+ Args:
+ config (dict): Config
+ roi_layer (dict): Specify RoI layer type and arguments.
+ out_channels (int): Output channels of RoI layers.
+ featmap_strides (int): Strides of input feature maps.
+ batch_size (int): Batchsize.
+ finest_scale (int): Scale threshold of mapping to level 0.
+ """
+
+ def __init__(self,
+ config,
+ roi_layer,
+ out_channels,
+ featmap_strides,
+ batch_size=1,
+ finest_scale=56):
+ super(SingleRoIExtractor, self).__init__()
+ cfg = config
+ self.train_batch_size = batch_size
+ self.out_channels = out_channels
+ self.featmap_strides = featmap_strides
+ self.num_levels = len(self.featmap_strides)
+ self.out_size = roi_layer['out_size']
+ self.sample_num = roi_layer['sample_num']
+ self.roi_layers = self.build_roi_layers(self.featmap_strides)
+ self.roi_layers = L.CellList(self.roi_layers)
+
+ self.sqrt = P.Sqrt()
+ self.log = P.Log()
+ self.finest_scale_ = finest_scale
+ self.clamp = C.clip_by_value
+
+ self.cast = P.Cast()
+ self.equal = P.Equal()
+ self.select = P.Select()
+
+ _mode_16 = False
+ self.dtype = np.float16 if _mode_16 else np.float32
+ self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32
+ self.set_train_local(cfg, training=True)
+
+ def set_train_local(self, config, training=True):
+ """Set training flag."""
+ self.training_local = training
+
+ cfg = config
+ # Init tensor
+ self.batch_size = cfg.roi_sample_num if self.training_local else cfg.rpn_max_num
+ self.batch_size = self.train_batch_size*self.batch_size \
+ if self.training_local else cfg.test_batch_size*self.batch_size
+ self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype))
+ finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_
+ self.finest_scale = Tensor(finest_scale)
+ self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6))
+ self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32))
+ self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1))
+ self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2)
+ self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels,
+ self.out_size, self.out_size)), dtype=self.dtype))
+ def num_inputs(self):
+ return len(self.featmap_strides)
+
+ def init_weights(self):
+ pass
+
+ def log2(self, value):
+ return self.log(value) / self.log(self.twos)
+
+ def build_roi_layers(self, featmap_strides):
+ roi_layers = []
+ for s in featmap_strides:
+ layer_cls = ROIAlign(self.out_size, self.out_size,
+ spatial_scale=1 / s,
+ sample_num=self.sample_num)
+ roi_layers.append(layer_cls)
+ return roi_layers
+
+ def _c_map_roi_levels(self, rois):
+ """Map rois to corresponding feature levels by scales.
+
+ - scale < finest_scale * 2: level 0
+ - finest_scale * 2 <= scale < finest_scale * 4: level 1
+ - finest_scale * 4 <= scale < finest_scale * 8: level 2
+ - scale >= finest_scale * 8: level 3
+
+ Args:
+ rois (Tensor): Input RoIs, shape (k, 5).
+ num_levels (int): Total level number.
+
+ Returns:
+ Tensor: Level index (0-based) of each RoI, shape (k, )
+ """
+ scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \
+ self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones)
+
+ target_lvls = self.log2(scale / self.finest_scale + self.epslion)
+ target_lvls = P.Floor()(target_lvls)
+ target_lvls = self.cast(target_lvls, mstype.int32)
+ target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels)
+
+ return target_lvls
+
+ def construct(self, rois, feat1, feat2, feat3, feat4):
+ feats = (feat1, feat2, feat3, feat4)
+ res = self.res_
+ target_lvls = self._c_map_roi_levels(rois)
+ for i in range(self.num_levels):
+ mask = self.equal(target_lvls, P.ScalarToArray()(i))
+ mask = P.Reshape()(mask, (-1, 1, 1, 1))
+ roi_feats_t = self.roi_layers[i](feats[i], rois)
+ mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, 7, 7)), mstype.bool_)
+ res = self.select(mask, roi_feats_t, res)
+
+ return res
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rpn.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rpn.py
new file mode 100644
index 0000000..5d5b87e
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rpn.py
@@ -0,0 +1,315 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""RPN for fasterRCNN"""
+import numpy as np
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+from mindspore.ops import operations as P
+from mindspore import Tensor
+from mindspore.ops import functional as F
+from mindspore.common.initializer import initializer
+from .bbox_assign_sample import BboxAssignSample
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+# pylint: disable=locally-disabled, invalid-name, missing-docstring
+
+
+class RpnRegClsBlock(nn.Cell):
+ """
+ Rpn reg cls block for rpn layer
+
+ Args:
+ in_channels (int) - Input channels of shared convolution.
+ feat_channels (int) - Output channels of shared convolution.
+ num_anchors (int) - The anchor number.
+ cls_out_channels (int) - Output channels of classification convolution.
+ weight_conv (Tensor) - weight init for rpn conv.
+ bias_conv (Tensor) - bias init for rpn conv.
+ weight_cls (Tensor) - weight init for rpn cls conv.
+ bias_cls (Tensor) - bias init for rpn cls conv.
+ weight_reg (Tensor) - weight init for rpn reg conv.
+ bias_reg (Tensor) - bias init for rpn reg conv.
+
+ Returns:
+ Tensor, output tensor.
+ """
+ def __init__(self,
+ in_channels,
+ feat_channels,
+ num_anchors,
+ cls_out_channels,
+ weight_conv,
+ bias_conv,
+ weight_cls,
+ bias_cls,
+ weight_reg,
+ bias_reg):
+ super(RpnRegClsBlock, self).__init__()
+ self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same',
+ has_bias=True, weight_init=weight_conv, bias_init=bias_conv)
+ self.relu = nn.ReLU()
+
+ self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid',
+ has_bias=True, weight_init=weight_cls, bias_init=bias_cls)
+ self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid',
+ has_bias=True, weight_init=weight_reg, bias_init=bias_reg)
+
+ def construct(self, x):
+ x = self.relu(self.rpn_conv(x))
+
+ x1 = self.rpn_cls(x)
+ x2 = self.rpn_reg(x)
+
+ return x1, x2
+
+
+class RPN(nn.Cell):
+ """
+ ROI proposal network..
+
+ Args:
+ config (dict) - Config.
+ batch_size (int) - Batchsize.
+ in_channels (int) - Input channels of shared convolution.
+ feat_channels (int) - Output channels of shared convolution.
+ num_anchors (int) - The anchor number.
+ cls_out_channels (int) - Output channels of classification convolution.
+
+ Returns:
+ Tuple, tuple of output tensor.
+
+ Examples:
+ RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
+ num_anchors=3, cls_out_channels=512)
+ """
+ def __init__(self,
+ config,
+ batch_size,
+ in_channels,
+ feat_channels,
+ num_anchors,
+ cls_out_channels):
+ super(RPN, self).__init__()
+ cfg_rpn = config
+ self.num_bboxes = cfg_rpn.num_bboxes
+ self.slice_index = ()
+ self.feature_anchor_shape = ()
+ self.slice_index += (0,)
+ index = 0
+ for shape in cfg_rpn.feature_shapes:
+ self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,)
+ self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,)
+ index += 1
+
+ self.num_anchors = num_anchors
+ self.batch_size = batch_size
+ self.test_batch_size = cfg_rpn.test_batch_size
+ self.num_layers = 5
+ self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
+
+ self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels,
+ num_anchors, cls_out_channels))
+
+ self.transpose = P.Transpose()
+ self.reshape = P.Reshape()
+ self.concat = P.Concat(axis=0)
+ self.fill = P.Fill()
+ self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
+
+ self.trans_shape = (0, 2, 3, 1)
+
+ self.reshape_shape_reg = (-1, 4)
+ self.reshape_shape_cls = (-1,)
+ self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
+ self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
+ self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
+ self.num_bboxes = cfg_rpn.num_bboxes
+ self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
+ self.CheckValid = P.CheckValid()
+ self.sum_loss = P.ReduceSum()
+ self.loss_cls = P.SigmoidCrossEntropyWithLogits()
+ self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
+ self.squeeze = P.Squeeze()
+ self.cast = P.Cast()
+ self.tile = P.Tile()
+ self.zeros_like = P.ZerosLike()
+ self.loss = Tensor(np.zeros((1,)).astype(np.float16))
+ self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
+ self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
+
+ def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
+ """
+ make rpn layer for rpn proposal network
+
+ Args:
+ num_layers (int) - layer num.
+ in_channels (int) - Input channels of shared convolution.
+ feat_channels (int) - Output channels of shared convolution.
+ num_anchors (int) - The anchor number.
+ cls_out_channels (int) - Output channels of classification convolution.
+
+ Returns:
+ List, list of RpnRegClsBlock cells.
+ """
+ rpn_layer = []
+
+ shp_weight_conv = (feat_channels, in_channels, 3, 3)
+ shp_bias_conv = (feat_channels,)
+ weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor()
+ bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor()
+
+ shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
+ shp_bias_cls = (num_anchors * cls_out_channels,)
+ weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor()
+ bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor()
+
+ shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
+ shp_bias_reg = (num_anchors * 4,)
+ weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor()
+ bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor()
+
+ for i in range(num_layers):
+ rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
+ weight_conv, bias_conv, weight_cls, \
+ bias_cls, weight_reg, bias_reg))
+
+ for i in range(1, num_layers):
+ rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
+ rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
+ rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight
+
+ rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias
+ rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias
+ rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias
+
+ return rpn_layer
+
+ def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
+ loss_print = ()
+ rpn_cls_score = ()
+ rpn_bbox_pred = ()
+ rpn_cls_score_total = ()
+ rpn_bbox_pred_total = ()
+
+ for i in range(self.num_layers):
+ x1, x2 = self.rpn_convs_list[i](inputs[i])
+
+ rpn_cls_score_total = rpn_cls_score_total + (x1,)
+ rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,)
+
+ x1 = self.transpose(x1, self.trans_shape)
+ x1 = self.reshape(x1, self.reshape_shape_cls)
+
+ x2 = self.transpose(x2, self.trans_shape)
+ x2 = self.reshape(x2, self.reshape_shape_reg)
+
+ rpn_cls_score = rpn_cls_score + (x1,)
+ rpn_bbox_pred = rpn_bbox_pred + (x2,)
+
+ loss = self.loss
+ clsloss = self.clsloss
+ regloss = self.regloss
+ bbox_targets = ()
+ bbox_weights = ()
+ labels = ()
+ label_weights = ()
+
+ output = ()
+ if self.training:
+ for i in range(self.batch_size):
+ multi_level_flags = ()
+ anchor_list_tuple = ()
+
+ for j in range(self.num_layers):
+ res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])),
+ mstype.int32)
+ multi_level_flags = multi_level_flags + (res,)
+ anchor_list_tuple = anchor_list_tuple + (anchor_list[j],)
+
+ valid_flag_list = self.concat(multi_level_flags)
+ anchor_using_list = self.concat(anchor_list_tuple)
+
+ gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
+ gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
+ gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
+
+ bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
+ gt_labels_i,
+ self.cast(valid_flag_list,
+ mstype.bool_),
+ anchor_using_list, gt_valids_i)
+
+ bbox_weight = self.cast(bbox_weight, mstype.float16)
+ label = self.cast(label, mstype.float16)
+ label_weight = self.cast(label_weight, mstype.float16)
+
+ for j in range(self.num_layers):
+ begin = self.slice_index[j]
+ end = self.slice_index[j + 1]
+ stride = 1
+ bbox_targets += (bbox_target[begin:end:stride, ::],)
+ bbox_weights += (bbox_weight[begin:end:stride],)
+ labels += (label[begin:end:stride],)
+ label_weights += (label_weight[begin:end:stride],)
+
+ for i in range(self.num_layers):
+ bbox_target_using = ()
+ bbox_weight_using = ()
+ label_using = ()
+ label_weight_using = ()
+
+ for j in range(self.batch_size):
+ bbox_target_using += (bbox_targets[i + (self.num_layers * j)],)
+ bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],)
+ label_using += (labels[i + (self.num_layers * j)],)
+ label_weight_using += (label_weights[i + (self.num_layers * j)],)
+
+ bbox_target_with_batchsize = self.concat(bbox_target_using)
+ bbox_weight_with_batchsize = self.concat(bbox_weight_using)
+ label_with_batchsize = self.concat(label_using)
+ label_weight_with_batchsize = self.concat(label_weight_using)
+
+ # stop
+ bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
+ bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
+ label_ = F.stop_gradient(label_with_batchsize)
+ label_weight_ = F.stop_gradient(label_weight_with_batchsize)
+
+ cls_score_i = rpn_cls_score[i]
+ reg_score_i = rpn_bbox_pred[i]
+
+ loss_cls = self.loss_cls(cls_score_i, label_)
+ loss_cls_item = loss_cls * label_weight_
+ loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total
+
+ loss_reg = self.loss_bbox(reg_score_i, bbox_target_)
+ bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4))
+ loss_reg = loss_reg * bbox_weight_
+ loss_reg_item = self.sum_loss(loss_reg, (1,))
+ loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total
+
+ loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item
+
+ loss += loss_total
+ loss_print += (loss_total, loss_cls_item, loss_reg_item)
+ clsloss += loss_cls_item
+ regloss += loss_reg_item
+
+ output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print)
+ else:
+ output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1)
+
+ return output
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/config.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/config.py
new file mode 100644
index 0000000..2aed5e9
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/config.py
@@ -0,0 +1,158 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ===========================================================================
+"""
+network config setting, will be used in train.py and eval.py
+"""
+from easydict import EasyDict as ed
+
+config = ed({
+ "img_width": 1280,
+ "img_height": 768,
+ "keep_ratio": False,
+ "flip_ratio": 0.5,
+ "photo_ratio": 0.5,
+ "expand_ratio": 1.0,
+
+ # anchor
+ "feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)],
+ "anchor_scales": [8],
+ "anchor_ratios": [0.5, 1.0, 2.0],
+ "anchor_strides": [4, 8, 16, 32, 64],
+ "num_anchors": 3,
+
+ # resnet
+ "resnet_block": [3, 4, 6, 3],
+ "resnet_in_channels": [64, 256, 512, 1024],
+ "resnet_out_channels": [256, 512, 1024, 2048],
+
+ # fpn
+ "fpn_in_channels": [256, 512, 1024, 2048],
+ "fpn_out_channels": 256,
+ "fpn_num_outs": 5,
+
+ # rpn
+ "rpn_in_channels": 256,
+ "rpn_feat_channels": 256,
+ "rpn_loss_cls_weight": 1.0,
+ "rpn_loss_reg_weight": 1.0,
+ "rpn_cls_out_channels": 1,
+ "rpn_target_means": [0., 0., 0., 0.],
+ "rpn_target_stds": [1.0, 1.0, 1.0, 1.0],
+
+ # bbox_assign_sampler
+ "neg_iou_thr": 0.3,
+ "pos_iou_thr": 0.7,
+ "min_pos_iou": 0.3,
+ "num_bboxes": 245520,
+ "num_gts": 128,
+ "num_expected_neg": 256,
+ "num_expected_pos": 128,
+
+ # proposal
+ "activate_num_classes": 2,
+ "use_sigmoid_cls": True,
+
+ # roi_align
+ "roi_layer": dict(type='RoIAlign', out_size=7, sample_num=2),
+ "roi_align_out_channels": 256,
+ "roi_align_featmap_strides": [4, 8, 16, 32],
+ "roi_align_finest_scale": 56,
+ "roi_sample_num": 640,
+
+ # bbox_assign_sampler_stage2
+ "neg_iou_thr_stage2": 0.5,
+ "pos_iou_thr_stage2": 0.5,
+ "min_pos_iou_stage2": 0.5,
+ "num_bboxes_stage2": 2000,
+ "num_expected_pos_stage2": 128,
+ "num_expected_neg_stage2": 512,
+ "num_expected_total_stage2": 512,
+
+ # rcnn
+ "rcnn_num_layers": 2,
+ "rcnn_in_channels": 256,
+ "rcnn_fc_out_channels": 1024,
+ "rcnn_loss_cls_weight": 1,
+ "rcnn_loss_reg_weight": 1,
+ "rcnn_target_means": [0., 0., 0., 0.],
+ "rcnn_target_stds": [0.1, 0.1, 0.2, 0.2],
+
+ # train proposal
+ "rpn_proposal_nms_across_levels": False,
+ "rpn_proposal_nms_pre": 2000,
+ "rpn_proposal_nms_post": 2000,
+ "rpn_proposal_max_num": 2000,
+ "rpn_proposal_nms_thr": 0.7,
+ "rpn_proposal_min_bbox_size": 0,
+
+ # test proposal
+ "rpn_nms_across_levels": False,
+ "rpn_nms_pre": 1000,
+ "rpn_nms_post": 1000,
+ "rpn_max_num": 1000,
+ "rpn_nms_thr": 0.7,
+ "rpn_min_bbox_min_size": 0,
+ "test_score_thr": 0.05,
+ "test_iou_thr": 0.5,
+ "test_max_per_img": 100,
+ "test_batch_size": 1,
+
+ "rpn_head_loss_type": "CrossEntropyLoss",
+ "rpn_head_use_sigmoid": True,
+ "rpn_head_weight": 1.0,
+
+ # LR
+ "base_lr": 0.02,
+ "base_step": 58633,
+ "total_epoch": 13,
+ "warmup_step": 500,
+ "warmup_mode": "linear",
+ "warmup_ratio": 1/3.0,
+ "sgd_step": [8, 11],
+ "sgd_momentum": 0.9,
+
+ # train
+ "batch_size": 1,
+ "loss_scale": 1,
+ "momentum": 0.91,
+ "weight_decay": 1e-4,
+ "epoch_size": 12,
+ "save_checkpoint": True,
+ "save_checkpoint_epochs": 1,
+ "keep_checkpoint_max": 10,
+ "save_checkpoint_path": "./",
+
+ "mindrecord_dir": "../MindRecord_COCO_TRAIN",
+ "coco_root": "./cocodataset/",
+ "train_data_type": "train2017",
+ "val_data_type": "val2017",
+ "instance_set": "annotations/instances_{}.json",
+ "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
+ 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
+ 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard',
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
+ 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
+ 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
+ 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
+ 'refrigerator', 'book', 'clock', 'vase', 'scissors',
+ 'teddy bear', 'hair drier', 'toothbrush'),
+ "num_classes": 81
+})
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/dataset.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/dataset.py
new file mode 100644
index 0000000..addcc99
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/dataset.py
@@ -0,0 +1,505 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""FasterRcnn dataset"""
+from __future__ import division
+
+import os
+import numpy as np
+from numpy import random
+
+import mmcv
+import mindspore.dataset as de
+import mindspore.dataset.vision.c_transforms as C
+import mindspore.dataset.transforms.c_transforms as CC
+import mindspore.common.dtype as mstype
+from mindspore.mindrecord import FileWriter
+from src.config import config
+
+# pylint: disable=locally-disabled, unused-variable
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
+ """Calculate the ious between each bbox of bboxes1 and bboxes2.
+
+ Args:
+ bboxes1(ndarray): shape (n, 4)
+ bboxes2(ndarray): shape (k, 4)
+ mode(str): iou (intersection over union) or iof (intersection
+ over foreground)
+
+ Returns:
+ ious(ndarray): shape (n, k)
+ """
+
+ assert mode in ['iou', 'iof']
+
+ bboxes1 = bboxes1.astype(np.float32)
+ bboxes2 = bboxes2.astype(np.float32)
+ rows = bboxes1.shape[0]
+ cols = bboxes2.shape[0]
+ ious = np.zeros((rows, cols), dtype=np.float32)
+ if rows * cols == 0:
+ return ious
+ exchange = False
+ if bboxes1.shape[0] > bboxes2.shape[0]:
+ bboxes1, bboxes2 = bboxes2, bboxes1
+ ious = np.zeros((cols, rows), dtype=np.float32)
+ exchange = True
+ area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1)
+ area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1)
+ for i in range(bboxes1.shape[0]):
+ x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
+ y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
+ x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
+ y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
+ overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
+ y_end - y_start + 1, 0)
+ if mode == 'iou':
+ union = area1[i] + area2 - overlap
+ else:
+ union = area1[i] if not exchange else area2
+ ious[i, :] = overlap / union
+ if exchange:
+ ious = ious.T
+ return ious
+
+
+class PhotoMetricDistortion:
+ """Photo Metric Distortion"""
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def __call__(self, img, boxes, labels):
+ # random brightness
+ img = img.astype('float32')
+
+ if random.randint(2):
+ delta = random.uniform(-self.brightness_delta,
+ self.brightness_delta)
+ img += delta
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # convert color from BGR to HSV
+ img = mmcv.bgr2hsv(img)
+
+ # random saturation
+ if random.randint(2):
+ img[..., 1] *= random.uniform(self.saturation_lower,
+ self.saturation_upper)
+
+ # random hue
+ if random.randint(2):
+ img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
+ img[..., 0][img[..., 0] > 360] -= 360
+ img[..., 0][img[..., 0] < 0] += 360
+
+ # convert color from HSV to BGR
+ img = mmcv.hsv2bgr(img)
+
+ # random contrast
+ if mode == 0:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # randomly swap channels
+ if random.randint(2):
+ img = img[..., random.permutation(3)]
+
+ return img, boxes, labels
+
+
+class Expand:
+ """expand image"""
+
+ def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
+ if to_rgb:
+ self.mean = mean[::-1]
+ else:
+ self.mean = mean
+ self.min_ratio, self.max_ratio = ratio_range
+
+ def __call__(self, img, boxes, labels):
+ if random.randint(2):
+ return img, boxes, labels
+
+ h, w, c = img.shape
+ ratio = random.uniform(self.min_ratio, self.max_ratio)
+ expand_img = np.full((int(h * ratio), int(w * ratio), c),
+ self.mean).astype(img.dtype)
+ left = int(random.uniform(0, w * ratio - w))
+ top = int(random.uniform(0, h * ratio - h))
+ expand_img[top:top + h, left:left + w] = img
+ img = expand_img
+ boxes += np.tile((left, top), 2)
+ return img, boxes, labels
+
+
+def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """rescale operation for image"""
+ img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
+ if img_data.shape[0] > config.img_height:
+ img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True)
+ scale_factor = scale_factor * scale_factor2
+ img_shape = np.append(img_shape, scale_factor)
+ img_shape = np.asarray(img_shape, dtype=np.float32)
+ gt_bboxes = gt_bboxes * scale_factor
+
+ gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
+ gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
+
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """resize operation for image"""
+ img_data = img
+ img_data, w_scale, h_scale = mmcv.imresize(
+ img_data, (config.img_width, config.img_height), return_scale=True)
+ scale_factor = np.array(
+ [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
+ img_shape = (config.img_height, config.img_width, 1.0)
+ img_shape = np.asarray(img_shape, dtype=np.float32)
+
+ gt_bboxes = gt_bboxes * scale_factor
+
+ gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
+ gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
+
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """resize operation for image of eval"""
+ img_data = img
+ img_data, w_scale, h_scale = mmcv.imresize(
+ img_data, (config.img_width, config.img_height), return_scale=True)
+ scale_factor = np.array(
+ [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
+ img_shape = np.append(img_shape, (h_scale, w_scale))
+ img_shape = np.asarray(img_shape, dtype=np.float32)
+
+ gt_bboxes = gt_bboxes * scale_factor
+
+ gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
+ gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
+
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """impad operation for image"""
+ img_data = mmcv.impad(img, (config.img_height, config.img_width))
+ img_data = img_data.astype(np.float32)
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """imnormalize operation for image"""
+ img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
+ img_data = img_data.astype(np.float32)
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """flip operation for image"""
+ img_data = img
+ img_data = mmcv.imflip(img_data)
+ flipped = gt_bboxes.copy()
+ _, w, _ = img_data.shape
+
+ flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
+ flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
+
+ return (img_data, img_shape, flipped, gt_label, gt_num)
+
+
+def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """flipped generation"""
+ img_data = img
+ flipped = gt_bboxes.copy()
+ _, w, _ = img_data.shape
+
+ flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
+ flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
+
+ return (img_data, img_shape, flipped, gt_label, gt_num)
+
+
+def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
+ img_data = img[:, :, ::-1]
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """transpose operation for image"""
+ img_data = img.transpose(2, 0, 1).copy()
+ img_data = img_data.astype(np.float16)
+ img_shape = img_shape.astype(np.float16)
+ gt_bboxes = gt_bboxes.astype(np.float16)
+ gt_label = gt_label.astype(np.int32)
+ gt_num = gt_num.astype(np.bool)
+
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """photo crop operation for image"""
+ random_photo = PhotoMetricDistortion()
+ img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
+
+ return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
+ """expand operation for image"""
+ expand = Expand()
+ img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
+
+ return (img, img_shape, gt_bboxes, gt_label, gt_num)
+
+
+def preprocess_fn(image, box, is_training):
+ """Preprocess function for dataset."""
+
+ def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
+ image_shape = image_shape[:2]
+ input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
+
+ if config.keep_ratio:
+ input_data = rescale_column(*input_data)
+ else:
+ input_data = resize_column_test(*input_data)
+
+ input_data = image_bgr_rgb(*input_data)
+
+ output_data = input_data
+ return output_data
+
+ def _data_aug(image, box, is_training):
+ """Data augmentation function."""
+ image_bgr = image.copy()
+ image_bgr[:, :, 0] = image[:, :, 2]
+ image_bgr[:, :, 1] = image[:, :, 1]
+ image_bgr[:, :, 2] = image[:, :, 0]
+ image_shape = image_bgr.shape[:2]
+ gt_box = box[:, :4]
+ gt_label = box[:, 4]
+ gt_iscrowd = box[:, 5]
+
+ pad_max_number = 128
+ gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0)
+ gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1)
+ gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1)
+ gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)
+
+ if not is_training:
+ return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)
+
+ input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
+
+ if config.keep_ratio:
+ input_data = rescale_column(*input_data)
+ else:
+ input_data = resize_column(*input_data)
+
+ input_data = image_bgr_rgb(*input_data)
+
+ output_data = input_data
+ return output_data
+
+ return _data_aug(image, box, is_training)
+
+
+def create_coco_label(is_training):
+ """Get image path and annotation from COCO."""
+ from pycocotools.coco import COCO
+
+ coco_root = config.coco_root
+ data_type = config.val_data_type
+ if is_training:
+ data_type = config.train_data_type
+
+ # Classes need to train or test.
+ train_cls = config.coco_classes
+ train_cls_dict = {}
+ for i, cls in enumerate(train_cls):
+ train_cls_dict[cls] = i
+
+ anno_json = os.path.join(coco_root, config.instance_set.format(data_type))
+
+ coco = COCO(anno_json)
+ classs_dict = {}
+ cat_ids = coco.loadCats(coco.getCatIds())
+ for cat in cat_ids:
+ classs_dict[cat["id"]] = cat["name"]
+
+ image_ids = coco.getImgIds()
+ image_files = []
+ image_anno_dict = {}
+
+ for img_id in image_ids:
+ image_info = coco.loadImgs(img_id)
+ file_name = image_info[0]["file_name"]
+ anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
+ anno = coco.loadAnns(anno_ids)
+ image_path = os.path.join(coco_root, data_type, file_name)
+ annos = []
+ for label in anno:
+ bbox = label["bbox"]
+ class_name = classs_dict[label["category_id"]]
+ if class_name in train_cls:
+ x1, x2 = bbox[0], bbox[0] + bbox[2]
+ y1, y2 = bbox[1], bbox[1] + bbox[3]
+ annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])])
+
+ image_files.append(image_path)
+ if annos:
+ image_anno_dict[image_path] = np.array(annos)
+ else:
+ image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
+
+ return image_files, image_anno_dict
+
+
+def anno_parser(annos_str):
+ """Parse annotation from string to list."""
+ annos = []
+ for anno_str in annos_str:
+ anno = list(map(int, anno_str.strip().split(',')))
+ annos.append(anno)
+ return annos
+
+
+def filter_valid_data(image_dir, anno_path):
+ """Filter valid image file, which both in image_dir and anno_path."""
+ image_files = []
+ image_anno_dict = {}
+ if not os.path.isdir(image_dir):
+ raise RuntimeError("Path given is not valid.")
+ if not os.path.isfile(anno_path):
+ raise RuntimeError("Annotation file is not valid.")
+
+ with open(anno_path, "rb") as f:
+ lines = f.readlines()
+ for line in lines:
+ line_str = line.decode("utf-8").strip()
+ line_split = str(line_str).split(' ')
+ file_name = line_split[0]
+ image_path = os.path.join(image_dir, file_name)
+ if os.path.isfile(image_path):
+ image_anno_dict[image_path] = anno_parser(line_split[1:])
+ image_files.append(image_path)
+ return image_files, image_anno_dict
+
+
+def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
+ """Create MindRecord file."""
+ mindrecord_dir = config.mindrecord_dir
+ mindrecord_path = os.path.join(mindrecord_dir, prefix)
+ writer = FileWriter(mindrecord_path, file_num)
+ if dataset == "coco":
+ image_files, image_anno_dict = create_coco_label(is_training)
+ else:
+ image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
+
+ fasterrcnn_json = {
+ "image": {"type": "bytes"},
+ "annotation": {"type": "int32", "shape": [-1, 6]},
+ }
+ writer.add_schema(fasterrcnn_json, "fasterrcnn_json")
+
+ for image_name in image_files:
+ with open(image_name, 'rb') as f:
+ img = f.read()
+ annos = np.array(image_anno_dict[image_name], dtype=np.int32)
+ row = {"image": img, "annotation": annos}
+ writer.write_raw_data([row])
+ writer.commit()
+
+
+def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
+ is_training=True, num_parallel_workers=4):
+ """Creatr FasterRcnn dataset with MindDataset."""
+ ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
+ num_parallel_workers=1, shuffle=False)
+ decode = C.Decode()
+ ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
+ compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
+
+ hwc_to_chw = C.HWC2CHW()
+ normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
+ horizontally_op = C.RandomHorizontalFlip(1)
+ type_cast0 = CC.TypeCast(mstype.float32)
+ type_cast1 = CC.TypeCast(mstype.float16)
+ type_cast2 = CC.TypeCast(mstype.int32)
+ type_cast3 = CC.TypeCast(mstype.bool_)
+
+ if is_training:
+ ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
+ output_columns=["image", "image_shape", "box", "label", "valid_num"],
+ column_order=["image", "image_shape", "box", "label", "valid_num"],
+ num_parallel_workers=num_parallel_workers)
+
+ flip = (np.random.rand() < config.flip_ratio)
+ if flip:
+ ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
+ num_parallel_workers=12)
+ ds = ds.map(operations=flipped_generation,
+ input_columns=["image", "image_shape", "box", "label", "valid_num"],
+ num_parallel_workers=num_parallel_workers)
+ else:
+ ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
+ num_parallel_workers=12)
+ ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
+ num_parallel_workers=12)
+
+ else:
+ ds = ds.map(operations=compose_map_func,
+ input_columns=["image", "annotation"],
+ output_columns=["image", "image_shape", "box", "label", "valid_num"],
+ column_order=["image", "image_shape", "box", "label", "valid_num"],
+ num_parallel_workers=num_parallel_workers)
+
+ ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
+ num_parallel_workers=24)
+
+ # transpose_column from python to c
+ ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
+ ds = ds.map(operations=[type_cast1], input_columns=["box"])
+ ds = ds.map(operations=[type_cast2], input_columns=["label"])
+ ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
+ ds = ds.batch(batch_size, drop_remainder=True)
+ ds = ds.repeat(repeat_num)
+
+ return ds
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/lr_schedule.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/lr_schedule.py
new file mode 100644
index 0000000..d46510a
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/lr_schedule.py
@@ -0,0 +1,42 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""lr generator for fasterrcnn"""
+import math
+
+def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
+ lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
+ learning_rate = float(init_lr) + lr_inc * current_step
+ return learning_rate
+
+def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
+ base = float(current_step - warmup_steps) / float(decay_steps)
+ learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
+ return learning_rate
+
+def dynamic_lr(config, rank_size=1):
+ """dynamic learning rate generator"""
+ base_lr = config.base_lr
+
+ base_step = (config.base_step // rank_size) + rank_size
+ total_steps = int(base_step * config.total_epoch)
+ warmup_steps = int(config.warmup_step)
+ lr = []
+ for i in range(total_steps):
+ if i < warmup_steps:
+ lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
+ else:
+ lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
+
+ return lr
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/network_define.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/network_define.py
new file mode 100644
index 0000000..e923bc6
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/network_define.py
@@ -0,0 +1,184 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""FasterRcnn training network wrapper."""
+
+import time
+import numpy as np
+import mindspore.nn as nn
+from mindspore.common.tensor import Tensor
+from mindspore.ops import functional as F
+from mindspore.ops import composite as C
+from mindspore import ParameterTuple
+from mindspore.train.callback import Callback
+from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
+
+# pylint: disable=locally-disabled, missing-docstring, unused-argument
+
+
+time_stamp_init = False
+time_stamp_first = 0
+class LossCallBack(Callback):
+ """
+ Monitor the loss in training.
+
+ If the loss is NAN or INF terminating training.
+
+ Note:
+ If per_print_times is 0 do not print loss.
+
+ Args:
+ per_print_times (int): Print loss every times. Default: 1.
+ """
+
+ def __init__(self, per_print_times=1, rank_id=0):
+ super(LossCallBack, self).__init__()
+ if not isinstance(per_print_times, int) or per_print_times < 0:
+ raise ValueError("print_step must be int and >= 0.")
+ self._per_print_times = per_print_times
+ self.count = 0
+ self.rpn_loss_sum = 0
+ self.rcnn_loss_sum = 0
+ self.rpn_cls_loss_sum = 0
+ self.rpn_reg_loss_sum = 0
+ self.rcnn_cls_loss_sum = 0
+ self.rcnn_reg_loss_sum = 0
+ self.rank_id = rank_id
+
+ global time_stamp_init, time_stamp_first
+ if not time_stamp_init:
+ time_stamp_first = time.time()
+ time_stamp_init = True
+
+ def step_end(self, run_context):
+ cb_params = run_context.original_args()
+ rpn_loss = cb_params.net_outputs[0].asnumpy()
+ rcnn_loss = cb_params.net_outputs[1].asnumpy()
+ rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
+
+ rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
+ rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
+ rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
+
+ self.count += 1
+ self.rpn_loss_sum += float(rpn_loss)
+ self.rcnn_loss_sum += float(rcnn_loss)
+ self.rpn_cls_loss_sum += float(rpn_cls_loss)
+ self.rpn_reg_loss_sum += float(rpn_reg_loss)
+ self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
+ self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
+
+ cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
+
+ if self.count >= 1:
+ global time_stamp_first
+ time_stamp_current = time.time()
+
+ rpn_loss = self.rpn_loss_sum/self.count
+ rcnn_loss = self.rcnn_loss_sum/self.count
+ rpn_cls_loss = self.rpn_cls_loss_sum/self.count
+
+ rpn_reg_loss = self.rpn_reg_loss_sum/self.count
+ rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count
+ rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count
+
+ total_loss = rpn_loss + rcnn_loss
+
+ loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
+ loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
+ "rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, total_loss: %.5f" %
+ (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
+ rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
+ rcnn_cls_loss, rcnn_reg_loss, total_loss))
+ loss_file.write("\n")
+ loss_file.close()
+
+ self.count = 0
+ self.rpn_loss_sum = 0
+ self.rcnn_loss_sum = 0
+ self.rpn_cls_loss_sum = 0
+ self.rpn_reg_loss_sum = 0
+ self.rcnn_cls_loss_sum = 0
+ self.rcnn_reg_loss_sum = 0
+
+class LossNet(nn.Cell):
+ """FasterRcnn loss method"""
+ def construct(self, x1, x2, x3, x4, x5, x6):
+ return x1 + x2
+
+class WithLossCell(nn.Cell):
+ """
+ Wrap the network with loss function to compute loss.
+
+ Args:
+ backbone (Cell): The target network to wrap.
+ loss_fn (Cell): The loss function used to compute loss.
+ """
+ def __init__(self, backbone, loss_fn):
+ super(WithLossCell, self).__init__(auto_prefix=False)
+ self._backbone = backbone
+ self._loss_fn = loss_fn
+
+ def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
+ loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
+ return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6)
+
+ @property
+ def backbone_network(self):
+ """
+ Get the backbone network.
+
+ Returns:
+ Cell, return backbone network.
+ """
+ return self._backbone
+
+
+class TrainOneStepCell(nn.Cell):
+ """
+ Network training package class.
+
+ Append an optimizer to the training network after that the construct function
+ can be called to create the backward graph.
+
+ Args:
+ network (Cell): The training network.
+ network_backbone (Cell): The forward network.
+ optimizer (Cell): Optimizer for updating the weights.
+ sens (Number): The adjust parameter. Default value is 1.0.
+ reduce_flag (bool): The reduce flag. Default value is False.
+ mean (bool): Allreduce method. Default value is False.
+ degree (int): Device number. Default value is None.
+ """
+ def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
+ super(TrainOneStepCell, self).__init__(auto_prefix=False)
+ self.network = network
+ self.network.set_grad()
+ self.backbone = network_backbone
+ self.weights = ParameterTuple(network.trainable_params())
+ self.optimizer = optimizer
+ self.grad = C.GradOperation(get_by_list=True,
+ sens_param=True)
+ self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
+ self.reduce_flag = reduce_flag
+ if reduce_flag:
+ self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
+
+ def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
+ weights = self.weights
+ loss1, loss2, loss3, loss4, loss5, loss6 = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
+ grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
+ if self.reduce_flag:
+ grads = self.grad_reducer(grads)
+ return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6
diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/src/util.py b/examples/model_security/model_attacks/cv/faster_rcnn/src/util.py
new file mode 100644
index 0000000..9b1045d
--- /dev/null
+++ b/examples/model_security/model_attacks/cv/faster_rcnn/src/util.py
@@ -0,0 +1,227 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""coco eval for fasterrcnn"""
+import json
+import numpy as np
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+import mmcv
+
+# pylint: disable=locally-disabled, invalid-name
+
+_init_value = np.array(0.0)
+summary_init = {
+ 'Precision/mAP': _init_value,
+ 'Precision/mAP@.50IOU': _init_value,
+ 'Precision/mAP@.75IOU': _init_value,
+ 'Precision/mAP (small)': _init_value,
+ 'Precision/mAP (medium)': _init_value,
+ 'Precision/mAP (large)': _init_value,
+ 'Recall/AR@1': _init_value,
+ 'Recall/AR@10': _init_value,
+ 'Recall/AR@100': _init_value,
+ 'Recall/AR@100 (small)': _init_value,
+ 'Recall/AR@100 (medium)': _init_value,
+ 'Recall/AR@100 (large)': _init_value,
+}
+
+
+def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False):
+ """coco eval for fasterrcnn"""
+ anns = json.load(open(result_files['bbox']))
+ if not anns:
+ return summary_init
+
+ if mmcv.is_str(coco):
+ coco = COCO(coco)
+ assert isinstance(coco, COCO)
+
+ for res_type in result_types:
+ result_file = result_files[res_type]
+ assert result_file.endswith('.json')
+
+ coco_dets = coco.loadRes(result_file)
+ gt_img_ids = coco.getImgIds()
+ det_img_ids = coco_dets.getImgIds()
+ iou_type = 'bbox' if res_type == 'proposal' else res_type
+ cocoEval = COCOeval(coco, coco_dets, iou_type)
+ if res_type == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.params.maxDets = list(max_dets)
+
+ tgt_ids = gt_img_ids if not single_result else det_img_ids
+
+ if single_result:
+ res_dict = dict()
+ for id_i in tgt_ids:
+ cocoEval = COCOeval(coco, coco_dets, iou_type)
+ if res_type == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.params.maxDets = list(max_dets)
+
+ cocoEval.params.imgIds = [id_i]
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]})
+
+ cocoEval = COCOeval(coco, coco_dets, iou_type)
+ if res_type == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.params.maxDets = list(max_dets)
+
+ cocoEval.params.imgIds = tgt_ids
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+
+ summary_metrics = {
+ 'Precision/mAP': cocoEval.stats[0],
+ 'Precision/mAP@.50IOU': cocoEval.stats[1],
+ 'Precision/mAP@.75IOU': cocoEval.stats[2],
+ 'Precision/mAP (small)': cocoEval.stats[3],
+ 'Precision/mAP (medium)': cocoEval.stats[4],
+ 'Precision/mAP (large)': cocoEval.stats[5],
+ 'Recall/AR@1': cocoEval.stats[6],
+ 'Recall/AR@10': cocoEval.stats[7],
+ 'Recall/AR@100': cocoEval.stats[8],
+ 'Recall/AR@100 (small)': cocoEval.stats[9],
+ 'Recall/AR@100 (medium)': cocoEval.stats[10],
+ 'Recall/AR@100 (large)': cocoEval.stats[11],
+ }
+
+ return summary_metrics
+
+
+def xyxy2xywh(bbox):
+ _bbox = bbox.tolist()
+ return [
+ _bbox[0],
+ _bbox[1],
+ _bbox[2] - _bbox[0] + 1,
+ _bbox[3] - _bbox[1] + 1,
+ ]
+
+def bbox2result_1image(bboxes, labels, num_classes):
+ """Convert detection results to a list of numpy arrays.
+
+ Args:
+ bboxes (Tensor): shape (n, 5)
+ labels (Tensor): shape (n, )
+ num_classes (int): class number, including background class
+
+ Returns:
+ list(ndarray): bbox results of each class
+ """
+ if bboxes.shape[0] == 0:
+ result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)]
+ else:
+ result = [bboxes[labels == i, :] for i in range(num_classes - 1)]
+ return result
+
+def proposal2json(dataset, results):
+ """convert proposal to json mode"""
+ img_ids = dataset.getImgIds()
+ json_results = []
+ dataset_len = dataset.get_dataset_size()*2
+ for idx in range(dataset_len):
+ img_id = img_ids[idx]
+ bboxes = results[idx]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = 1
+ json_results.append(data)
+ return json_results
+
+def det2json(dataset, results):
+ """convert det to json mode"""
+ cat_ids = dataset.getCatIds()
+ img_ids = dataset.getImgIds()
+ json_results = []
+ dataset_len = len(img_ids)
+ for idx in range(dataset_len):
+ img_id = img_ids[idx]
+ if idx == len(results): break
+ result = results[idx]
+ for label, result_label in enumerate(result):
+ bboxes = result_label
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = cat_ids[label]
+ json_results.append(data)
+ return json_results
+
+def segm2json(dataset, results):
+ """convert segm to json mode"""
+ bbox_json_results = []
+ segm_json_results = []
+ for idx in range(len(dataset)):
+ img_id = dataset.img_ids[idx]
+ det, seg = results[idx]
+ for label, det_label in enumerate(det):
+ # bbox results
+ bboxes = det_label
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = dataset.cat_ids[label]
+ bbox_json_results.append(data)
+
+ if len(seg) == 2:
+ segms = seg[0][label]
+ mask_score = seg[1][label]
+ else:
+ segms = seg[label]
+ mask_score = [bbox[4] for bbox in bboxes]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['score'] = float(mask_score[i])
+ data['category_id'] = dataset.cat_ids[label]
+ segms[i]['counts'] = segms[i]['counts'].decode()
+ data['segmentation'] = segms[i]
+ segm_json_results.append(data)
+ return bbox_json_results, segm_json_results
+
+def results2json(dataset, results, out_file):
+ """convert result convert to json mode"""
+ result_files = dict()
+ if isinstance(results[0], list):
+ json_results = det2json(dataset, results)
+ result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
+ result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
+ mmcv.dump(json_results, result_files['bbox'])
+ elif isinstance(results[0], tuple):
+ json_results = segm2json(dataset, results)
+ result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
+ result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
+ result_files['segm'] = '{}.{}.json'.format(out_file, 'segm')
+ mmcv.dump(json_results[0], result_files['bbox'])
+ mmcv.dump(json_results[1], result_files['segm'])
+ elif isinstance(results[0], np.ndarray):
+ json_results = proposal2json(dataset, results)
+ result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal')
+ mmcv.dump(json_results, result_files['proposal'])
+ else:
+ raise TypeError('invalid type of results')
+ return result_files
diff --git a/mindarmour/adv_robustness/attacks/gradient_method.py b/mindarmour/adv_robustness/attacks/gradient_method.py
index 5412631..ea705e3 100644
--- a/mindarmour/adv_robustness/attacks/gradient_method.py
+++ b/mindarmour/adv_robustness/attacks/gradient_method.py
@@ -19,7 +19,7 @@ from abc import abstractmethod
import numpy as np
from mindspore import Tensor
-from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
+from mindspore.nn import Cell
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss
from mindarmour.utils.logger import LogUtil
@@ -44,12 +44,13 @@ class GradientMethod(Attack):
Default: None.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: None.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
- >>> attack = FastGradientMethod(network)
+ >>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""
@@ -71,9 +72,10 @@ class GradientMethod(Attack):
else:
self._alpha = alpha
if loss_fn is None:
- loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False)
- with_loss_cell = WithLossCell(self._network, loss_fn)
- self._grad_all = GradWrapWithLoss(with_loss_cell)
+ self._grad_all = self._network
+ else:
+ with_loss_cell = WithLossCell(self._network, loss_fn)
+ self._grad_all = GradWrapWithLoss(with_loss_cell)
self._grad_all.set_train()
def generate(self, inputs, labels):
@@ -83,13 +85,19 @@ class GradientMethod(Attack):
Args:
inputs (numpy.ndarray): Benign input samples used as references to create
adversarial examples.
- labels (numpy.ndarray): Original/target labels.
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, generated adversarial examples.
"""
- inputs, labels = check_pair_numpy_param('inputs', inputs,
- 'labels', labels)
+ if isinstance(labels, tuple):
+ for i, labels_item in enumerate(labels):
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels[{}]'.format(i), labels_item)
+ else:
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels', labels)
self._dtype = inputs.dtype
gradient = self._gradient(inputs, labels)
# use random method or not
@@ -117,7 +125,8 @@ class GradientMethod(Attack):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
- labels (numpy.ndarray): Original/target labels.
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Raises:
NotImplementedError: It is an abstract method.
@@ -149,12 +158,13 @@ class FastGradientMethod(GradientMethod):
Possible values: np.inf, 1 or 2. Default: 2.
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
- >>> attack = FastGradientMethod(network)
+ >>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""
@@ -175,12 +185,19 @@ class FastGradientMethod(GradientMethod):
Args:
inputs (numpy.ndarray): Input sample.
- labels (numpy.ndarray): Original/target label.
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, gradient of inputs.
"""
- out_grad = self._grad_all(Tensor(inputs), Tensor(labels))
+ if isinstance(labels, tuple):
+ labels_tensor = tuple()
+ for item in labels:
+ labels_tensor += (Tensor(item),)
+ else:
+ labels_tensor = (Tensor(labels),)
+ out_grad = self._grad_all(Tensor(inputs), *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
gradient = out_grad.asnumpy()
@@ -210,7 +227,8 @@ class RandomFastGradientMethod(FastGradientMethod):
Possible values: np.inf, 1 or 2. Default: 2.
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
Raises:
ValueError: eps is smaller than alpha!
@@ -218,7 +236,7 @@ class RandomFastGradientMethod(FastGradientMethod):
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
- >>> attack = RandomFastGradientMethod(network)
+ >>> attack = RandomFastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""
@@ -254,12 +272,13 @@ class FastGradientSignMethod(GradientMethod):
In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
- >>> attack = FastGradientSignMethod(network)
+ >>> attack = FastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""
@@ -279,12 +298,19 @@ class FastGradientSignMethod(GradientMethod):
Args:
inputs (numpy.ndarray): Input samples.
- labels (numpy.ndarray): Original/target labels.
+ labels (union[numpy.ndarray, tuple]): original/target labels. \
+ for each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, gradient of inputs.
"""
- out_grad = self._grad_all(Tensor(inputs), Tensor(labels))
+ if isinstance(labels, tuple):
+ labels_tensor = tuple()
+ for item in labels:
+ labels_tensor += (Tensor(item),)
+ else:
+ labels_tensor = (Tensor(labels),)
+ out_grad = self._grad_all(Tensor(inputs), *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
gradient = out_grad.asnumpy()
@@ -311,7 +337,8 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): True: targeted attack. False: untargeted attack.
Default: False.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
Raises:
ValueError: eps is smaller than alpha!
@@ -319,7 +346,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
- >>> attack = RandomFastGradientSignMethod(network)
+ >>> attack = RandomFastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""
@@ -350,12 +377,13 @@ class LeastLikelyClassMethod(FastGradientSignMethod):
Default: None.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
- >>> attack = LeastLikelyClassMethod(network)
+ >>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""
@@ -384,7 +412,8 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
Default: 0.035.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
- loss_fn (Loss): Loss function for optimization.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
Raises:
ValueError: eps is smaller than alpha!
@@ -392,7 +421,7 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
- >>> attack = RandomLeastLikelyClassMethod(network)
+ >>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""
diff --git a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py
index 6f94d3d..462edbe 100644
--- a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py
+++ b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py
@@ -17,7 +17,7 @@ from abc import abstractmethod
import numpy as np
from PIL import Image, ImageOps
-from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
+from mindspore.nn import Cell
from mindspore import Tensor
from mindarmour.utils.logger import LogUtil
@@ -114,7 +114,8 @@ class IterativeGradientMethod(Attack):
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
nb_iter (int): Number of iteration. Default: 5.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5,
loss_fn=None):
@@ -123,12 +124,15 @@ class IterativeGradientMethod(Attack):
self._eps = check_value_positive('eps', eps)
self._eps_iter = check_value_positive('eps_iter', eps_iter)
self._nb_iter = check_int_positive('nb_iter', nb_iter)
- self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
- for b in self._bounds:
- _ = check_param_multi_types('bound', b, [int, float])
+ self._bounds = None
+ if bounds is not None:
+ self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
+ for b in self._bounds:
+ _ = check_param_multi_types('bound', b, [int, float])
if loss_fn is None:
- loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False)
- self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn))
+ self._loss_grad = network
+ else:
+ self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn))
self._loss_grad.set_train()
@abstractmethod
@@ -139,8 +143,8 @@ class IterativeGradientMethod(Attack):
Args:
inputs (numpy.ndarray): Benign input samples used as references to create
adversarial examples.
- labels (numpy.ndarray): Original/target labels.
-
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Raises:
NotImplementedError: This function is not available in
IterativeGradientMethod.
@@ -177,12 +181,13 @@ class BasicIterativeMethod(IterativeGradientMethod):
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
nb_iter (int): Number of iteration. Default: 5.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
attack (class): The single step gradient method of each iteration. In
this class, FGSM is used.
Examples:
- >>> attack = BasicIterativeMethod(network)
+ >>> attack = BasicIterativeMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
is_targeted=False, nb_iter=5, loss_fn=None):
@@ -207,8 +212,8 @@ class BasicIterativeMethod(IterativeGradientMethod):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
- labels (numpy.ndarray): Original/target labels.
-
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, generated adversarial examples.
@@ -218,8 +223,13 @@ class BasicIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]])
"""
- inputs, labels = check_pair_numpy_param('inputs', inputs,
- 'labels', labels)
+ if isinstance(labels, tuple):
+ for i, labels_item in enumerate(labels):
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels[{}]'.format(i), labels_item)
+ else:
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels', labels)
arr_x = inputs
if self._bounds is not None:
clip_min, clip_max = self._bounds
@@ -267,7 +277,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
decay_factor (float): Decay factor in iterations. Default: 1.0.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
@@ -290,7 +301,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
- labels (numpy.ndarray): Original/target labels.
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, generated adversarial examples.
@@ -301,8 +313,13 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]])
"""
- inputs, labels = check_pair_numpy_param('inputs', inputs,
- 'labels', labels)
+ if isinstance(labels, tuple):
+ for i, labels_item in enumerate(labels):
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels[{}]'.format(i), labels_item)
+ else:
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels', labels)
arr_x = inputs
momentum = 0
if self._bounds is not None:
@@ -340,7 +357,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
Args:
inputs (numpy.ndarray): Input samples.
- labels (numpy.ndarray): Original/target labels.
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, gradient of labels w.r.t inputs.
@@ -350,7 +368,13 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
"""
# get grad of loss over x
- out_grad = self._loss_grad(Tensor(inputs), Tensor(labels))
+ if isinstance(labels, tuple):
+ labels_tensor = tuple()
+ for item in labels:
+ labels_tensor += (Tensor(item),)
+ else:
+ labels_tensor = (Tensor(labels),)
+ out_grad = self._loss_grad(Tensor(inputs), *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
gradient = out_grad.asnumpy()
@@ -384,7 +408,8 @@ class ProjectedGradientDescent(BasicIterativeMethod):
nb_iter (int): Number of iteration. Default: 5.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
@@ -406,7 +431,8 @@ class ProjectedGradientDescent(BasicIterativeMethod):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
- labels (numpy.ndarray): Original/target labels.
+ labels (Union[numpy.ndarray, tuple]): Original/target labels. \
+ For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, generated adversarial examples.
@@ -417,8 +443,13 @@ class ProjectedGradientDescent(BasicIterativeMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
"""
- inputs, labels = check_pair_numpy_param('inputs', inputs,
- 'labels', labels)
+ if isinstance(labels, tuple):
+ for i, labels_item in enumerate(labels):
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels[{}]'.format(i), labels_item)
+ else:
+ inputs, _ = check_pair_numpy_param('inputs', inputs, \
+ 'labels', labels)
arr_x = inputs
if self._bounds is not None:
clip_min, clip_max = self._bounds
@@ -460,7 +491,8 @@ class DiverseInputIterativeMethod(BasicIterativeMethod):
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
prob (float): Transformation probability. Default: 0.5.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0),
is_targeted=False, prob=0.5, loss_fn=None):
@@ -495,7 +527,8 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'l1'.
prob (float): Transformation probability. Default: 0.5.
- loss_fn (Loss): Loss function for optimization. Default: None.
+ loss_fn (Loss): Loss function for optimization. If None, the input network \
+ is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0),
is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None):
diff --git a/mindarmour/fuzz_testing/fuzzing.py b/mindarmour/fuzz_testing/fuzzing.py
index ff14adf..407ed90 100644
--- a/mindarmour/fuzz_testing/fuzzing.py
+++ b/mindarmour/fuzz_testing/fuzzing.py
@@ -19,6 +19,7 @@ from random import choice
import numpy as np
from mindspore import Model
from mindspore import Tensor
+from mindspore import nn
from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_param_multi_types, check_norm_level, check_param_in_range, \
@@ -451,6 +452,8 @@ class Fuzzer:
else:
network = self._target_model._network
loss_fn = self._target_model._loss_fn
+ if loss_fn is None:
+ loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
mutates[method] = self._strategies[method](network,
loss_fn=loss_fn)
return mutates
diff --git a/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py b/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py
index 23e5838..f2c92e9 100644
--- a/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py
+++ b/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py
@@ -18,7 +18,7 @@ import numpy as np
import pytest
import mindspore.ops.operations as P
-from mindspore.nn import Cell
+from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
import mindspore.context as context
from mindarmour.adv_robustness.attacks import FastGradientMethod
@@ -67,7 +67,7 @@ def test_batch_generate_attack():
label = np.random.randint(0, 10, 128).astype(np.int32)
label = np.eye(10)[label].astype(np.float32)
- attack = FastGradientMethod(Net())
+ attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.batch_generate(input_np, label, batch_size=32)
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
diff --git a/tests/ut/python/adv_robustness/attacks/test_gradient_method.py b/tests/ut/python/adv_robustness/attacks/test_gradient_method.py
index 8e6707c..bf7a638 100644
--- a/tests/ut/python/adv_robustness/attacks/test_gradient_method.py
+++ b/tests/ut/python/adv_robustness/attacks/test_gradient_method.py
@@ -71,7 +71,7 @@ def test_fast_gradient_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = FastGradientMethod(Net())
+ attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
@@ -91,7 +91,7 @@ def test_fast_gradient_method_gpu():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = FastGradientMethod(Net())
+ attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
@@ -132,7 +132,7 @@ def test_random_fast_gradient_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = RandomFastGradientMethod(Net())
+ attack = RandomFastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Random fast gradient method: ' \
@@ -154,7 +154,7 @@ def test_fast_gradient_sign_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = FastGradientSignMethod(Net())
+ attack = FastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Fast gradient sign method: generate' \
@@ -176,7 +176,7 @@ def test_random_fast_gradient_sign_method():
label = np.asarray([2], np.int32)
label = np.eye(28)[label].astype(np.float32)
- attack = RandomFastGradientSignMethod(Net())
+ attack = RandomFastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Random fast gradient sign method: ' \
@@ -198,7 +198,7 @@ def test_least_likely_class_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = LeastLikelyClassMethod(Net())
+ attack = LeastLikelyClassMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Least likely class method: generate' \
@@ -220,7 +220,8 @@ def test_random_least_likely_class_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01)
+ attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01, \
+ loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Random least likely class method: ' \
@@ -239,5 +240,6 @@ def test_assert_error():
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
with pytest.raises(ValueError) as e:
- assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21)
+ assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21, \
+ loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
assert str(e.value) == 'eps must be larger than alpha!'
diff --git a/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py b/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py
index 8263468..da330bc 100644
--- a/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py
+++ b/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py
@@ -20,6 +20,7 @@ import pytest
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore import context
+from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindarmour.adv_robustness.attacks import BasicIterativeMethod
from mindarmour.adv_robustness.attacks import MomentumIterativeMethod
@@ -70,7 +71,7 @@ def test_basic_iterative_method():
for i in range(5):
net = Net()
- attack = BasicIterativeMethod(net, nb_iter=i + 1)
+ attack = BasicIterativeMethod(net, nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(
ms_adv_x != input_np), 'Basic iterative method: generate value' \
@@ -91,7 +92,7 @@ def test_momentum_iterative_method():
label = np.eye(3)[label].astype(np.float32)
for i in range(5):
- attack = MomentumIterativeMethod(Net(), nb_iter=i + 1)
+ attack = MomentumIterativeMethod(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \
' value must not be equal to' \
@@ -112,7 +113,7 @@ def test_projected_gradient_descent_method():
label = np.eye(3)[label].astype(np.float32)
for i in range(5):
- attack = ProjectedGradientDescent(Net(), nb_iter=i + 1)
+ attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(
@@ -134,7 +135,7 @@ def test_diverse_input_iterative_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = DiverseInputIterativeMethod(Net())
+ attack = DiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \
' value must not be equal to' \
@@ -154,7 +155,7 @@ def test_momentum_diverse_input_iterative_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
- attack = MomentumDiverseInputIterativeMethod(Net())
+ attack = MomentumDiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \
'generate value must not be equal to' \
@@ -167,10 +168,7 @@ def test_momentum_diverse_input_iterative_method():
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_error():
- with pytest.raises(TypeError):
- # check_param_multi_types
- assert IterativeGradientMethod(Net(), bounds=None)
- attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0))
+ attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
with pytest.raises(NotImplementedError):
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32)
label = np.asarray([2], np.int32)
diff --git a/tests/ut/python/adv_robustness/defenses/test_ead.py b/tests/ut/python/adv_robustness/defenses/test_ead.py
index eb057e7..e1bd240 100644
--- a/tests/ut/python/adv_robustness/defenses/test_ead.py
+++ b/tests/ut/python/adv_robustness/defenses/test_ead.py
@@ -59,8 +59,8 @@ def test_ead():
optimizer = Momentum(net.trainable_params(), 0.001, 0.9)
net = Net()
- fgsm = FastGradientSignMethod(net)
- pgd = ProjectedGradientDescent(net)
+ fgsm = FastGradientSignMethod(net, loss_fn=loss_fn)
+ pgd = ProjectedGradientDescent(net, loss_fn=loss_fn)
ead = EnsembleAdversarialDefense(net, [fgsm, pgd], loss_fn=loss_fn,
optimizer=optimizer)
LOGGER.set_level(logging.DEBUG)
diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py
index 282e1f4..b4912a5 100644
--- a/tests/ut/python/fuzzing/test_coverage_metrics.py
+++ b/tests/ut/python/fuzzing/test_coverage_metrics.py
@@ -117,7 +117,7 @@ def test_lenet_mnist_coverage_ascend():
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
# generate adv_data
- attack = FastGradientSignMethod(net, eps=0.3)
+ attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())