Browse Source

Added support for FasterRCNN

tags/v1.2.1
liuluobin 4 years ago
parent
commit
0746b465d4
26 changed files with 3578 additions and 75 deletions
  1. +47
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/README.md
  2. +135
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pgd.py
  3. +31
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/__init__.py
  4. +84
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/anchor_generator.py
  5. +166
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py
  6. +197
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py
  7. +428
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py
  8. +114
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py
  9. +201
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/proposal_generator.py
  10. +173
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rcnn.py
  11. +250
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/resnet50.py
  12. +181
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/roi_align.py
  13. +315
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rpn.py
  14. +158
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/config.py
  15. +505
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/dataset.py
  16. +42
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/lr_schedule.py
  17. +184
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/network_define.py
  18. +227
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/src/util.py
  19. +55
    -26
      mindarmour/adv_robustness/attacks/gradient_method.py
  20. +60
    -27
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  21. +3
    -0
      mindarmour/fuzz_testing/fuzzing.py
  22. +2
    -2
      tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py
  23. +10
    -8
      tests/ut/python/adv_robustness/attacks/test_gradient_method.py
  24. +7
    -9
      tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py
  25. +2
    -2
      tests/ut/python/adv_robustness/defenses/test_ead.py
  26. +1
    -1
      tests/ut/python/fuzzing/test_coverage_metrics.py

+ 47
- 0
examples/model_security/model_attacks/cv/faster_rcnn/README.md View File

@@ -0,0 +1,47 @@
# Dataset

Dataset used: [COCO2017](<https://cocodataset.org/>)

- Dataset size:19G
- Train:18G,118000 images
- Val:1G,5000 images
- Annotations:241M,instances,captions,person_keypoints etc
- Data format:image and json files
- Note:Data will be processed in dataset.py

# Environment Requirements

- Install [MindSpore](https://www.mindspore.cn/install/en).

- Download the dataset COCO2017.

- We use COCO2017 as dataset in this example.

Install Cython and pycocotool, and you can also install mmcv to process data.

```
pip install Cython

pip install pycocotools

pip install mmcv==0.2.14
```

And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:

```
.
└─cocodataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
├─val2017
└─train2017
```

# Quick start
You can download the pre-trained model checkpoint file [here](<https://www.mindspore.cn/resources/hub/details?2505/MindSpore/ascend/0.7/fasterrcnn_v1.0_coco2017>).
```
python coco_attack_pgd.py --ann_file [VAL_JSON_FILE] --pre_trained [PRETRAINED_CHECKPOINT_FILE]
```
> Adversarial samples will be generated and saved as pickle file.

+ 135
- 0
examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pgd.py View File

@@ -0,0 +1,135 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PGD attack for faster rcnn"""
import os
import argparse
import pickle

from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation

from mindarmour.adv_robustness.attacks import ProjectedGradientDescent

from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset

# pylint: disable=locally-disabled, unused-argument, redefined-outer-name

set_seed(1)

parser = argparse.ArgumentParser(description='FasterRCNN attack')
parser.add_argument('--ann_file', type=str, required=True, help='Ann file path.')
parser.add_argument('--pre_trained', type=str, required=True, help='pre-trained ckpt file path for target model.')
parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.')
parser.add_argument('--num', type=int, default=5, help='Number of adversarial examples.')
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=args.device_id)


class LossNet(Cell):
"""loss function."""
def construct(self, x1, x2, x3, x4, x5, x6):
return x4 + x6


class WithLossCell(Cell):
"""Wrap the network with loss function."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn

def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num):
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6)

@property
def backbone_network(self):
return self._backbone


class GradWrapWithLoss(Cell):
"""
Construct a network to compute the gradient of loss function in \
input space and weighted by `weight`.
"""
def __init__(self, network):
super(GradWrapWithLoss, self).__init__()
self._grad_all = GradOperation(get_all=True, sens_param=False)
self._network = network

def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num):
gout = self._grad_all(self._network)(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
return gout[0]


if __name__ == '__main__':
prefix = 'FasterRcnn_eval.mindrecord'
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
pre_trained = args.pre_trained
ann_file = args.ann_file

print("CHECKING MINDRECORD FILES ...")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if os.path.isdir(config.coco_root):
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")

print('Start generate adversarial samples.')

# build network and dataset
ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \
repeat_num=1, is_training=True)
net = Faster_Rcnn_Resnet50(config)
param_dict = load_checkpoint(pre_trained)
load_param_into_net(net, param_dict)
net = net.set_train()

# build attacker
with_loss_cell = WithLossCell(net, LossNet())
grad_with_loss_net = GradWrapWithLoss(with_loss_cell)
attack = ProjectedGradientDescent(grad_with_loss_net, bounds=None, eps=0.1)

# generate adversarial samples
num = args.num
num_batches = num // config.test_batch_size
channel = 3
adv_samples = [0] * (num_batches * config.test_batch_size)
adv_id = 0
for data in ds.create_dict_iterator(num_epochs=num_batches):
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']

adv_img = attack.generate(img_data.asnumpy(), \
(img_metas.asnumpy(), gt_bboxes.asnumpy(), gt_labels.asnumpy(), gt_num.asnumpy()))
for item in adv_img:
adv_samples[adv_id] = item
adv_id += 1

pickle.dump(adv_samples, open('adv_samples.pkl', 'wb'))
print('Generate adversarial samples complete.')

+ 31
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/__init__.py View File

@@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn Init."""

from .resnet50 import ResNetFea, ResidualBlockUsing
from .bbox_assign_sample import BboxAssignSample
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
from .proposal_generator import Proposal
from .rcnn import Rcnn
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator

__all__ = [
"ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn",
"FeatPyramidNeck", "Proposal", "Rcnn",
"RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing"
]

+ 84
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/anchor_generator.py View File

@@ -0,0 +1,84 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn anchor generator."""

import numpy as np

class AnchorGenerator():
"""Anchor generator for FasterRcnn."""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
"""Anchor generator init method."""
self.base_size = base_size
self.scales = np.array(scales)
self.ratios = np.array(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()

def gen_base_anchors(self):
"""Generate a single anchor."""
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr

h_ratios = np.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1)

base_anchors = np.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1).round()

return base_anchors

def _meshgrid(self, x, y, row_major=True):
"""Generate grid."""
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
yy = np.repeat(y, len(x))
if row_major:
return xx, yy

return yy, xx

def grid_anchors(self, featmap_size, stride=16):
"""Generate anchor list."""
base_anchors = self.base_anchors

feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = shifts.astype(base_anchors.dtype)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)

all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.reshape(-1, 4)

return all_anchors

+ 166
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py View File

@@ -0,0 +1,166 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn positive and negative sample screening for RPN."""

import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype

# pylint: disable=locally-disabled, invalid-name, missing-docstring


class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler defination.

Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.

Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)

Examples:
BboxAssignSample(config, 2, 1024, True)
"""

def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
self.batch_size = batch_size

self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
self.zero_thr = Tensor(0.0, mstype.float16)

self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
self.num_expected_pos = cfg.num_expected_pos
self.num_expected_neg = cfg.num_expected_neg
self.add_gt_as_proposals = add_gt_as_proposals

if self.add_gt_as_proposals:
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))

self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()

self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))

self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))


def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)

overlaps = self.iou(bboxes, gt_bboxes_i)

max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)

neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
self.less(max_overlaps_w_gt, self.neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)

pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
assigned_gt_inds4 = assigned_gt_inds3
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])

pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))

assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)

assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)

pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))

pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))

pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))

neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))

num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)

pos_bboxes_ = self.gatherND(bboxes, pos_index)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)

pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)

valid_pos_index = self.cast(valid_pos_index, mstype.int32)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
total_index = self.concat((pos_index, neg_index))
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))

return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
labels_total, self.cast(label_weights_total, mstype.bool_)

+ 197
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py View File

@@ -0,0 +1,197 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn tpositive and negative sample screening for Rcnn."""

import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor

# pylint: disable=locally-disabled, invalid-name, missing-docstring


class BboxAssignSampleForRcnn(nn.Cell):
"""
Bbox assigner and sampler defination.

Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.

Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)

Examples:
BboxAssignSampleForRcnn(config, 2, 1024, True)
"""

def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSampleForRcnn, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = cfg.neg_iou_thr_stage2
self.pos_iou_thr = cfg.pos_iou_thr_stage2
self.min_pos_iou = cfg.min_pos_iou_stage2
self.num_gts = cfg.num_gts
self.num_bboxes = num_bboxes
self.num_expected_pos = cfg.num_expected_pos_stage2
self.num_expected_neg = cfg.num_expected_neg_stage2
self.num_expected_total = cfg.num_expected_total_stage2

self.add_gt_as_proposals = add_gt_as_proposals
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts),
dtype=np.int32))

self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
self.concat_axis1 = P.Concat(axis=1)
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()

# Check
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))

# Init tensor
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))

self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))

self.reshape_shape_pos = (self.num_expected_pos, 1)
self.reshape_shape_neg = (self.num_expected_neg, 1)

self.scalar_zero = Tensor(0.0, dtype=mstype.float16)
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16)
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16)
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16)

def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), \
gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \
bboxes, self.check_anchor_two)

overlaps = self.iou(bboxes, gt_bboxes_i)

max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)

neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
self.scalar_zero),
self.less(max_overlaps_w_gt,
self.scalar_neg_iou_thr))

assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)

pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)

for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_ac_j = overlaps[j:j+1:1, ::]
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou)
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j))
pos_mask_j = self.logicaland(temp1, temp2)
assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3)

assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores)

bboxes = self.concat((gt_bboxes_i, bboxes))
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores)
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5))

# Get pos index
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))

pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))

num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
pos_index = self.reshape(pos_index, self.reshape_shape_pos)
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
pos_index = pos_index * valid_pos_index

pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index

pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)

# Get neg index
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))

unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
neg_index = self.reshape(neg_index, self.reshape_shape_neg)

valid_neg_index = self.cast(valid_neg_index, mstype.int32)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
neg_index = neg_index * valid_neg_index

pos_bboxes_ = self.gatherND(bboxes, pos_index)

neg_bboxes_ = self.gatherND(bboxes, neg_index)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)

total_bboxes = self.concat((pos_bboxes_, neg_bboxes_))
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask))
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask))

valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
total_mask = self.concat((valid_pos_index, valid_neg_index))

return total_bboxes, total_deltas, total_labels, total_mask

+ 428
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py View File

@@ -0,0 +1,428 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn based on ResNet50."""

import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from .resnet50 import ResNetFea, ResidualBlockUsing
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
from .proposal_generator import Proposal
from .rcnn import Rcnn
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator

# pylint: disable=locally-disabled, invalid-name, missing-docstring


class Faster_Rcnn_Resnet50(nn.Cell):
"""
FasterRcnn Network.

Note:
backbone = resnet50

Returns:
Tuple, tuple of output tensor.
rpn_loss: Scalar, Total loss of RPN subnet.
rcnn_loss: Scalar, Total loss of RCNN subnet.
rpn_cls_loss: Scalar, Classification loss of RPN subnet.
rpn_reg_loss: Scalar, Regression loss of RPN subnet.
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet.
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet.

Examples:
net = Faster_Rcnn_Resnet50()
"""
def __init__(self, config):
super(Faster_Rcnn_Resnet50, self).__init__()
self.train_batch_size = config.batch_size
self.num_classes = config.num_classes
self.anchor_scales = config.anchor_scales
self.anchor_ratios = config.anchor_ratios
self.anchor_strides = config.anchor_strides
self.target_means = tuple(config.rcnn_target_means)
self.target_stds = tuple(config.rcnn_target_stds)

# Anchor generator
anchor_base_sizes = None
self.anchor_base_sizes = list(
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes

self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))

self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)

featmap_sizes = config.feature_shapes
assert len(featmap_sizes) == len(self.anchor_generators)

self.anchor_list = self.get_anchors(featmap_sizes)

# Backbone resnet50
self.backbone = ResNetFea(ResidualBlockUsing,
config.resnet_block,
config.resnet_in_channels,
config.resnet_out_channels,
False)

# Fpn
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
config.fpn_out_channels,
config.fpn_num_outs)

# Rpn and rpn loss
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
self.rpn_with_loss = RPN(config,
self.train_batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)

# Proposal
self.proposal_generator = Proposal(config,
self.train_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator.set_train_local(config, True)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)

# Assign and sampler stage two
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
config.num_bboxes_stage2, True)
self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=self.target_means, \
stds=self.target_stds)

# Roi
self.roi_align = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
self.train_batch_size,
config.roi_align_finest_scale)
self.roi_align.set_train_local(config, True)
self.roi_align_test = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
1,
config.roi_align_finest_scale)
self.roi_align_test.set_train_local(config, False)

# Rcnn
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer['out_size'] * config.roi_layer['out_size'],
self.train_batch_size, self.num_classes)

# Op declare
self.squeeze = P.Squeeze()
self.cast = P.Cast()

self.concat = P.Concat(axis=0)
self.concat_1 = P.Concat(axis=1)
self.concat_2 = P.Concat(axis=2)
self.reshape = P.Reshape()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()

# Test mode
self.test_batch_size = config.test_batch_size
self.split = P.Split(axis=0, output_num=self.test_batch_size)
self.split_shape = P.Split(axis=0, output_num=4)
self.split_scores = P.Split(axis=1, output_num=self.num_classes)
self.split_cls = P.Split(axis=0, output_num=self.num_classes-1)
self.tile = P.Tile()
self.gather = P.GatherNd()

self.rpn_max_num = config.rpn_max_num

self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16))
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
self.ones_mask, self.zeros_mask), axis=1))
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))

self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr)
self.test_max_per_img = config.test_max_per_img
self.nms_test = P.NMSWithMask(config.test_iou_thr)
self.softmax = P.Softmax(axis=1)
self.logicand = P.LogicalAnd()
self.oneslike = P.OnesLike()
self.test_topk = P.TopK(sorted=True)
self.test_num_proposal = self.test_batch_size * self.rpn_max_num

# Improve speed
self.concat_start = min(self.num_classes - 2, 55)
self.concat_end = (self.num_classes - 1)

# Init tensor
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
dtype=np.float16) for i in range(self.train_batch_size)]

roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \
for i in range(self.test_batch_size)]

self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))

def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
x = self.backbone(img_data)
x = self.fpn_ncek(x)

rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
img_metas,
self.anchor_list,
gt_bboxes,
self.gt_labels_stage1,
gt_valids)

if self.training:
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
else:
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)

gt_labels = self.cast(gt_labels, mstype.int32)
gt_valids = self.cast(gt_valids, mstype.int32)
bboxes_tuple = ()
deltas_tuple = ()
labels_tuple = ()
mask_tuple = ()
if self.training:
for i in range(self.train_batch_size):
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])

gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_labels_i = self.cast(gt_labels_i, mstype.uint8)

gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
gt_valids_i = self.cast(gt_valids_i, mstype.bool_)

bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
gt_labels_i,
proposal_mask[i],
proposal[i][::, 0:4:1],
gt_valids_i)
bboxes_tuple += (bboxes,)
deltas_tuple += (deltas,)
labels_tuple += (labels,)
mask_tuple += (mask,)

bbox_targets = self.concat(deltas_tuple)
rcnn_labels = self.concat(labels_tuple)
bbox_targets = F.stop_gradient(bbox_targets)
rcnn_labels = F.stop_gradient(rcnn_labels)
rcnn_labels = self.cast(rcnn_labels, mstype.int32)
else:
mask_tuple += proposal_mask
bbox_targets = proposal_mask
rcnn_labels = proposal_mask
for p_i in proposal:
bboxes_tuple += (p_i[::, 0:4:1],)

if self.training:
if self.train_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
else:
if self.test_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
else:
bboxes_all = bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))


rois = self.cast(rois, mstype.float32)
rois = F.stop_gradient(rois)

if self.training:
roi_feats = self.roi_align(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
else:
roi_feats = self.roi_align_test(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))


roi_feats = self.cast(roi_feats, mstype.float16)
rcnn_masks = self.concat(mask_tuple)
rcnn_masks = F.stop_gradient(rcnn_masks)
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats,
bbox_targets,
rcnn_labels,
rcnn_mask_squeeze)

output = ()
if self.training:
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss)
else:
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas)

return output

def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas):
"""Get the actual detection box."""
scores = self.softmax(cls_logits)

boxes_all = ()
for i in range(self.num_classes):
k = i * 4
reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1])
out_boxes_i = self.decode(rois, reg_logits_i)
boxes_all += (out_boxes_i,)

img_metas_all = self.split(img_metas)
scores_all = self.split(scores)
mask_all = self.split(self.cast(mask_logits, mstype.int32))

boxes_all_with_batchsize = ()
for i in range(self.test_batch_size):
scale = self.split_shape(self.squeeze(img_metas_all[i]))
scale_h = scale[2]
scale_w = scale[3]
boxes_tuple = ()
for j in range(self.num_classes):
boxes_tmp = self.split(boxes_all[j])
out_boxes_h = boxes_tmp[i] / scale_h
out_boxes_w = boxes_tmp[i] / scale_w
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),)
boxes_all_with_batchsize += (boxes_tuple,)

output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all)

return output

def multiclass_nms(self, boxes_all, scores_all, mask_all):
"""Multiscale postprocessing."""
all_bboxes = ()
all_labels = ()
all_masks = ()

for i in range(self.test_batch_size):
bboxes = boxes_all[i]
scores = scores_all[i]
masks = self.cast(mask_all[i], mstype.bool_)

res_boxes_tuple = ()
res_labels_tuple = ()
res_masks_tuple = ()

for j in range(self.num_classes - 1):
k = j + 1
_cls_scores = scores[::, k:k + 1:1]
_bboxes = self.squeeze(bboxes[k])
_mask_o = self.reshape(masks, (self.rpn_max_num, 1))

cls_mask = self.greater(_cls_scores, self.test_score_thresh)
_mask = self.logicand(_mask_o, cls_mask)

_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_)

_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros)
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros)
__cls_scores = self.squeeze(_cls_scores)
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num)
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1))
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1))
_bboxes_sorted = self.gather(_bboxes, topk_inds)
_mask_sorted = self.gather(_mask, topk_inds)

scores_sorted = self.tile(scores_sorted, (1, 4))
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted))
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5))

cls_dets, _index, _mask_nms = self.nms_test(cls_dets)
_index = self.reshape(_index, (self.rpn_max_num, 1))
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1))

_mask_n = self.gather(_mask_sorted, _index)

_mask_n = self.logicand(_mask_n, _mask_nms)
cls_labels = self.oneslike(_index) * j
res_boxes_tuple += (cls_dets,)
res_labels_tuple += (cls_labels,)
res_masks_tuple += (_mask_n,)

res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start])
res_labels_start = self.concat(res_labels_tuple[:self.concat_start])
res_masks_start = self.concat(res_masks_tuple[:self.concat_start])

res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end])
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end])
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end])

res_boxes = self.concat((res_boxes_start, res_boxes_end))
res_labels = self.concat((res_labels_start, res_labels_end))
res_masks = self.concat((res_masks_start, res_masks_end))

reshape_size = (self.num_classes - 1) * self.rpn_max_num
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5))
res_labels = self.reshape(res_labels, (1, reshape_size, 1))
res_masks = self.reshape(res_masks, (1, reshape_size, 1))

all_bboxes += (res_boxes,)
all_labels += (res_labels,)
all_masks += (res_masks,)

all_bboxes = self.concat(all_bboxes)
all_labels = self.concat(all_labels)
all_masks = self.concat(all_masks)
return all_bboxes, all_labels, all_masks

def get_anchors(self, featmap_sizes):
"""Get anchors according to feature map sizes.

Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.

Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)

# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = ()
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors += (Tensor(anchors.astype(np.float16)),)

return multi_level_anchors

+ 114
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py View File

@@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn feature pyramid network."""

import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer

# pylint: disable=locally-disabled, missing-docstring

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

def bias_init_zeros(shape):
"""Bias init method."""
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))

def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
shape_bias = (out_channels,)
biass = bias_init_zeros(shape_bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)

class FeatPyramidNeck(nn.Cell):
"""
Feature pyramid network cell, usually uses as network neck.

Applies the convolution on multiple, input feature maps
and output feature map with same channel size. if required num of
output larger then num of inputs, add extra maxpooling for further
downsampling;

Args:
in_channels (tuple) - Channel size of input feature maps.
out_channels (int) - Channel size output.
num_outs (int) - Num of output features.

Returns:
Tuple, with tensors of same channel size.

Examples:
neck = FeatPyramidNeck([100,200,300], 50, 4)
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)),
dtype=np.float32) \
for i, c in enumerate(config.fpn_in_channels))
x = neck(input_data)
"""

def __init__(self,
in_channels,
out_channels,
num_outs):
super(FeatPyramidNeck, self).__init__()
self.num_outs = num_outs
self.in_channels = in_channels
self.fpn_layer = len(self.in_channels)

assert not self.num_outs < len(in_channels)

self.lateral_convs_list_ = []
self.fpn_convs_ = []

for _, channel in enumerate(in_channels):
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
self.lateral_convs_list_.append(l_conv)
self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
self.interpolate1 = P.ResizeNearestNeighbor((48, 80))
self.interpolate2 = P.ResizeNearestNeighbor((96, 160))
self.interpolate3 = P.ResizeNearestNeighbor((192, 320))
self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same")

def construct(self, inputs):
x = ()
for i in range(self.fpn_layer):
x += (self.lateral_convs_list[i](inputs[i]),)

y = (x[3],)
y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),)
y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),)
y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),)

z = ()
for i in range(self.fpn_layer - 1, -1, -1):
z = z + (y[i],)

outs = ()
for i in range(self.fpn_layer):
outs = outs + (self.fpn_convs_list[i](z[i]),)

for i in range(self.num_outs - self.fpn_layer):
outs = outs + (self.maxpool(outs[3]),)
return outs

+ 201
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/proposal_generator.py View File

@@ -0,0 +1,201 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn proposal generator."""

import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context

# pylint: disable=locally-disabled, invalid-name, missing-docstring


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Proposal(nn.Cell):
"""
Proposal subnet.

Args:
config (dict): Config.
batch_size (int): Batchsize.
num_classes (int) - Class number.
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).

Returns:
Tuple, tuple of output tensor,(proposal, mask).

Examples:
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
"""
def __init__(self,
config,
batch_size,
num_classes,
use_sigmoid_cls,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)
):
super(Proposal, self).__init__()
cfg = config
self.batch_size = batch_size
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls

if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else:
self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)

if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))

self.num_pre = cfg.rpn_proposal_nms_pre
self.min_box_size = cfg.rpn_proposal_min_bbox_size
self.nms_thr = cfg.rpn_proposal_nms_thr
self.nms_post = cfg.rpn_proposal_nms_post
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
self.max_num = cfg.rpn_proposal_max_num
self.num_levels = cfg.fpn_num_outs

# Op Define
self.squeeze = P.Squeeze()
self.reshape = P.Reshape()
self.cast = P.Cast()

self.feature_shapes = cfg.feature_shapes

self.transpose_shape = (1, 2, 0)

self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \
means=self.target_means, \
stds=self.target_stds)

self.nms = P.NMSWithMask(self.nms_thr)
self.concat_axis0 = P.Concat(axis=0)
self.concat_axis1 = P.Concat(axis=1)
self.split = P.Split(axis=1, output_num=5)
self.min = P.Minimum()
self.gatherND = P.GatherNd()
self.slice = P.Slice()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
self.tile = P.Tile()
self.set_train_local(config, training=True)

self.multi_10 = Tensor(10.0, mstype.float16)

def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training

cfg = config
self.topK_stage1 = ()
self.topK_shape = ()
total_max_topk_input = 0
if not self.training_local:
self.num_pre = cfg.rpn_nms_pre
self.min_box_size = cfg.rpn_min_bbox_min_size
self.nms_thr = cfg.rpn_nms_thr
self.nms_post = cfg.rpn_nms_post
self.nms_across_levels = cfg.rpn_nms_across_levels
self.max_num = cfg.rpn_max_num

for shp in self.feature_shapes:
k_num = min(self.num_pre, (shp[0] * shp[1] * 3))
total_max_topk_input += k_num
self.topK_stage1 += (k_num,)
self.topK_shape += ((k_num, 1),)

self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65536.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))

def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
masks_tuple = ()
for img_id in range(self.batch_size):
cls_score_list = ()
bbox_pred_list = ()
for i in range(self.num_levels):
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::])
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::])

cls_score_list = cls_score_list + (rpn_cls_score_i,)
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,)

proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list)
proposals_tuple += (proposals,)
masks_tuple += (masks,)
return proposals_tuple, masks_tuple

def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
"""Get proposal boundingbox."""
mlvl_proposals = ()
mlvl_mask = ()
for idx in range(self.num_levels):
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape)
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape)
anchors = mlvl_anchors[idx]

rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16)

rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)

scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])

topk_inds = self.reshape(topk_inds, self.topK_shape[idx])

bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)

proposals_decode = self.decode(anchors_sorted, bboxes_sorted)

proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
proposals, _, mask_valid = self.nms(proposals_decode)

mlvl_proposals = mlvl_proposals + (proposals,)
mlvl_mask = mlvl_mask + (mask_valid,)

proposals = self.concat_axis0(mlvl_proposals)
masks = self.concat_axis0(mlvl_mask)

_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
scores_using = self.select(masks, scores, topk_mask)

_, topk_inds = self.topKv2(scores_using, self.max_num)

topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
proposals = self.gatherND(proposals, topk_inds)
masks = self.gatherND(masks, topk_inds)
return proposals, masks

+ 173
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rcnn.py View File

@@ -0,0 +1,173 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn Rcnn network."""

import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter

# pylint: disable=locally-disabled, missing-docstring


class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()

self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16),
name="weight")
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias")

self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()

def construct(self, x):
output = self.bias_add(self.matmul(x, self.weight), self.bias)
return output


class Rcnn(nn.Cell):
"""
Rcnn subnet.

Args:
config (dict) - Config.
representation_size (int) - Channels of shared dense.
batch_size (int) - Batchsize.
num_classes (int) - Class number.
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).

Returns:
Tuple, tuple of output tensor.

Examples:
Rcnn(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
"""
def __init__(self,
config,
representation_size,
batch_size,
num_classes,
target_means=(0., 0., 0., 0.),
target_stds=(0.1, 0.1, 0.2, 0.2)
):
super(Rcnn, self).__init__()
cfg = config
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16))
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
self.target_means = target_means
self.target_stds = target_stds
self.num_classes = num_classes
self.in_channels = cfg.rcnn_in_channels
self.train_batch_size = batch_size
self.test_batch_size = cfg.test_batch_size

shape_0 = (self.rcnn_fc_out_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor()
shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor()
self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1)

cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1],
dtype=mstype.float16).to_tensor()
reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1],
dtype=mstype.float16).to_tensor()
self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight)

self.flatten = P.Flatten()
self.relu = P.ReLU()
self.logicaland = P.LogicalAnd()
self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(beta=1.0)
self.reshape = P.Reshape()
self.onehot = P.OneHot()
self.greater = P.Greater()
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()

self.gather = P.GatherNd()
self.argmax = P.ArgMaxWithValue(axis=1)

self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.value = Tensor(1.0, mstype.float16)

self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size

rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))

self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size

range_max = np.arange(self.num_bboxes_test).astype(np.int32)
self.range_max = Tensor(range_max)

def construct(self, featuremap, bbox_targets, labels, mask):
x = self.flatten(featuremap)

x = self.relu(self.shared_fc_0(x))
x = self.relu(self.shared_fc_1(x))

x_cls = self.cls_scores(x)
x_reg = self.reg_scores(x)

if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))

loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)
out = (loss, loss_cls, loss_reg, loss_print)
else:
out = (x_cls, (x_cls / self.value), x_reg, x_cls)

return out

def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights):
"""Loss method."""
loss_print = ()
loss_cls, _ = self.loss_cls(cls_score, labels)

weights = self.cast(weights, mstype.float16)
loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))

bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
bbox_weights = bbox_weights * self.rmv_first_tensor

pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
loss_reg = self.sum_loss(loss_reg, (2,))
loss_reg = loss_reg * bbox_weights
loss_reg = loss_reg / self.sum_loss(weights, (0,))
loss_reg = self.sum_loss(loss_reg, (0, 1))

loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg
loss_print += (loss_cls, loss_reg)

return loss, loss_cls, loss_reg, loss_print

+ 250
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/resnet50.py View File

@@ -0,0 +1,250 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Resnet50 backbone."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore import context
# pylint: disable=locally-disabled, invalid-name, missing-docstring
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def weight_init_ones(shape):
"""Weight init."""
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = weight_init_ones(shape)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
class ResNetFea(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
weights_update (bool): Weight update flag.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> False)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
weights_update=False):
super(ResNetFea, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
bn_training = False
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training)
self.relu = P.ReLU()
self.maxpool = P.MaxPool(ksize=3, strides=2, padding="SAME")
self.weights_update = weights_update
if not self.weights_update:
self.conv1.weight.requires_grad = False
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=1,
training=bn_training,
weights_update=self.weights_update)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=2,
training=bn_training,
weights_update=True)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=2,
training=bn_training,
weights_update=True)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=2,
training=bn_training,
weights_update=True)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False):
"""Make block layer."""
layers = []
down_sample = False
if stride != 1 or in_channel != out_channel:
down_sample = True
resblk = block(in_channel,
out_channel,
stride=stride,
down_sample=down_sample,
training=training,
weights_update=weights_update)
layers.append(resblk)
for _ in range(1, layer_num):
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update)
layers.append(resblk)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
identity = c2
if not self.weights_update:
identity = F.stop_gradient(c2)
c3 = self.layer2(identity)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return identity, c3, c4, c5
class ResidualBlockUsing(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channels (int) - Input channel.
out_channels (int) - Output channel.
stride (int) - Stride size for the initial convolutional layer. Default: 1.
down_sample (bool) - If to do the downsample in block. Default: False.
momentum (float) - Momentum for batchnorm layer. Default: 0.1.
training (bool) - Training flag. Default: False.
weights_updata (bool) - Weights update flag. Default: False.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3,256,stride=2,down_sample=True)
"""
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1,
down_sample=False,
momentum=0.1,
training=False,
weights_update=False):
super(ResidualBlockUsing, self).__init__()
self.affine = weights_update
out_chls = out_channels // self.expansion
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0)
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1)
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0)
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training)
if training:
self.bn1 = self.bn1.set_train()
self.bn2 = self.bn2.set_train()
self.bn3 = self.bn3.set_train()
if not weights_update:
self.conv1.weight.requires_grad = False
self.conv2.weight.requires_grad = False
self.conv3.weight.requires_grad = False
self.relu = P.ReLU()
self.downsample = down_sample
if self.downsample:
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine,
use_batch_statistics=training)
if training:
self.bn_down_sample = self.bn_down_sample.set_train()
if not weights_update:
self.conv_down_sample.weight.requires_grad = False
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
out = self.add(out, identity)
out = self.relu(out)
return out

+ 181
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/roi_align.py View File

@@ -0,0 +1,181 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn ROIAlign module."""

import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.nn import layer as L
from mindspore.common.tensor import Tensor

# pylint: disable=locally-disabled, invalid-name, missing-docstring


class ROIAlign(nn.Cell):
"""
Extract RoI features from mulitple feature map.

Args:
out_size_h (int) - RoI height.
out_size_w (int) - RoI width.
spatial_scale (int) - RoI spatial scale.
sample_num (int) - RoI sample number.
"""
def __init__(self,
out_size_h,
out_size_w,
spatial_scale,
sample_num=0):
super(ROIAlign, self).__init__()

self.out_size = (out_size_h, out_size_w)
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1],
self.spatial_scale, self.sample_num)

def construct(self, features, rois):
return self.align_op(features, rois)

def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
self.out_size, self.spatial_scale, self.sample_num)
return format_str


class SingleRoIExtractor(nn.Cell):
"""
Extract RoI features from a single level feature map.

If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.

Args:
config (dict): Config
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
batch_size (int): Batchsize.
finest_scale (int): Scale threshold of mapping to level 0.
"""

def __init__(self,
config,
roi_layer,
out_channels,
featmap_strides,
batch_size=1,
finest_scale=56):
super(SingleRoIExtractor, self).__init__()
cfg = config
self.train_batch_size = batch_size
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.out_size = roi_layer['out_size']
self.sample_num = roi_layer['sample_num']
self.roi_layers = self.build_roi_layers(self.featmap_strides)
self.roi_layers = L.CellList(self.roi_layers)

self.sqrt = P.Sqrt()
self.log = P.Log()
self.finest_scale_ = finest_scale
self.clamp = C.clip_by_value

self.cast = P.Cast()
self.equal = P.Equal()
self.select = P.Select()

_mode_16 = False
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32
self.set_train_local(cfg, training=True)

def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training

cfg = config
# Init tensor
self.batch_size = cfg.roi_sample_num if self.training_local else cfg.rpn_max_num
self.batch_size = self.train_batch_size*self.batch_size \
if self.training_local else cfg.test_batch_size*self.batch_size
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype))
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_
self.finest_scale = Tensor(finest_scale)
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6))
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32))
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1))
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2)
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels,
self.out_size, self.out_size)), dtype=self.dtype))
def num_inputs(self):
return len(self.featmap_strides)

def init_weights(self):
pass

def log2(self, value):
return self.log(value) / self.log(self.twos)

def build_roi_layers(self, featmap_strides):
roi_layers = []
for s in featmap_strides:
layer_cls = ROIAlign(self.out_size, self.out_size,
spatial_scale=1 / s,
sample_num=self.sample_num)
roi_layers.append(layer_cls)
return roi_layers

def _c_map_roi_levels(self, rois):
"""Map rois to corresponding feature levels by scales.

- scale < finest_scale * 2: level 0
- finest_scale * 2 <= scale < finest_scale * 4: level 1
- finest_scale * 4 <= scale < finest_scale * 8: level 2
- scale >= finest_scale * 8: level 3

Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.

Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones)

target_lvls = self.log2(scale / self.finest_scale + self.epslion)
target_lvls = P.Floor()(target_lvls)
target_lvls = self.cast(target_lvls, mstype.int32)
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels)

return target_lvls

def construct(self, rois, feat1, feat2, feat3, feat4):
feats = (feat1, feat2, feat3, feat4)
res = self.res_
target_lvls = self._c_map_roi_levels(rois)
for i in range(self.num_levels):
mask = self.equal(target_lvls, P.ScalarToArray()(i))
mask = P.Reshape()(mask, (-1, 1, 1, 1))
roi_feats_t = self.roi_layers[i](feats[i], rois)
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, 7, 7)), mstype.bool_)
res = self.select(mask, roi_feats_t, res)

return res

+ 315
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/FasterRcnn/rpn.py View File

@@ -0,0 +1,315 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RPN for fasterRCNN"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from .bbox_assign_sample import BboxAssignSample

# pylint: disable=locally-disabled, invalid-name, missing-docstring

# pylint: disable=locally-disabled, invalid-name, missing-docstring


class RpnRegClsBlock(nn.Cell):
"""
Rpn reg cls block for rpn layer

Args:
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
weight_conv (Tensor) - weight init for rpn conv.
bias_conv (Tensor) - bias init for rpn conv.
weight_cls (Tensor) - weight init for rpn cls conv.
bias_cls (Tensor) - bias init for rpn cls conv.
weight_reg (Tensor) - weight init for rpn reg conv.
bias_reg (Tensor) - bias init for rpn reg conv.

Returns:
Tensor, output tensor.
"""
def __init__(self,
in_channels,
feat_channels,
num_anchors,
cls_out_channels,
weight_conv,
bias_conv,
weight_cls,
bias_cls,
weight_reg,
bias_reg):
super(RpnRegClsBlock, self).__init__()
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same',
has_bias=True, weight_init=weight_conv, bias_init=bias_conv)
self.relu = nn.ReLU()

self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_cls, bias_init=bias_cls)
self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_reg, bias_init=bias_reg)

def construct(self, x):
x = self.relu(self.rpn_conv(x))

x1 = self.rpn_cls(x)
x2 = self.rpn_reg(x)

return x1, x2


class RPN(nn.Cell):
"""
ROI proposal network..

Args:
config (dict) - Config.
batch_size (int) - Batchsize.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.

Returns:
Tuple, tuple of output tensor.

Examples:
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
num_anchors=3, cls_out_channels=512)
"""
def __init__(self,
config,
batch_size,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = ()
self.feature_anchor_shape = ()
self.slice_index += (0,)
index = 0
for shape in cfg_rpn.feature_shapes:
self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,)
self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,)
index += 1

self.num_anchors = num_anchors
self.batch_size = batch_size
self.test_batch_size = cfg_rpn.test_batch_size
self.num_layers = 5
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))

self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels,
num_anchors, cls_out_channels))

self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=0)
self.fill = P.Fill()
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))

self.trans_shape = (0, 2, 3, 1)

self.reshape_shape_reg = (-1, 4)
self.reshape_shape_cls = (-1,)
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
self.num_bboxes = cfg_rpn.num_bboxes
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
self.CheckValid = P.CheckValid()
self.sum_loss = P.ReduceSum()
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))

def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
"""
make rpn layer for rpn proposal network

Args:
num_layers (int) - layer num.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.

Returns:
List, list of RpnRegClsBlock cells.
"""
rpn_layer = []

shp_weight_conv = (feat_channels, in_channels, 3, 3)
shp_bias_conv = (feat_channels,)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor()
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor()

shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
shp_bias_cls = (num_anchors * cls_out_channels,)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor()
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor()

shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
shp_bias_reg = (num_anchors * 4,)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor()
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor()

for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg))

for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight

rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias
rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias
rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias

return rpn_layer

def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
loss_print = ()
rpn_cls_score = ()
rpn_bbox_pred = ()
rpn_cls_score_total = ()
rpn_bbox_pred_total = ()

for i in range(self.num_layers):
x1, x2 = self.rpn_convs_list[i](inputs[i])

rpn_cls_score_total = rpn_cls_score_total + (x1,)
rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,)

x1 = self.transpose(x1, self.trans_shape)
x1 = self.reshape(x1, self.reshape_shape_cls)

x2 = self.transpose(x2, self.trans_shape)
x2 = self.reshape(x2, self.reshape_shape_reg)

rpn_cls_score = rpn_cls_score + (x1,)
rpn_bbox_pred = rpn_bbox_pred + (x2,)

loss = self.loss
clsloss = self.clsloss
regloss = self.regloss
bbox_targets = ()
bbox_weights = ()
labels = ()
label_weights = ()

output = ()
if self.training:
for i in range(self.batch_size):
multi_level_flags = ()
anchor_list_tuple = ()

for j in range(self.num_layers):
res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])),
mstype.int32)
multi_level_flags = multi_level_flags + (res,)
anchor_list_tuple = anchor_list_tuple + (anchor_list[j],)

valid_flag_list = self.concat(multi_level_flags)
anchor_using_list = self.concat(anchor_list_tuple)

gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])

bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
gt_labels_i,
self.cast(valid_flag_list,
mstype.bool_),
anchor_using_list, gt_valids_i)

bbox_weight = self.cast(bbox_weight, mstype.float16)
label = self.cast(label, mstype.float16)
label_weight = self.cast(label_weight, mstype.float16)

for j in range(self.num_layers):
begin = self.slice_index[j]
end = self.slice_index[j + 1]
stride = 1
bbox_targets += (bbox_target[begin:end:stride, ::],)
bbox_weights += (bbox_weight[begin:end:stride],)
labels += (label[begin:end:stride],)
label_weights += (label_weight[begin:end:stride],)

for i in range(self.num_layers):
bbox_target_using = ()
bbox_weight_using = ()
label_using = ()
label_weight_using = ()

for j in range(self.batch_size):
bbox_target_using += (bbox_targets[i + (self.num_layers * j)],)
bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],)
label_using += (labels[i + (self.num_layers * j)],)
label_weight_using += (label_weights[i + (self.num_layers * j)],)

bbox_target_with_batchsize = self.concat(bbox_target_using)
bbox_weight_with_batchsize = self.concat(bbox_weight_using)
label_with_batchsize = self.concat(label_using)
label_weight_with_batchsize = self.concat(label_weight_using)

# stop
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize)

cls_score_i = rpn_cls_score[i]
reg_score_i = rpn_bbox_pred[i]

loss_cls = self.loss_cls(cls_score_i, label_)
loss_cls_item = loss_cls * label_weight_
loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total

loss_reg = self.loss_bbox(reg_score_i, bbox_target_)
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4))
loss_reg = loss_reg * bbox_weight_
loss_reg_item = self.sum_loss(loss_reg, (1,))
loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total

loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item

loss += loss_total
loss_print += (loss_total, loss_cls_item, loss_reg_item)
clsloss += loss_cls_item
regloss += loss_reg_item

output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print)
else:
output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1)

return output

+ 158
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/config.py View File

@@ -0,0 +1,158 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed

config = ed({
"img_width": 1280,
"img_height": 768,
"keep_ratio": False,
"flip_ratio": 0.5,
"photo_ratio": 0.5,
"expand_ratio": 1.0,

# anchor
"feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)],
"anchor_scales": [8],
"anchor_ratios": [0.5, 1.0, 2.0],
"anchor_strides": [4, 8, 16, 32, 64],
"num_anchors": 3,

# resnet
"resnet_block": [3, 4, 6, 3],
"resnet_in_channels": [64, 256, 512, 1024],
"resnet_out_channels": [256, 512, 1024, 2048],

# fpn
"fpn_in_channels": [256, 512, 1024, 2048],
"fpn_out_channels": 256,
"fpn_num_outs": 5,

# rpn
"rpn_in_channels": 256,
"rpn_feat_channels": 256,
"rpn_loss_cls_weight": 1.0,
"rpn_loss_reg_weight": 1.0,
"rpn_cls_out_channels": 1,
"rpn_target_means": [0., 0., 0., 0.],
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0],

# bbox_assign_sampler
"neg_iou_thr": 0.3,
"pos_iou_thr": 0.7,
"min_pos_iou": 0.3,
"num_bboxes": 245520,
"num_gts": 128,
"num_expected_neg": 256,
"num_expected_pos": 128,

# proposal
"activate_num_classes": 2,
"use_sigmoid_cls": True,

# roi_align
"roi_layer": dict(type='RoIAlign', out_size=7, sample_num=2),
"roi_align_out_channels": 256,
"roi_align_featmap_strides": [4, 8, 16, 32],
"roi_align_finest_scale": 56,
"roi_sample_num": 640,

# bbox_assign_sampler_stage2
"neg_iou_thr_stage2": 0.5,
"pos_iou_thr_stage2": 0.5,
"min_pos_iou_stage2": 0.5,
"num_bboxes_stage2": 2000,
"num_expected_pos_stage2": 128,
"num_expected_neg_stage2": 512,
"num_expected_total_stage2": 512,

# rcnn
"rcnn_num_layers": 2,
"rcnn_in_channels": 256,
"rcnn_fc_out_channels": 1024,
"rcnn_loss_cls_weight": 1,
"rcnn_loss_reg_weight": 1,
"rcnn_target_means": [0., 0., 0., 0.],
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2],

# train proposal
"rpn_proposal_nms_across_levels": False,
"rpn_proposal_nms_pre": 2000,
"rpn_proposal_nms_post": 2000,
"rpn_proposal_max_num": 2000,
"rpn_proposal_nms_thr": 0.7,
"rpn_proposal_min_bbox_size": 0,

# test proposal
"rpn_nms_across_levels": False,
"rpn_nms_pre": 1000,
"rpn_nms_post": 1000,
"rpn_max_num": 1000,
"rpn_nms_thr": 0.7,
"rpn_min_bbox_min_size": 0,
"test_score_thr": 0.05,
"test_iou_thr": 0.5,
"test_max_per_img": 100,
"test_batch_size": 1,

"rpn_head_loss_type": "CrossEntropyLoss",
"rpn_head_use_sigmoid": True,
"rpn_head_weight": 1.0,

# LR
"base_lr": 0.02,
"base_step": 58633,
"total_epoch": 13,
"warmup_step": 500,
"warmup_mode": "linear",
"warmup_ratio": 1/3.0,
"sgd_step": [8, 11],
"sgd_momentum": 0.9,

# train
"batch_size": 1,
"loss_scale": 1,
"momentum": 0.91,
"weight_decay": 1e-4,
"epoch_size": 12,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",

"mindrecord_dir": "../MindRecord_COCO_TRAIN",
"coco_root": "./cocodataset/",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instance_set": "annotations/instances_{}.json",
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81
})

+ 505
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/dataset.py View File

@@ -0,0 +1,505 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""FasterRcnn dataset"""
from __future__ import division

import os
import numpy as np
from numpy import random

import mmcv
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as CC
import mindspore.common.dtype as mstype
from mindspore.mindrecord import FileWriter
from src.config import config

# pylint: disable=locally-disabled, unused-variable


def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
"""Calculate the ious between each bbox of bboxes1 and bboxes2.

Args:
bboxes1(ndarray): shape (n, 4)
bboxes2(ndarray): shape (k, 4)
mode(str): iou (intersection over union) or iof (intersection
over foreground)

Returns:
ious(ndarray): shape (n, k)
"""

assert mode in ['iou', 'iof']

bboxes1 = bboxes1.astype(np.float32)
bboxes2 = bboxes2.astype(np.float32)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
ious = np.zeros((rows, cols), dtype=np.float32)
if rows * cols == 0:
return ious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
ious = np.zeros((cols, rows), dtype=np.float32)
exchange = True
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1)
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1)
for i in range(bboxes1.shape[0]):
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
y_end - y_start + 1, 0)
if mode == 'iou':
union = area1[i] + area2 - overlap
else:
union = area1[i] if not exchange else area2
ious[i, :] = overlap / union
if exchange:
ious = ious.T
return ious


class PhotoMetricDistortion:
"""Photo Metric Distortion"""

def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta

def __call__(self, img, boxes, labels):
# random brightness
img = img.astype('float32')

if random.randint(2):
delta = random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta

# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha

# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)

# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)

# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360

# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)

# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha

# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]

return img, boxes, labels


class Expand:
"""expand image"""

def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range

def __call__(self, img, boxes, labels):
if random.randint(2):
return img, boxes, labels

h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
img = expand_img
boxes += np.tile((left, top), 2)
return img, boxes, labels


def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""rescale operation for image"""
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
if img_data.shape[0] > config.img_height:
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True)
scale_factor = scale_factor * scale_factor2
img_shape = np.append(img_shape, scale_factor)
img_shape = np.asarray(img_shape, dtype=np.float32)
gt_bboxes = gt_bboxes * scale_factor

gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)

return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""resize operation for image"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = (config.img_height, config.img_width, 1.0)
img_shape = np.asarray(img_shape, dtype=np.float32)

gt_bboxes = gt_bboxes * scale_factor

gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)

return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
"""resize operation for image of eval"""
img_data = img
img_data, w_scale, h_scale = mmcv.imresize(
img_data, (config.img_width, config.img_height), return_scale=True)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
img_shape = np.append(img_shape, (h_scale, w_scale))
img_shape = np.asarray(img_shape, dtype=np.float32)

gt_bboxes = gt_bboxes * scale_factor

gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)

return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""impad operation for image"""
img_data = mmcv.impad(img, (config.img_height, config.img_width))
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""imnormalize operation for image"""
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""flip operation for image"""
img_data = img
img_data = mmcv.imflip(img_data)
flipped = gt_bboxes.copy()
_, w, _ = img_data.shape

flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1

return (img_data, img_shape, flipped, gt_label, gt_num)


def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num):
"""flipped generation"""
img_data = img
flipped = gt_bboxes.copy()
_, w, _ = img_data.shape

flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1

return (img_data, img_shape, flipped, gt_label, gt_num)


def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num):
img_data = img[:, :, ::-1]
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""transpose operation for image"""
img_data = img.transpose(2, 0, 1).copy()
img_data = img_data.astype(np.float16)
img_shape = img_shape.astype(np.float16)
gt_bboxes = gt_bboxes.astype(np.float16)
gt_label = gt_label.astype(np.int32)
gt_num = gt_num.astype(np.bool)

return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""photo crop operation for image"""
random_photo = PhotoMetricDistortion()
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)

return (img_data, img_shape, gt_bboxes, gt_label, gt_num)


def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""expand operation for image"""
expand = Expand()
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)

return (img, img_shape, gt_bboxes, gt_label, gt_num)


def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""

def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
image_shape = image_shape[:2]
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert

if config.keep_ratio:
input_data = rescale_column(*input_data)
else:
input_data = resize_column_test(*input_data)

input_data = image_bgr_rgb(*input_data)

output_data = input_data
return output_data

def _data_aug(image, box, is_training):
"""Data augmentation function."""
image_bgr = image.copy()
image_bgr[:, :, 0] = image[:, :, 2]
image_bgr[:, :, 1] = image[:, :, 1]
image_bgr[:, :, 2] = image[:, :, 0]
image_shape = image_bgr.shape[:2]
gt_box = box[:, :4]
gt_label = box[:, 4]
gt_iscrowd = box[:, 5]

pad_max_number = 128
gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0)
gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1)
gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1)
gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)

if not is_training:
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)

input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert

if config.keep_ratio:
input_data = rescale_column(*input_data)
else:
input_data = resize_column(*input_data)

input_data = image_bgr_rgb(*input_data)

output_data = input_data
return output_data

return _data_aug(image, box, is_training)


def create_coco_label(is_training):
"""Get image path and annotation from COCO."""
from pycocotools.coco import COCO

coco_root = config.coco_root
data_type = config.val_data_type
if is_training:
data_type = config.train_data_type

# Classes need to train or test.
train_cls = config.coco_classes
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i

anno_json = os.path.join(coco_root, config.instance_set.format(data_type))

coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())
for cat in cat_ids:
classs_dict[cat["id"]] = cat["name"]

image_ids = coco.getImgIds()
image_files = []
image_anno_dict = {}

for img_id in image_ids:
image_info = coco.loadImgs(img_id)
file_name = image_info[0]["file_name"]
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = coco.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name)
annos = []
for label in anno:
bbox = label["bbox"]
class_name = classs_dict[label["category_id"]]
if class_name in train_cls:
x1, x2 = bbox[0], bbox[0] + bbox[2]
y1, y2 = bbox[1], bbox[1] + bbox[3]
annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])])

image_files.append(image_path)
if annos:
image_anno_dict[image_path] = np.array(annos)
else:
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])

return image_files, image_anno_dict


def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos


def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
image_files = []
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")

with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
image_path = os.path.join(image_dir, file_name)
if os.path.isfile(image_path):
image_anno_dict[image_path] = anno_parser(line_split[1:])
image_files.append(image_path)
return image_files, image_anno_dict


def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
"""Create MindRecord file."""
mindrecord_dir = config.mindrecord_dir
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
if dataset == "coco":
image_files, image_anno_dict = create_coco_label(is_training)
else:
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)

fasterrcnn_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 6]},
}
writer.add_schema(fasterrcnn_json, "fasterrcnn_json")

for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()


def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=4):
"""Creatr FasterRcnn dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=1, shuffle=False)
decode = C.Decode()
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))

hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
horizontally_op = C.RandomHorizontalFlip(1)
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_)

if is_training:
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)

flip = (np.random.rand() < config.flip_ratio)
if flip:
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
num_parallel_workers=12)
ds = ds.map(operations=flipped_generation,
input_columns=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)
else:
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
num_parallel_workers=12)
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=12)

else:
ds = ds.map(operations=compose_map_func,
input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
column_order=["image", "image_shape", "box", "label", "valid_num"],
num_parallel_workers=num_parallel_workers)

ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=24)

# transpose_column from python to c
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
ds = ds.map(operations=[type_cast1], input_columns=["box"])
ds = ds.map(operations=[type_cast2], input_columns=["label"])
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)

return ds

+ 42
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/lr_schedule.py View File

@@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for fasterrcnn"""
import math

def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate

def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate

def dynamic_lr(config, rank_size=1):
"""dynamic learning rate generator"""
base_lr = config.base_lr

base_step = (config.base_step // rank_size) + rank_size
total_steps = int(base_step * config.total_epoch)
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))

return lr

+ 184
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/network_define.py View File

@@ -0,0 +1,184 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn training network wrapper."""

import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer

# pylint: disable=locally-disabled, missing-docstring, unused-argument


time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.

If the loss is NAN or INF terminating training.

Note:
If per_print_times is 0 do not print loss.

Args:
per_print_times (int): Print loss every times. Default: 1.
"""

def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rank_id = rank_id

global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True

def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()

rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()

self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)

cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1

if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()

rpn_loss = self.rpn_loss_sum/self.count
rcnn_loss = self.rcnn_loss_sum/self.count
rpn_cls_loss = self.rpn_cls_loss_sum/self.count

rpn_reg_loss = self.rpn_reg_loss_sum/self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count

total_loss = rpn_loss + rcnn_loss

loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, total_loss))
loss_file.write("\n")
loss_file.close()

self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0

class LossNet(nn.Cell):
"""FasterRcnn loss method"""
def construct(self, x1, x2, x3, x4, x5, x6):
return x1 + x2

class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.

Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn

def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6)

@property
def backbone_network(self):
"""
Get the backbone network.

Returns:
Cell, return backbone network.
"""
return self._backbone


class TrainOneStepCell(nn.Cell):
"""
Network training package class.

Append an optimizer to the training network after that the construct function
can be called to create the backward graph.

Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
self.reduce_flag = reduce_flag
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)

def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6 = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6

+ 227
- 0
examples/model_security/model_attacks/cv/faster_rcnn/src/util.py View File

@@ -0,0 +1,227 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""coco eval for fasterrcnn"""
import json
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import mmcv

# pylint: disable=locally-disabled, invalid-name

_init_value = np.array(0.0)
summary_init = {
'Precision/mAP': _init_value,
'Precision/mAP@.50IOU': _init_value,
'Precision/mAP@.75IOU': _init_value,
'Precision/mAP (small)': _init_value,
'Precision/mAP (medium)': _init_value,
'Precision/mAP (large)': _init_value,
'Recall/AR@1': _init_value,
'Recall/AR@10': _init_value,
'Recall/AR@100': _init_value,
'Recall/AR@100 (small)': _init_value,
'Recall/AR@100 (medium)': _init_value,
'Recall/AR@100 (large)': _init_value,
}


def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False):
"""coco eval for fasterrcnn"""
anns = json.load(open(result_files['bbox']))
if not anns:
return summary_init

if mmcv.is_str(coco):
coco = COCO(coco)
assert isinstance(coco, COCO)

for res_type in result_types:
result_file = result_files[res_type]
assert result_file.endswith('.json')

coco_dets = coco.loadRes(result_file)
gt_img_ids = coco.getImgIds()
det_img_ids = coco_dets.getImgIds()
iou_type = 'bbox' if res_type == 'proposal' else res_type
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)

tgt_ids = gt_img_ids if not single_result else det_img_ids

if single_result:
res_dict = dict()
for id_i in tgt_ids:
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)

cocoEval.params.imgIds = [id_i]
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]})

cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)

cocoEval.params.imgIds = tgt_ids
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()

summary_metrics = {
'Precision/mAP': cocoEval.stats[0],
'Precision/mAP@.50IOU': cocoEval.stats[1],
'Precision/mAP@.75IOU': cocoEval.stats[2],
'Precision/mAP (small)': cocoEval.stats[3],
'Precision/mAP (medium)': cocoEval.stats[4],
'Precision/mAP (large)': cocoEval.stats[5],
'Recall/AR@1': cocoEval.stats[6],
'Recall/AR@10': cocoEval.stats[7],
'Recall/AR@100': cocoEval.stats[8],
'Recall/AR@100 (small)': cocoEval.stats[9],
'Recall/AR@100 (medium)': cocoEval.stats[10],
'Recall/AR@100 (large)': cocoEval.stats[11],
}

return summary_metrics


def xyxy2xywh(bbox):
_bbox = bbox.tolist()
return [
_bbox[0],
_bbox[1],
_bbox[2] - _bbox[0] + 1,
_bbox[3] - _bbox[1] + 1,
]

def bbox2result_1image(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.

Args:
bboxes (Tensor): shape (n, 5)
labels (Tensor): shape (n, )
num_classes (int): class number, including background class

Returns:
list(ndarray): bbox results of each class
"""
if bboxes.shape[0] == 0:
result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)]
else:
result = [bboxes[labels == i, :] for i in range(num_classes - 1)]
return result

def proposal2json(dataset, results):
"""convert proposal to json mode"""
img_ids = dataset.getImgIds()
json_results = []
dataset_len = dataset.get_dataset_size()*2
for idx in range(dataset_len):
img_id = img_ids[idx]
bboxes = results[idx]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = 1
json_results.append(data)
return json_results

def det2json(dataset, results):
"""convert det to json mode"""
cat_ids = dataset.getCatIds()
img_ids = dataset.getImgIds()
json_results = []
dataset_len = len(img_ids)
for idx in range(dataset_len):
img_id = img_ids[idx]
if idx == len(results): break
result = results[idx]
for label, result_label in enumerate(result):
bboxes = result_label
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = cat_ids[label]
json_results.append(data)
return json_results

def segm2json(dataset, results):
"""convert segm to json mode"""
bbox_json_results = []
segm_json_results = []
for idx in range(len(dataset)):
img_id = dataset.img_ids[idx]
det, seg = results[idx]
for label, det_label in enumerate(det):
# bbox results
bboxes = det_label
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = dataset.cat_ids[label]
bbox_json_results.append(data)

if len(seg) == 2:
segms = seg[0][label]
mask_score = seg[1][label]
else:
segms = seg[label]
mask_score = [bbox[4] for bbox in bboxes]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['score'] = float(mask_score[i])
data['category_id'] = dataset.cat_ids[label]
segms[i]['counts'] = segms[i]['counts'].decode()
data['segmentation'] = segms[i]
segm_json_results.append(data)
return bbox_json_results, segm_json_results

def results2json(dataset, results, out_file):
"""convert result convert to json mode"""
result_files = dict()
if isinstance(results[0], list):
json_results = det2json(dataset, results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
mmcv.dump(json_results, result_files['bbox'])
elif isinstance(results[0], tuple):
json_results = segm2json(dataset, results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['segm'] = '{}.{}.json'.format(out_file, 'segm')
mmcv.dump(json_results[0], result_files['bbox'])
mmcv.dump(json_results[1], result_files['segm'])
elif isinstance(results[0], np.ndarray):
json_results = proposal2json(dataset, results)
result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal')
mmcv.dump(json_results, result_files['proposal'])
else:
raise TypeError('invalid type of results')
return result_files

+ 55
- 26
mindarmour/adv_robustness/attacks/gradient_method.py View File

@@ -19,7 +19,7 @@ from abc import abstractmethod
import numpy as np

from mindspore import Tensor
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
from mindspore.nn import Cell

from mindarmour.utils.util import WithLossCell, GradWrapWithLoss
from mindarmour.utils.logger import LogUtil
@@ -44,12 +44,13 @@ class GradientMethod(Attack):
Default: None.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: None.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientMethod(network)
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -71,9 +72,10 @@ class GradientMethod(Attack):
else:
self._alpha = alpha
if loss_fn is None:
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False)
with_loss_cell = WithLossCell(self._network, loss_fn)
self._grad_all = GradWrapWithLoss(with_loss_cell)
self._grad_all = self._network
else:
with_loss_cell = WithLossCell(self._network, loss_fn)
self._grad_all = GradWrapWithLoss(with_loss_cell)
self._grad_all.set_train()

def generate(self, inputs, labels):
@@ -83,13 +85,19 @@ class GradientMethod(Attack):
Args:
inputs (numpy.ndarray): Benign input samples used as references to create
adversarial examples.
labels (numpy.ndarray): Original/target labels.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.

Returns:
numpy.ndarray, generated adversarial examples.
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels[{}]'.format(i), labels_item)
else:
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels', labels)
self._dtype = inputs.dtype
gradient = self._gradient(inputs, labels)
# use random method or not
@@ -117,7 +125,8 @@ class GradientMethod(Attack):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
labels (numpy.ndarray): Original/target labels.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.

Raises:
NotImplementedError: It is an abstract method.
@@ -149,12 +158,13 @@ class FastGradientMethod(GradientMethod):
Possible values: np.inf, 1 or 2. Default: 2.
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientMethod(network)
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -175,12 +185,19 @@ class FastGradientMethod(GradientMethod):

Args:
inputs (numpy.ndarray): Input sample.
labels (numpy.ndarray): Original/target label.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.

Returns:
numpy.ndarray, gradient of inputs.
"""
out_grad = self._grad_all(Tensor(inputs), Tensor(labels))
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
out_grad = self._grad_all(Tensor(inputs), *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
gradient = out_grad.asnumpy()
@@ -210,7 +227,8 @@ class RandomFastGradientMethod(FastGradientMethod):
Possible values: np.inf, 1 or 2. Default: 2.
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

Raises:
ValueError: eps is smaller than alpha!
@@ -218,7 +236,7 @@ class RandomFastGradientMethod(FastGradientMethod):
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomFastGradientMethod(network)
>>> attack = RandomFastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -254,12 +272,13 @@ class FastGradientSignMethod(GradientMethod):
In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientSignMethod(network)
>>> attack = FastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -279,12 +298,19 @@ class FastGradientSignMethod(GradientMethod):

Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Original/target labels.
labels (union[numpy.ndarray, tuple]): original/target labels. \
for each input if it has more than one label, it is wrapped in a tuple.

Returns:
numpy.ndarray, gradient of inputs.
"""
out_grad = self._grad_all(Tensor(inputs), Tensor(labels))
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
out_grad = self._grad_all(Tensor(inputs), *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
gradient = out_grad.asnumpy()
@@ -311,7 +337,8 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): True: targeted attack. False: untargeted attack.
Default: False.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

Raises:
ValueError: eps is smaller than alpha!
@@ -319,7 +346,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomFastGradientSignMethod(network)
>>> attack = RandomFastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -350,12 +377,13 @@ class LeastLikelyClassMethod(FastGradientSignMethod):
Default: None.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = LeastLikelyClassMethod(network)
>>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -384,7 +412,8 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
Default: 0.035.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
loss_fn (Loss): Loss function for optimization.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.

Raises:
ValueError: eps is smaller than alpha!
@@ -392,7 +421,7 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomLeastLikelyClassMethod(network)
>>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""



+ 60
- 27
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -17,7 +17,7 @@ from abc import abstractmethod
import numpy as np
from PIL import Image, ImageOps

from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
from mindspore.nn import Cell
from mindspore import Tensor

from mindarmour.utils.logger import LogUtil
@@ -114,7 +114,8 @@ class IterativeGradientMethod(Attack):
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
nb_iter (int): Number of iteration. Default: 5.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5,
loss_fn=None):
@@ -123,12 +124,15 @@ class IterativeGradientMethod(Attack):
self._eps = check_value_positive('eps', eps)
self._eps_iter = check_value_positive('eps_iter', eps_iter)
self._nb_iter = check_int_positive('nb_iter', nb_iter)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._bounds = None
if bounds is not None:
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
if loss_fn is None:
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False)
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn))
self._loss_grad = network
else:
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn))
self._loss_grad.set_train()

@abstractmethod
@@ -139,8 +143,8 @@ class IterativeGradientMethod(Attack):
Args:
inputs (numpy.ndarray): Benign input samples used as references to create
adversarial examples.
labels (numpy.ndarray): Original/target labels.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.
Raises:
NotImplementedError: This function is not available in
IterativeGradientMethod.
@@ -177,12 +181,13 @@ class BasicIterativeMethod(IterativeGradientMethod):
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
nb_iter (int): Number of iteration. Default: 5.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.
attack (class): The single step gradient method of each iteration. In
this class, FGSM is used.

Examples:
>>> attack = BasicIterativeMethod(network)
>>> attack = BasicIterativeMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
is_targeted=False, nb_iter=5, loss_fn=None):
@@ -207,8 +212,8 @@ class BasicIterativeMethod(IterativeGradientMethod):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
labels (numpy.ndarray): Original/target labels.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.
Returns:
numpy.ndarray, generated adversarial examples.

@@ -218,8 +223,13 @@ class BasicIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels[{}]'.format(i), labels_item)
else:
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels', labels)
arr_x = inputs
if self._bounds is not None:
clip_min, clip_max = self._bounds
@@ -267,7 +277,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
decay_factor (float): Decay factor in iterations. Default: 1.0.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.
"""

def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
@@ -290,7 +301,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
labels (numpy.ndarray): Original/target labels.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.

Returns:
numpy.ndarray, generated adversarial examples.
@@ -301,8 +313,13 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels[{}]'.format(i), labels_item)
else:
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels', labels)
arr_x = inputs
momentum = 0
if self._bounds is not None:
@@ -340,7 +357,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):

Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Original/target labels.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.

Returns:
numpy.ndarray, gradient of labels w.r.t inputs.
@@ -350,7 +368,13 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
"""
# get grad of loss over x
out_grad = self._loss_grad(Tensor(inputs), Tensor(labels))
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
out_grad = self._loss_grad(Tensor(inputs), *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
gradient = out_grad.asnumpy()
@@ -384,7 +408,8 @@ class ProjectedGradientDescent(BasicIterativeMethod):
nb_iter (int): Number of iteration. Default: 5.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.
"""

def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
@@ -406,7 +431,8 @@ class ProjectedGradientDescent(BasicIterativeMethod):
Args:
inputs (numpy.ndarray): Benign input samples used as references to
create adversarial examples.
labels (numpy.ndarray): Original/target labels.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.

Returns:
numpy.ndarray, generated adversarial examples.
@@ -417,8 +443,13 @@ class ProjectedGradientDescent(BasicIterativeMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels[{}]'.format(i), labels_item)
else:
inputs, _ = check_pair_numpy_param('inputs', inputs, \
'labels', labels)
arr_x = inputs
if self._bounds is not None:
clip_min, clip_max = self._bounds
@@ -460,7 +491,8 @@ class DiverseInputIterativeMethod(BasicIterativeMethod):
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
prob (float): Transformation probability. Default: 0.5.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0),
is_targeted=False, prob=0.5, loss_fn=None):
@@ -495,7 +527,8 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'l1'.
prob (float): Transformation probability. Default: 0.5.
loss_fn (Loss): Loss function for optimization. Default: None.
loss_fn (Loss): Loss function for optimization. If None, the input network \
is already equipped with loss function. Default: None.
"""
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0),
is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None):


+ 3
- 0
mindarmour/fuzz_testing/fuzzing.py View File

@@ -19,6 +19,7 @@ from random import choice
import numpy as np
from mindspore import Model
from mindspore import Tensor
from mindspore import nn

from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_param_multi_types, check_norm_level, check_param_in_range, \
@@ -451,6 +452,8 @@ class Fuzzer:
else:
network = self._target_model._network
loss_fn = self._target_model._loss_fn
if loss_fn is None:
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
mutates[method] = self._strategies[method](network,
loss_fn=loss_fn)
return mutates


+ 2
- 2
tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py View File

@@ -18,7 +18,7 @@ import numpy as np
import pytest

import mindspore.ops.operations as P
from mindspore.nn import Cell
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
import mindspore.context as context

from mindarmour.adv_robustness.attacks import FastGradientMethod
@@ -67,7 +67,7 @@ def test_batch_generate_attack():
label = np.random.randint(0, 10, 128).astype(np.int32)
label = np.eye(10)[label].astype(np.float32)

attack = FastGradientMethod(Net())
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.batch_generate(input_np, label, batch_size=32)

assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \


+ 10
- 8
tests/ut/python/adv_robustness/attacks/test_gradient_method.py View File

@@ -71,7 +71,7 @@ def test_fast_gradient_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = FastGradientMethod(Net())
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
@@ -91,7 +91,7 @@ def test_fast_gradient_method_gpu():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = FastGradientMethod(Net())
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
@@ -132,7 +132,7 @@ def test_random_fast_gradient_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = RandomFastGradientMethod(Net())
attack = RandomFastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(ms_adv_x != input_np), 'Random fast gradient method: ' \
@@ -154,7 +154,7 @@ def test_fast_gradient_sign_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = FastGradientSignMethod(Net())
attack = FastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(ms_adv_x != input_np), 'Fast gradient sign method: generate' \
@@ -176,7 +176,7 @@ def test_random_fast_gradient_sign_method():
label = np.asarray([2], np.int32)
label = np.eye(28)[label].astype(np.float32)

attack = RandomFastGradientSignMethod(Net())
attack = RandomFastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(ms_adv_x != input_np), 'Random fast gradient sign method: ' \
@@ -198,7 +198,7 @@ def test_least_likely_class_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = LeastLikelyClassMethod(Net())
attack = LeastLikelyClassMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(ms_adv_x != input_np), 'Least likely class method: generate' \
@@ -220,7 +220,8 @@ def test_random_least_likely_class_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01)
attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01, \
loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(ms_adv_x != input_np), 'Random least likely class method: ' \
@@ -239,5 +240,6 @@ def test_assert_error():
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
with pytest.raises(ValueError) as e:
assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21)
assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21, \
loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
assert str(e.value) == 'eps must be larger than alpha!'

+ 7
- 9
tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py View File

@@ -20,6 +20,7 @@ import pytest
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore import context
from mindspore.nn import SoftmaxCrossEntropyWithLogits

from mindarmour.adv_robustness.attacks import BasicIterativeMethod
from mindarmour.adv_robustness.attacks import MomentumIterativeMethod
@@ -70,7 +71,7 @@ def test_basic_iterative_method():

for i in range(5):
net = Net()
attack = BasicIterativeMethod(net, nb_iter=i + 1)
attack = BasicIterativeMethod(net, nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(
ms_adv_x != input_np), 'Basic iterative method: generate value' \
@@ -91,7 +92,7 @@ def test_momentum_iterative_method():
label = np.eye(3)[label].astype(np.float32)

for i in range(5):
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1)
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \
' value must not be equal to' \
@@ -112,7 +113,7 @@ def test_projected_gradient_descent_method():
label = np.eye(3)[label].astype(np.float32)

for i in range(5):
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1)
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)

assert np.any(
@@ -134,7 +135,7 @@ def test_diverse_input_iterative_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = DiverseInputIterativeMethod(Net())
attack = DiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \
' value must not be equal to' \
@@ -154,7 +155,7 @@ def test_momentum_diverse_input_iterative_method():
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

attack = MomentumDiverseInputIterativeMethod(Net())
attack = MomentumDiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \
'generate value must not be equal to' \
@@ -167,10 +168,7 @@ def test_momentum_diverse_input_iterative_method():
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_error():
with pytest.raises(TypeError):
# check_param_multi_types
assert IterativeGradientMethod(Net(), bounds=None)
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0))
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
with pytest.raises(NotImplementedError):
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32)
label = np.asarray([2], np.int32)


+ 2
- 2
tests/ut/python/adv_robustness/defenses/test_ead.py View File

@@ -59,8 +59,8 @@ def test_ead():
optimizer = Momentum(net.trainable_params(), 0.001, 0.9)

net = Net()
fgsm = FastGradientSignMethod(net)
pgd = ProjectedGradientDescent(net)
fgsm = FastGradientSignMethod(net, loss_fn=loss_fn)
pgd = ProjectedGradientDescent(net, loss_fn=loss_fn)
ead = EnsembleAdversarialDefense(net, [fgsm, pgd], loss_fn=loss_fn,
optimizer=optimizer)
LOGGER.set_level(logging.DEBUG)


+ 1
- 1
tests/ut/python/fuzzing/test_coverage_metrics.py View File

@@ -117,7 +117,7 @@ def test_lenet_mnist_coverage_ascend():
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

# generate adv_data
attack = FastGradientSignMethod(net, eps=0.3)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())


Loading…
Cancel
Save