@@ -0,0 +1,47 @@ | |||
# Dataset | |||
Dataset used: [COCO2017](<https://cocodataset.org/>) | |||
- Dataset size:19G | |||
- Train:18G,118000 images | |||
- Val:1G,5000 images | |||
- Annotations:241M,instances,captions,person_keypoints etc | |||
- Data format:image and json files | |||
- Note:Data will be processed in dataset.py | |||
# Environment Requirements | |||
- Install [MindSpore](https://www.mindspore.cn/install/en). | |||
- Download the dataset COCO2017. | |||
- We use COCO2017 as dataset in this example. | |||
Install Cython and pycocotool, and you can also install mmcv to process data. | |||
``` | |||
pip install Cython | |||
pip install pycocotools | |||
pip install mmcv==0.2.14 | |||
``` | |||
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: | |||
``` | |||
. | |||
└─cocodataset | |||
├─annotations | |||
├─instance_train2017.json | |||
└─instance_val2017.json | |||
├─val2017 | |||
└─train2017 | |||
``` | |||
# Quick start | |||
You can download the pre-trained model checkpoint file [here](<https://www.mindspore.cn/resources/hub/details?2505/MindSpore/ascend/0.7/fasterrcnn_v1.0_coco2017>). | |||
``` | |||
python coco_attack_pgd.py --ann_file [VAL_JSON_FILE] --pre_trained [PRETRAINED_CHECKPOINT_FILE] | |||
``` | |||
> Adversarial samples will be generated and saved as pickle file. |
@@ -0,0 +1,135 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""PGD attack for faster rcnn""" | |||
import os | |||
import argparse | |||
import pickle | |||
from mindspore import context | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindspore.common import set_seed | |||
from mindspore.nn import Cell | |||
from mindspore.ops.composite import GradOperation | |||
from mindarmour.adv_robustness.attacks import ProjectedGradientDescent | |||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | |||
from src.config import config | |||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset | |||
# pylint: disable=locally-disabled, unused-argument, redefined-outer-name | |||
set_seed(1) | |||
parser = argparse.ArgumentParser(description='FasterRCNN attack') | |||
parser.add_argument('--ann_file', type=str, required=True, help='Ann file path.') | |||
parser.add_argument('--pre_trained', type=str, required=True, help='pre-trained ckpt file path for target model.') | |||
parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.') | |||
parser.add_argument('--num', type=int, default=5, help='Number of adversarial examples.') | |||
args = parser.parse_args() | |||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=args.device_id) | |||
class LossNet(Cell): | |||
"""loss function.""" | |||
def construct(self, x1, x2, x3, x4, x5, x6): | |||
return x4 + x6 | |||
class WithLossCell(Cell): | |||
"""Wrap the network with loss function.""" | |||
def __init__(self, backbone, loss_fn): | |||
super(WithLossCell, self).__init__(auto_prefix=False) | |||
self._backbone = backbone | |||
self._loss_fn = loss_fn | |||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num): | |||
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | |||
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6) | |||
@property | |||
def backbone_network(self): | |||
return self._backbone | |||
class GradWrapWithLoss(Cell): | |||
""" | |||
Construct a network to compute the gradient of loss function in \ | |||
input space and weighted by `weight`. | |||
""" | |||
def __init__(self, network): | |||
super(GradWrapWithLoss, self).__init__() | |||
self._grad_all = GradOperation(get_all=True, sens_param=False) | |||
self._network = network | |||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_num): | |||
gout = self._grad_all(self._network)(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | |||
return gout[0] | |||
if __name__ == '__main__': | |||
prefix = 'FasterRcnn_eval.mindrecord' | |||
mindrecord_dir = config.mindrecord_dir | |||
mindrecord_file = os.path.join(mindrecord_dir, prefix) | |||
pre_trained = args.pre_trained | |||
ann_file = args.ann_file | |||
print("CHECKING MINDRECORD FILES ...") | |||
if not os.path.exists(mindrecord_file): | |||
if not os.path.isdir(mindrecord_dir): | |||
os.makedirs(mindrecord_dir) | |||
if os.path.isdir(config.coco_root): | |||
print("Create Mindrecord. It may take some time.") | |||
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) | |||
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
else: | |||
print("coco_root not exits.") | |||
print('Start generate adversarial samples.') | |||
# build network and dataset | |||
ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \ | |||
repeat_num=1, is_training=True) | |||
net = Faster_Rcnn_Resnet50(config) | |||
param_dict = load_checkpoint(pre_trained) | |||
load_param_into_net(net, param_dict) | |||
net = net.set_train() | |||
# build attacker | |||
with_loss_cell = WithLossCell(net, LossNet()) | |||
grad_with_loss_net = GradWrapWithLoss(with_loss_cell) | |||
attack = ProjectedGradientDescent(grad_with_loss_net, bounds=None, eps=0.1) | |||
# generate adversarial samples | |||
num = args.num | |||
num_batches = num // config.test_batch_size | |||
channel = 3 | |||
adv_samples = [0] * (num_batches * config.test_batch_size) | |||
adv_id = 0 | |||
for data in ds.create_dict_iterator(num_epochs=num_batches): | |||
img_data = data['image'] | |||
img_metas = data['image_shape'] | |||
gt_bboxes = data['box'] | |||
gt_labels = data['label'] | |||
gt_num = data['valid_num'] | |||
adv_img = attack.generate(img_data.asnumpy(), \ | |||
(img_metas.asnumpy(), gt_bboxes.asnumpy(), gt_labels.asnumpy(), gt_num.asnumpy())) | |||
for item in adv_img: | |||
adv_samples[adv_id] = item | |||
adv_id += 1 | |||
pickle.dump(adv_samples, open('adv_samples.pkl', 'wb')) | |||
print('Generate adversarial samples complete.') |
@@ -0,0 +1,31 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn Init.""" | |||
from .resnet50 import ResNetFea, ResidualBlockUsing | |||
from .bbox_assign_sample import BboxAssignSample | |||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn | |||
from .fpn_neck import FeatPyramidNeck | |||
from .proposal_generator import Proposal | |||
from .rcnn import Rcnn | |||
from .rpn import RPN | |||
from .roi_align import SingleRoIExtractor | |||
from .anchor_generator import AnchorGenerator | |||
__all__ = [ | |||
"ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn", | |||
"FeatPyramidNeck", "Proposal", "Rcnn", | |||
"RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing" | |||
] |
@@ -0,0 +1,84 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn anchor generator.""" | |||
import numpy as np | |||
class AnchorGenerator(): | |||
"""Anchor generator for FasterRcnn.""" | |||
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): | |||
"""Anchor generator init method.""" | |||
self.base_size = base_size | |||
self.scales = np.array(scales) | |||
self.ratios = np.array(ratios) | |||
self.scale_major = scale_major | |||
self.ctr = ctr | |||
self.base_anchors = self.gen_base_anchors() | |||
def gen_base_anchors(self): | |||
"""Generate a single anchor.""" | |||
w = self.base_size | |||
h = self.base_size | |||
if self.ctr is None: | |||
x_ctr = 0.5 * (w - 1) | |||
y_ctr = 0.5 * (h - 1) | |||
else: | |||
x_ctr, y_ctr = self.ctr | |||
h_ratios = np.sqrt(self.ratios) | |||
w_ratios = 1 / h_ratios | |||
if self.scale_major: | |||
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1) | |||
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1) | |||
else: | |||
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1) | |||
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1) | |||
base_anchors = np.stack( | |||
[ | |||
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), | |||
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) | |||
], | |||
axis=-1).round() | |||
return base_anchors | |||
def _meshgrid(self, x, y, row_major=True): | |||
"""Generate grid.""" | |||
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1) | |||
yy = np.repeat(y, len(x)) | |||
if row_major: | |||
return xx, yy | |||
return yy, xx | |||
def grid_anchors(self, featmap_size, stride=16): | |||
"""Generate anchor list.""" | |||
base_anchors = self.base_anchors | |||
feat_h, feat_w = featmap_size | |||
shift_x = np.arange(0, feat_w) * stride | |||
shift_y = np.arange(0, feat_h) * stride | |||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) | |||
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) | |||
shifts = shifts.astype(base_anchors.dtype) | |||
# first feat_w elements correspond to the first row of shifts | |||
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get | |||
# shifted anchors (K, A, 4), reshape to (K*A, 4) | |||
all_anchors = base_anchors[None, :, :] + shifts[:, None, :] | |||
all_anchors = all_anchors.reshape(-1, 4) | |||
return all_anchors |
@@ -0,0 +1,166 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn positive and negative sample screening for RPN.""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.ops import operations as P | |||
from mindspore.common.tensor import Tensor | |||
import mindspore.common.dtype as mstype | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
class BboxAssignSample(nn.Cell): | |||
""" | |||
Bbox assigner and sampler defination. | |||
Args: | |||
config (dict): Config. | |||
batch_size (int): Batchsize. | |||
num_bboxes (int): The anchor nums. | |||
add_gt_as_proposals (bool): add gt bboxes as proposals flag. | |||
Returns: | |||
Tensor, output tensor. | |||
bbox_targets: bbox location, (batch_size, num_bboxes, 4) | |||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1) | |||
labels: label for every bboxes, (batch_size, num_bboxes, 1) | |||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1) | |||
Examples: | |||
BboxAssignSample(config, 2, 1024, True) | |||
""" | |||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | |||
super(BboxAssignSample, self).__init__() | |||
cfg = config | |||
self.batch_size = batch_size | |||
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16) | |||
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16) | |||
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16) | |||
self.zero_thr = Tensor(0.0, mstype.float16) | |||
self.num_bboxes = num_bboxes | |||
self.num_gts = cfg.num_gts | |||
self.num_expected_pos = cfg.num_expected_pos | |||
self.num_expected_neg = cfg.num_expected_neg | |||
self.add_gt_as_proposals = add_gt_as_proposals | |||
if self.add_gt_as_proposals: | |||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1)) | |||
self.concat = P.Concat(axis=0) | |||
self.max_gt = P.ArgMaxWithValue(axis=0) | |||
self.max_anchor = P.ArgMaxWithValue(axis=1) | |||
self.sum_inds = P.ReduceSum() | |||
self.iou = P.IOU() | |||
self.greaterequal = P.GreaterEqual() | |||
self.greater = P.Greater() | |||
self.select = P.Select() | |||
self.gatherND = P.GatherNd() | |||
self.squeeze = P.Squeeze() | |||
self.cast = P.Cast() | |||
self.logicaland = P.LogicalAnd() | |||
self.less = P.Less() | |||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) | |||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | |||
self.reshape = P.Reshape() | |||
self.equal = P.Equal() | |||
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) | |||
self.scatterNdUpdate = P.ScatterNdUpdate() | |||
self.scatterNd = P.ScatterNd() | |||
self.logicalnot = P.LogicalNot() | |||
self.tile = P.Tile() | |||
self.zeros_like = P.ZerosLike() | |||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) | |||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) | |||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | |||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | |||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) | |||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) | |||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) | |||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | |||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ | |||
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one) | |||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ | |||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two) | |||
overlaps = self.iou(bboxes, gt_bboxes_i) | |||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) | |||
_, max_overlaps_w_ac = self.max_anchor(overlaps) | |||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \ | |||
self.less(max_overlaps_w_gt, self.neg_iou_thr)) | |||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) | |||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr) | |||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ | |||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) | |||
assigned_gt_inds4 = assigned_gt_inds3 | |||
for j in range(self.num_gts): | |||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] | |||
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::]) | |||
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \ | |||
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j)) | |||
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4) | |||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores) | |||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | |||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) | |||
pos_check_valid = self.sum_inds(pos_check_valid, -1) | |||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | |||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | |||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones | |||
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32) | |||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1)) | |||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) | |||
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16) | |||
num_pos = self.sum_inds(num_pos, -1) | |||
unvalid_pos_index = self.less(self.range_pos_size, num_pos) | |||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) | |||
pos_bboxes_ = self.gatherND(bboxes, pos_index) | |||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) | |||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) | |||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) | |||
valid_pos_index = self.cast(valid_pos_index, mstype.int32) | |||
valid_neg_index = self.cast(valid_neg_index, mstype.int32) | |||
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4)) | |||
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,)) | |||
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,)) | |||
total_index = self.concat((pos_index, neg_index)) | |||
total_valid_index = self.concat((valid_pos_index, valid_neg_index)) | |||
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,)) | |||
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \ | |||
labels_total, self.cast(label_weights_total, mstype.bool_) |
@@ -0,0 +1,197 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn tpositive and negative sample screening for Rcnn.""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
from mindspore.ops import operations as P | |||
from mindspore.common.tensor import Tensor | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
class BboxAssignSampleForRcnn(nn.Cell): | |||
""" | |||
Bbox assigner and sampler defination. | |||
Args: | |||
config (dict): Config. | |||
batch_size (int): Batchsize. | |||
num_bboxes (int): The anchor nums. | |||
add_gt_as_proposals (bool): add gt bboxes as proposals flag. | |||
Returns: | |||
Tensor, output tensor. | |||
bbox_targets: bbox location, (batch_size, num_bboxes, 4) | |||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1) | |||
labels: label for every bboxes, (batch_size, num_bboxes, 1) | |||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1) | |||
Examples: | |||
BboxAssignSampleForRcnn(config, 2, 1024, True) | |||
""" | |||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | |||
super(BboxAssignSampleForRcnn, self).__init__() | |||
cfg = config | |||
self.batch_size = batch_size | |||
self.neg_iou_thr = cfg.neg_iou_thr_stage2 | |||
self.pos_iou_thr = cfg.pos_iou_thr_stage2 | |||
self.min_pos_iou = cfg.min_pos_iou_stage2 | |||
self.num_gts = cfg.num_gts | |||
self.num_bboxes = num_bboxes | |||
self.num_expected_pos = cfg.num_expected_pos_stage2 | |||
self.num_expected_neg = cfg.num_expected_neg_stage2 | |||
self.num_expected_total = cfg.num_expected_total_stage2 | |||
self.add_gt_as_proposals = add_gt_as_proposals | |||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32)) | |||
self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts), | |||
dtype=np.int32)) | |||
self.concat = P.Concat(axis=0) | |||
self.max_gt = P.ArgMaxWithValue(axis=0) | |||
self.max_anchor = P.ArgMaxWithValue(axis=1) | |||
self.sum_inds = P.ReduceSum() | |||
self.iou = P.IOU() | |||
self.greaterequal = P.GreaterEqual() | |||
self.greater = P.Greater() | |||
self.select = P.Select() | |||
self.gatherND = P.GatherNd() | |||
self.squeeze = P.Squeeze() | |||
self.cast = P.Cast() | |||
self.logicaland = P.LogicalAnd() | |||
self.less = P.Less() | |||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) | |||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | |||
self.reshape = P.Reshape() | |||
self.equal = P.Equal() | |||
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2)) | |||
self.concat_axis1 = P.Concat(axis=1) | |||
self.logicalnot = P.LogicalNot() | |||
self.tile = P.Tile() | |||
# Check | |||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) | |||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) | |||
# Init tensor | |||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) | |||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) | |||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | |||
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32)) | |||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) | |||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | |||
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16)) | |||
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8)) | |||
self.reshape_shape_pos = (self.num_expected_pos, 1) | |||
self.reshape_shape_neg = (self.num_expected_neg, 1) | |||
self.scalar_zero = Tensor(0.0, dtype=mstype.float16) | |||
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16) | |||
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16) | |||
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16) | |||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | |||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ | |||
(self.num_gts, 1)), (1, 4)), mstype.bool_), \ | |||
gt_bboxes_i, self.check_gt_one) | |||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ | |||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \ | |||
bboxes, self.check_anchor_two) | |||
overlaps = self.iou(bboxes, gt_bboxes_i) | |||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) | |||
_, max_overlaps_w_ac = self.max_anchor(overlaps) | |||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, | |||
self.scalar_zero), | |||
self.less(max_overlaps_w_gt, | |||
self.scalar_neg_iou_thr)) | |||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) | |||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr) | |||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ | |||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) | |||
for j in range(self.num_gts): | |||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] | |||
overlaps_w_ac_j = overlaps[j:j+1:1, ::] | |||
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou) | |||
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j)) | |||
pos_mask_j = self.logicaland(temp1, temp2) | |||
assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3) | |||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores) | |||
bboxes = self.concat((gt_bboxes_i, bboxes)) | |||
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores) | |||
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid | |||
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5)) | |||
# Get pos index | |||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | |||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) | |||
pos_check_valid = self.sum_inds(pos_check_valid, -1) | |||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | |||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | |||
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1) | |||
valid_pos_index = self.cast(valid_pos_index, mstype.int32) | |||
pos_index = self.reshape(pos_index, self.reshape_shape_pos) | |||
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) | |||
pos_index = pos_index * valid_pos_index | |||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones | |||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos) | |||
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index | |||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) | |||
# Get neg index | |||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) | |||
unvalid_pos_index = self.less(self.range_pos_size, num_pos) | |||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) | |||
neg_index = self.reshape(neg_index, self.reshape_shape_neg) | |||
valid_neg_index = self.cast(valid_neg_index, mstype.int32) | |||
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg) | |||
neg_index = neg_index * valid_neg_index | |||
pos_bboxes_ = self.gatherND(bboxes, pos_index) | |||
neg_bboxes_ = self.gatherND(bboxes, neg_index) | |||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos) | |||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) | |||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) | |||
total_bboxes = self.concat((pos_bboxes_, neg_bboxes_)) | |||
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask)) | |||
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask)) | |||
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) | |||
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg) | |||
total_mask = self.concat((valid_pos_index, valid_neg_index)) | |||
return total_bboxes, total_deltas, total_labels, total_mask |
@@ -0,0 +1,428 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn based on ResNet50.""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.ops import operations as P | |||
from mindspore.common.tensor import Tensor | |||
import mindspore.common.dtype as mstype | |||
from mindspore.ops import functional as F | |||
from .resnet50 import ResNetFea, ResidualBlockUsing | |||
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn | |||
from .fpn_neck import FeatPyramidNeck | |||
from .proposal_generator import Proposal | |||
from .rcnn import Rcnn | |||
from .rpn import RPN | |||
from .roi_align import SingleRoIExtractor | |||
from .anchor_generator import AnchorGenerator | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
class Faster_Rcnn_Resnet50(nn.Cell): | |||
""" | |||
FasterRcnn Network. | |||
Note: | |||
backbone = resnet50 | |||
Returns: | |||
Tuple, tuple of output tensor. | |||
rpn_loss: Scalar, Total loss of RPN subnet. | |||
rcnn_loss: Scalar, Total loss of RCNN subnet. | |||
rpn_cls_loss: Scalar, Classification loss of RPN subnet. | |||
rpn_reg_loss: Scalar, Regression loss of RPN subnet. | |||
rcnn_cls_loss: Scalar, Classification loss of RCNN subnet. | |||
rcnn_reg_loss: Scalar, Regression loss of RCNN subnet. | |||
Examples: | |||
net = Faster_Rcnn_Resnet50() | |||
""" | |||
def __init__(self, config): | |||
super(Faster_Rcnn_Resnet50, self).__init__() | |||
self.train_batch_size = config.batch_size | |||
self.num_classes = config.num_classes | |||
self.anchor_scales = config.anchor_scales | |||
self.anchor_ratios = config.anchor_ratios | |||
self.anchor_strides = config.anchor_strides | |||
self.target_means = tuple(config.rcnn_target_means) | |||
self.target_stds = tuple(config.rcnn_target_stds) | |||
# Anchor generator | |||
anchor_base_sizes = None | |||
self.anchor_base_sizes = list( | |||
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes | |||
self.anchor_generators = [] | |||
for anchor_base in self.anchor_base_sizes: | |||
self.anchor_generators.append( | |||
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios)) | |||
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales) | |||
featmap_sizes = config.feature_shapes | |||
assert len(featmap_sizes) == len(self.anchor_generators) | |||
self.anchor_list = self.get_anchors(featmap_sizes) | |||
# Backbone resnet50 | |||
self.backbone = ResNetFea(ResidualBlockUsing, | |||
config.resnet_block, | |||
config.resnet_in_channels, | |||
config.resnet_out_channels, | |||
False) | |||
# Fpn | |||
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels, | |||
config.fpn_out_channels, | |||
config.fpn_num_outs) | |||
# Rpn and rpn loss | |||
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8)) | |||
self.rpn_with_loss = RPN(config, | |||
self.train_batch_size, | |||
config.rpn_in_channels, | |||
config.rpn_feat_channels, | |||
config.num_anchors, | |||
config.rpn_cls_out_channels) | |||
# Proposal | |||
self.proposal_generator = Proposal(config, | |||
self.train_batch_size, | |||
config.activate_num_classes, | |||
config.use_sigmoid_cls) | |||
self.proposal_generator.set_train_local(config, True) | |||
self.proposal_generator_test = Proposal(config, | |||
config.test_batch_size, | |||
config.activate_num_classes, | |||
config.use_sigmoid_cls) | |||
self.proposal_generator_test.set_train_local(config, False) | |||
# Assign and sampler stage two | |||
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size, | |||
config.num_bboxes_stage2, True) | |||
self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=self.target_means, \ | |||
stds=self.target_stds) | |||
# Roi | |||
self.roi_align = SingleRoIExtractor(config, | |||
config.roi_layer, | |||
config.roi_align_out_channels, | |||
config.roi_align_featmap_strides, | |||
self.train_batch_size, | |||
config.roi_align_finest_scale) | |||
self.roi_align.set_train_local(config, True) | |||
self.roi_align_test = SingleRoIExtractor(config, | |||
config.roi_layer, | |||
config.roi_align_out_channels, | |||
config.roi_align_featmap_strides, | |||
1, | |||
config.roi_align_finest_scale) | |||
self.roi_align_test.set_train_local(config, False) | |||
# Rcnn | |||
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer['out_size'] * config.roi_layer['out_size'], | |||
self.train_batch_size, self.num_classes) | |||
# Op declare | |||
self.squeeze = P.Squeeze() | |||
self.cast = P.Cast() | |||
self.concat = P.Concat(axis=0) | |||
self.concat_1 = P.Concat(axis=1) | |||
self.concat_2 = P.Concat(axis=2) | |||
self.reshape = P.Reshape() | |||
self.select = P.Select() | |||
self.greater = P.Greater() | |||
self.transpose = P.Transpose() | |||
# Test mode | |||
self.test_batch_size = config.test_batch_size | |||
self.split = P.Split(axis=0, output_num=self.test_batch_size) | |||
self.split_shape = P.Split(axis=0, output_num=4) | |||
self.split_scores = P.Split(axis=1, output_num=self.num_classes) | |||
self.split_cls = P.Split(axis=0, output_num=self.num_classes-1) | |||
self.tile = P.Tile() | |||
self.gather = P.GatherNd() | |||
self.rpn_max_num = config.rpn_max_num | |||
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16)) | |||
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool) | |||
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool) | |||
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask, | |||
self.ones_mask, self.zeros_mask), axis=1)) | |||
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask, | |||
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1)) | |||
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr) | |||
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0) | |||
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1) | |||
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr) | |||
self.test_max_per_img = config.test_max_per_img | |||
self.nms_test = P.NMSWithMask(config.test_iou_thr) | |||
self.softmax = P.Softmax(axis=1) | |||
self.logicand = P.LogicalAnd() | |||
self.oneslike = P.OnesLike() | |||
self.test_topk = P.TopK(sorted=True) | |||
self.test_num_proposal = self.test_batch_size * self.rpn_max_num | |||
# Improve speed | |||
self.concat_start = min(self.num_classes - 2, 55) | |||
self.concat_end = (self.num_classes - 1) | |||
# Init tensor | |||
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i, | |||
dtype=np.float16) for i in range(self.train_batch_size)] | |||
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \ | |||
for i in range(self.test_batch_size)] | |||
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index)) | |||
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test)) | |||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids): | |||
x = self.backbone(img_data) | |||
x = self.fpn_ncek(x) | |||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x, | |||
img_metas, | |||
self.anchor_list, | |||
gt_bboxes, | |||
self.gt_labels_stage1, | |||
gt_valids) | |||
if self.training: | |||
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list) | |||
else: | |||
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list) | |||
gt_labels = self.cast(gt_labels, mstype.int32) | |||
gt_valids = self.cast(gt_valids, mstype.int32) | |||
bboxes_tuple = () | |||
deltas_tuple = () | |||
labels_tuple = () | |||
mask_tuple = () | |||
if self.training: | |||
for i in range(self.train_batch_size): | |||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) | |||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) | |||
gt_labels_i = self.cast(gt_labels_i, mstype.uint8) | |||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) | |||
gt_valids_i = self.cast(gt_valids_i, mstype.bool_) | |||
bboxes, deltas, labels, mask = self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i, | |||
gt_labels_i, | |||
proposal_mask[i], | |||
proposal[i][::, 0:4:1], | |||
gt_valids_i) | |||
bboxes_tuple += (bboxes,) | |||
deltas_tuple += (deltas,) | |||
labels_tuple += (labels,) | |||
mask_tuple += (mask,) | |||
bbox_targets = self.concat(deltas_tuple) | |||
rcnn_labels = self.concat(labels_tuple) | |||
bbox_targets = F.stop_gradient(bbox_targets) | |||
rcnn_labels = F.stop_gradient(rcnn_labels) | |||
rcnn_labels = self.cast(rcnn_labels, mstype.int32) | |||
else: | |||
mask_tuple += proposal_mask | |||
bbox_targets = proposal_mask | |||
rcnn_labels = proposal_mask | |||
for p_i in proposal: | |||
bboxes_tuple += (p_i[::, 0:4:1],) | |||
if self.training: | |||
if self.train_batch_size > 1: | |||
bboxes_all = self.concat(bboxes_tuple) | |||
else: | |||
bboxes_all = bboxes_tuple[0] | |||
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all)) | |||
else: | |||
if self.test_batch_size > 1: | |||
bboxes_all = self.concat(bboxes_tuple) | |||
else: | |||
bboxes_all = bboxes_tuple[0] | |||
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) | |||
rois = self.cast(rois, mstype.float32) | |||
rois = F.stop_gradient(rois) | |||
if self.training: | |||
roi_feats = self.roi_align(rois, | |||
self.cast(x[0], mstype.float32), | |||
self.cast(x[1], mstype.float32), | |||
self.cast(x[2], mstype.float32), | |||
self.cast(x[3], mstype.float32)) | |||
else: | |||
roi_feats = self.roi_align_test(rois, | |||
self.cast(x[0], mstype.float32), | |||
self.cast(x[1], mstype.float32), | |||
self.cast(x[2], mstype.float32), | |||
self.cast(x[3], mstype.float32)) | |||
roi_feats = self.cast(roi_feats, mstype.float16) | |||
rcnn_masks = self.concat(mask_tuple) | |||
rcnn_masks = F.stop_gradient(rcnn_masks) | |||
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_)) | |||
rcnn_loss, rcnn_cls_loss, rcnn_reg_loss, _ = self.rcnn(roi_feats, | |||
bbox_targets, | |||
rcnn_labels, | |||
rcnn_mask_squeeze) | |||
output = () | |||
if self.training: | |||
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss) | |||
else: | |||
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, img_metas) | |||
return output | |||
def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas): | |||
"""Get the actual detection box.""" | |||
scores = self.softmax(cls_logits) | |||
boxes_all = () | |||
for i in range(self.num_classes): | |||
k = i * 4 | |||
reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1]) | |||
out_boxes_i = self.decode(rois, reg_logits_i) | |||
boxes_all += (out_boxes_i,) | |||
img_metas_all = self.split(img_metas) | |||
scores_all = self.split(scores) | |||
mask_all = self.split(self.cast(mask_logits, mstype.int32)) | |||
boxes_all_with_batchsize = () | |||
for i in range(self.test_batch_size): | |||
scale = self.split_shape(self.squeeze(img_metas_all[i])) | |||
scale_h = scale[2] | |||
scale_w = scale[3] | |||
boxes_tuple = () | |||
for j in range(self.num_classes): | |||
boxes_tmp = self.split(boxes_all[j]) | |||
out_boxes_h = boxes_tmp[i] / scale_h | |||
out_boxes_w = boxes_tmp[i] / scale_w | |||
boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),) | |||
boxes_all_with_batchsize += (boxes_tuple,) | |||
output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all) | |||
return output | |||
def multiclass_nms(self, boxes_all, scores_all, mask_all): | |||
"""Multiscale postprocessing.""" | |||
all_bboxes = () | |||
all_labels = () | |||
all_masks = () | |||
for i in range(self.test_batch_size): | |||
bboxes = boxes_all[i] | |||
scores = scores_all[i] | |||
masks = self.cast(mask_all[i], mstype.bool_) | |||
res_boxes_tuple = () | |||
res_labels_tuple = () | |||
res_masks_tuple = () | |||
for j in range(self.num_classes - 1): | |||
k = j + 1 | |||
_cls_scores = scores[::, k:k + 1:1] | |||
_bboxes = self.squeeze(bboxes[k]) | |||
_mask_o = self.reshape(masks, (self.rpn_max_num, 1)) | |||
cls_mask = self.greater(_cls_scores, self.test_score_thresh) | |||
_mask = self.logicand(_mask_o, cls_mask) | |||
_reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_) | |||
_bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros) | |||
_cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros) | |||
__cls_scores = self.squeeze(_cls_scores) | |||
scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num) | |||
topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1)) | |||
scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1)) | |||
_bboxes_sorted = self.gather(_bboxes, topk_inds) | |||
_mask_sorted = self.gather(_mask, topk_inds) | |||
scores_sorted = self.tile(scores_sorted, (1, 4)) | |||
cls_dets = self.concat_1((_bboxes_sorted, scores_sorted)) | |||
cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5)) | |||
cls_dets, _index, _mask_nms = self.nms_test(cls_dets) | |||
_index = self.reshape(_index, (self.rpn_max_num, 1)) | |||
_mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1)) | |||
_mask_n = self.gather(_mask_sorted, _index) | |||
_mask_n = self.logicand(_mask_n, _mask_nms) | |||
cls_labels = self.oneslike(_index) * j | |||
res_boxes_tuple += (cls_dets,) | |||
res_labels_tuple += (cls_labels,) | |||
res_masks_tuple += (_mask_n,) | |||
res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start]) | |||
res_labels_start = self.concat(res_labels_tuple[:self.concat_start]) | |||
res_masks_start = self.concat(res_masks_tuple[:self.concat_start]) | |||
res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end]) | |||
res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end]) | |||
res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end]) | |||
res_boxes = self.concat((res_boxes_start, res_boxes_end)) | |||
res_labels = self.concat((res_labels_start, res_labels_end)) | |||
res_masks = self.concat((res_masks_start, res_masks_end)) | |||
reshape_size = (self.num_classes - 1) * self.rpn_max_num | |||
res_boxes = self.reshape(res_boxes, (1, reshape_size, 5)) | |||
res_labels = self.reshape(res_labels, (1, reshape_size, 1)) | |||
res_masks = self.reshape(res_masks, (1, reshape_size, 1)) | |||
all_bboxes += (res_boxes,) | |||
all_labels += (res_labels,) | |||
all_masks += (res_masks,) | |||
all_bboxes = self.concat(all_bboxes) | |||
all_labels = self.concat(all_labels) | |||
all_masks = self.concat(all_masks) | |||
return all_bboxes, all_labels, all_masks | |||
def get_anchors(self, featmap_sizes): | |||
"""Get anchors according to feature map sizes. | |||
Args: | |||
featmap_sizes (list[tuple]): Multi-level feature map sizes. | |||
img_metas (list[dict]): Image meta info. | |||
Returns: | |||
tuple: anchors of each image, valid flags of each image | |||
""" | |||
num_levels = len(featmap_sizes) | |||
# since feature map sizes of all images are the same, we only compute | |||
# anchors for one time | |||
multi_level_anchors = () | |||
for i in range(num_levels): | |||
anchors = self.anchor_generators[i].grid_anchors( | |||
featmap_sizes[i], self.anchor_strides[i]) | |||
multi_level_anchors += (Tensor(anchors.astype(np.float16)),) | |||
return multi_level_anchors |
@@ -0,0 +1,114 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn feature pyramid network.""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore import context | |||
from mindspore.ops import operations as P | |||
from mindspore.common.tensor import Tensor | |||
from mindspore.common import dtype as mstype | |||
from mindspore.common.initializer import initializer | |||
# pylint: disable=locally-disabled, missing-docstring | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
def bias_init_zeros(shape): | |||
"""Bias init method.""" | |||
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16)) | |||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | |||
"""Conv2D wrapper.""" | |||
shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() | |||
shape_bias = (out_channels,) | |||
biass = bias_init_zeros(shape_bias) | |||
return nn.Conv2d(in_channels, out_channels, | |||
kernel_size=kernel_size, stride=stride, padding=padding, | |||
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass) | |||
class FeatPyramidNeck(nn.Cell): | |||
""" | |||
Feature pyramid network cell, usually uses as network neck. | |||
Applies the convolution on multiple, input feature maps | |||
and output feature map with same channel size. if required num of | |||
output larger then num of inputs, add extra maxpooling for further | |||
downsampling; | |||
Args: | |||
in_channels (tuple) - Channel size of input feature maps. | |||
out_channels (int) - Channel size output. | |||
num_outs (int) - Num of output features. | |||
Returns: | |||
Tuple, with tensors of same channel size. | |||
Examples: | |||
neck = FeatPyramidNeck([100,200,300], 50, 4) | |||
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)), | |||
dtype=np.float32) \ | |||
for i, c in enumerate(config.fpn_in_channels)) | |||
x = neck(input_data) | |||
""" | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
num_outs): | |||
super(FeatPyramidNeck, self).__init__() | |||
self.num_outs = num_outs | |||
self.in_channels = in_channels | |||
self.fpn_layer = len(self.in_channels) | |||
assert not self.num_outs < len(in_channels) | |||
self.lateral_convs_list_ = [] | |||
self.fpn_convs_ = [] | |||
for _, channel in enumerate(in_channels): | |||
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid') | |||
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same') | |||
self.lateral_convs_list_.append(l_conv) | |||
self.fpn_convs_.append(fpn_conv) | |||
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) | |||
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) | |||
self.interpolate1 = P.ResizeNearestNeighbor((48, 80)) | |||
self.interpolate2 = P.ResizeNearestNeighbor((96, 160)) | |||
self.interpolate3 = P.ResizeNearestNeighbor((192, 320)) | |||
self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same") | |||
def construct(self, inputs): | |||
x = () | |||
for i in range(self.fpn_layer): | |||
x += (self.lateral_convs_list[i](inputs[i]),) | |||
y = (x[3],) | |||
y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),) | |||
y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),) | |||
y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),) | |||
z = () | |||
for i in range(self.fpn_layer - 1, -1, -1): | |||
z = z + (y[i],) | |||
outs = () | |||
for i in range(self.fpn_layer): | |||
outs = outs + (self.fpn_convs_list[i](z[i]),) | |||
for i in range(self.num_outs - self.fpn_layer): | |||
outs = outs + (self.maxpool(outs[3]),) | |||
return outs |
@@ -0,0 +1,201 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn proposal generator.""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
from mindspore.ops import operations as P | |||
from mindspore import Tensor | |||
from mindspore import context | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class Proposal(nn.Cell): | |||
""" | |||
Proposal subnet. | |||
Args: | |||
config (dict): Config. | |||
batch_size (int): Batchsize. | |||
num_classes (int) - Class number. | |||
use_sigmoid_cls (bool) - Select sigmoid or softmax function. | |||
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0). | |||
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0). | |||
Returns: | |||
Tuple, tuple of output tensor,(proposal, mask). | |||
Examples: | |||
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \ | |||
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0)) | |||
""" | |||
def __init__(self, | |||
config, | |||
batch_size, | |||
num_classes, | |||
use_sigmoid_cls, | |||
target_means=(.0, .0, .0, .0), | |||
target_stds=(1.0, 1.0, 1.0, 1.0) | |||
): | |||
super(Proposal, self).__init__() | |||
cfg = config | |||
self.batch_size = batch_size | |||
self.num_classes = num_classes | |||
self.target_means = target_means | |||
self.target_stds = target_stds | |||
self.use_sigmoid_cls = use_sigmoid_cls | |||
if self.use_sigmoid_cls: | |||
self.cls_out_channels = num_classes - 1 | |||
self.activation = P.Sigmoid() | |||
self.reshape_shape = (-1, 1) | |||
else: | |||
self.cls_out_channels = num_classes | |||
self.activation = P.Softmax(axis=1) | |||
self.reshape_shape = (-1, 2) | |||
if self.cls_out_channels <= 0: | |||
raise ValueError('num_classes={} is too small'.format(num_classes)) | |||
self.num_pre = cfg.rpn_proposal_nms_pre | |||
self.min_box_size = cfg.rpn_proposal_min_bbox_size | |||
self.nms_thr = cfg.rpn_proposal_nms_thr | |||
self.nms_post = cfg.rpn_proposal_nms_post | |||
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels | |||
self.max_num = cfg.rpn_proposal_max_num | |||
self.num_levels = cfg.fpn_num_outs | |||
# Op Define | |||
self.squeeze = P.Squeeze() | |||
self.reshape = P.Reshape() | |||
self.cast = P.Cast() | |||
self.feature_shapes = cfg.feature_shapes | |||
self.transpose_shape = (1, 2, 0) | |||
self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \ | |||
means=self.target_means, \ | |||
stds=self.target_stds) | |||
self.nms = P.NMSWithMask(self.nms_thr) | |||
self.concat_axis0 = P.Concat(axis=0) | |||
self.concat_axis1 = P.Concat(axis=1) | |||
self.split = P.Split(axis=1, output_num=5) | |||
self.min = P.Minimum() | |||
self.gatherND = P.GatherNd() | |||
self.slice = P.Slice() | |||
self.select = P.Select() | |||
self.greater = P.Greater() | |||
self.transpose = P.Transpose() | |||
self.tile = P.Tile() | |||
self.set_train_local(config, training=True) | |||
self.multi_10 = Tensor(10.0, mstype.float16) | |||
def set_train_local(self, config, training=True): | |||
"""Set training flag.""" | |||
self.training_local = training | |||
cfg = config | |||
self.topK_stage1 = () | |||
self.topK_shape = () | |||
total_max_topk_input = 0 | |||
if not self.training_local: | |||
self.num_pre = cfg.rpn_nms_pre | |||
self.min_box_size = cfg.rpn_min_bbox_min_size | |||
self.nms_thr = cfg.rpn_nms_thr | |||
self.nms_post = cfg.rpn_nms_post | |||
self.nms_across_levels = cfg.rpn_nms_across_levels | |||
self.max_num = cfg.rpn_max_num | |||
for shp in self.feature_shapes: | |||
k_num = min(self.num_pre, (shp[0] * shp[1] * 3)) | |||
total_max_topk_input += k_num | |||
self.topK_stage1 += (k_num,) | |||
self.topK_shape += ((k_num, 1),) | |||
self.topKv2 = P.TopK(sorted=True) | |||
self.topK_shape_stage2 = (self.max_num, 1) | |||
self.min_float_num = -65536.0 | |||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16)) | |||
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list): | |||
proposals_tuple = () | |||
masks_tuple = () | |||
for img_id in range(self.batch_size): | |||
cls_score_list = () | |||
bbox_pred_list = () | |||
for i in range(self.num_levels): | |||
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::]) | |||
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::]) | |||
cls_score_list = cls_score_list + (rpn_cls_score_i,) | |||
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,) | |||
proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list) | |||
proposals_tuple += (proposals,) | |||
masks_tuple += (masks,) | |||
return proposals_tuple, masks_tuple | |||
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors): | |||
"""Get proposal boundingbox.""" | |||
mlvl_proposals = () | |||
mlvl_mask = () | |||
for idx in range(self.num_levels): | |||
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape) | |||
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape) | |||
anchors = mlvl_anchors[idx] | |||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape) | |||
rpn_cls_score = self.activation(rpn_cls_score) | |||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16) | |||
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16) | |||
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx]) | |||
topk_inds = self.reshape(topk_inds, self.topK_shape[idx]) | |||
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds) | |||
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16) | |||
proposals_decode = self.decode(anchors_sorted, bboxes_sorted) | |||
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx]))) | |||
proposals, _, mask_valid = self.nms(proposals_decode) | |||
mlvl_proposals = mlvl_proposals + (proposals,) | |||
mlvl_mask = mlvl_mask + (mask_valid,) | |||
proposals = self.concat_axis0(mlvl_proposals) | |||
masks = self.concat_axis0(mlvl_mask) | |||
_, _, _, _, scores = self.split(proposals) | |||
scores = self.squeeze(scores) | |||
topk_mask = self.cast(self.topK_mask, mstype.float16) | |||
scores_using = self.select(masks, scores, topk_mask) | |||
_, topk_inds = self.topKv2(scores_using, self.max_num) | |||
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2) | |||
proposals = self.gatherND(proposals, topk_inds) | |||
masks = self.gatherND(masks, topk_inds) | |||
return proposals, masks |
@@ -0,0 +1,173 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn Rcnn network.""" | |||
import numpy as np | |||
import mindspore.common.dtype as mstype | |||
import mindspore.nn as nn | |||
from mindspore.ops import operations as P | |||
from mindspore.common.tensor import Tensor | |||
from mindspore.common.initializer import initializer | |||
from mindspore.common.parameter import Parameter | |||
# pylint: disable=locally-disabled, missing-docstring | |||
class DenseNoTranpose(nn.Cell): | |||
"""Dense method""" | |||
def __init__(self, input_channels, output_channels, weight_init): | |||
super(DenseNoTranpose, self).__init__() | |||
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16), | |||
name="weight") | |||
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias") | |||
self.matmul = P.MatMul(transpose_b=False) | |||
self.bias_add = P.BiasAdd() | |||
def construct(self, x): | |||
output = self.bias_add(self.matmul(x, self.weight), self.bias) | |||
return output | |||
class Rcnn(nn.Cell): | |||
""" | |||
Rcnn subnet. | |||
Args: | |||
config (dict) - Config. | |||
representation_size (int) - Channels of shared dense. | |||
batch_size (int) - Batchsize. | |||
num_classes (int) - Class number. | |||
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]). | |||
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2). | |||
Returns: | |||
Tuple, tuple of output tensor. | |||
Examples: | |||
Rcnn(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \ | |||
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2)) | |||
""" | |||
def __init__(self, | |||
config, | |||
representation_size, | |||
batch_size, | |||
num_classes, | |||
target_means=(0., 0., 0., 0.), | |||
target_stds=(0.1, 0.1, 0.2, 0.2) | |||
): | |||
super(Rcnn, self).__init__() | |||
cfg = config | |||
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16)) | |||
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16)) | |||
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels | |||
self.target_means = target_means | |||
self.target_stds = target_stds | |||
self.num_classes = num_classes | |||
self.in_channels = cfg.rcnn_in_channels | |||
self.train_batch_size = batch_size | |||
self.test_batch_size = cfg.test_batch_size | |||
shape_0 = (self.rcnn_fc_out_channels, representation_size) | |||
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor() | |||
shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels) | |||
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor() | |||
self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0) | |||
self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1) | |||
cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1], | |||
dtype=mstype.float16).to_tensor() | |||
reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1], | |||
dtype=mstype.float16).to_tensor() | |||
self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight) | |||
self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight) | |||
self.flatten = P.Flatten() | |||
self.relu = P.ReLU() | |||
self.logicaland = P.LogicalAnd() | |||
self.loss_cls = P.SoftmaxCrossEntropyWithLogits() | |||
self.loss_bbox = P.SmoothL1Loss(beta=1.0) | |||
self.reshape = P.Reshape() | |||
self.onehot = P.OneHot() | |||
self.greater = P.Greater() | |||
self.cast = P.Cast() | |||
self.sum_loss = P.ReduceSum() | |||
self.tile = P.Tile() | |||
self.expandims = P.ExpandDims() | |||
self.gather = P.GatherNd() | |||
self.argmax = P.ArgMaxWithValue(axis=1) | |||
self.on_value = Tensor(1.0, mstype.float32) | |||
self.off_value = Tensor(0.0, mstype.float32) | |||
self.value = Tensor(1.0, mstype.float16) | |||
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size | |||
rmv_first = np.ones((self.num_bboxes, self.num_classes)) | |||
rmv_first[:, 0] = np.zeros((self.num_bboxes,)) | |||
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16)) | |||
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size | |||
range_max = np.arange(self.num_bboxes_test).astype(np.int32) | |||
self.range_max = Tensor(range_max) | |||
def construct(self, featuremap, bbox_targets, labels, mask): | |||
x = self.flatten(featuremap) | |||
x = self.relu(self.shared_fc_0(x)) | |||
x = self.relu(self.shared_fc_1(x)) | |||
x_cls = self.cls_scores(x) | |||
x_reg = self.reg_scores(x) | |||
if self.training: | |||
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels | |||
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16) | |||
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) | |||
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask) | |||
out = (loss, loss_cls, loss_reg, loss_print) | |||
else: | |||
out = (x_cls, (x_cls / self.value), x_reg, x_cls) | |||
return out | |||
def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights): | |||
"""Loss method.""" | |||
loss_print = () | |||
loss_cls, _ = self.loss_cls(cls_score, labels) | |||
weights = self.cast(weights, mstype.float16) | |||
loss_cls = loss_cls * weights | |||
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,)) | |||
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), | |||
mstype.float16) | |||
bbox_weights = bbox_weights * self.rmv_first_tensor | |||
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4)) | |||
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets) | |||
loss_reg = self.sum_loss(loss_reg, (2,)) | |||
loss_reg = loss_reg * bbox_weights | |||
loss_reg = loss_reg / self.sum_loss(weights, (0,)) | |||
loss_reg = self.sum_loss(loss_reg, (0, 1)) | |||
loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg | |||
loss_print += (loss_cls, loss_reg) | |||
return loss, loss_cls, loss_reg, loss_print |
@@ -0,0 +1,250 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Resnet50 backbone.""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.ops import operations as P | |||
from mindspore.common.tensor import Tensor | |||
from mindspore.ops import functional as F | |||
from mindspore import context | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
def weight_init_ones(shape): | |||
"""Weight init.""" | |||
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16)) | |||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): | |||
"""Conv2D wrapper.""" | |||
shape = (out_channels, in_channels, kernel_size, kernel_size) | |||
weights = weight_init_ones(shape) | |||
return nn.Conv2d(in_channels, out_channels, | |||
kernel_size=kernel_size, stride=stride, padding=padding, | |||
pad_mode=pad_mode, weight_init=weights, has_bias=False) | |||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True): | |||
"""Batchnorm2D wrapper.""" | |||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) | |||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) | |||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) | |||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) | |||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, | |||
beta_init=beta_init, moving_mean_init=moving_mean_init, | |||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) | |||
class ResNetFea(nn.Cell): | |||
""" | |||
ResNet architecture. | |||
Args: | |||
block (Cell): Block for network. | |||
layer_nums (list): Numbers of block in different layers. | |||
in_channels (list): Input channel in each layer. | |||
out_channels (list): Output channel in each layer. | |||
weights_update (bool): Weight update flag. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> ResNet(ResidualBlock, | |||
>>> [3, 4, 6, 3], | |||
>>> [64, 256, 512, 1024], | |||
>>> [256, 512, 1024, 2048], | |||
>>> False) | |||
""" | |||
def __init__(self, | |||
block, | |||
layer_nums, | |||
in_channels, | |||
out_channels, | |||
weights_update=False): | |||
super(ResNetFea, self).__init__() | |||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | |||
raise ValueError("the length of " | |||
"layer_num, inchannel, outchannel list must be 4!") | |||
bn_training = False | |||
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') | |||
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training) | |||
self.relu = P.ReLU() | |||
self.maxpool = P.MaxPool(ksize=3, strides=2, padding="SAME") | |||
self.weights_update = weights_update | |||
if not self.weights_update: | |||
self.conv1.weight.requires_grad = False | |||
self.layer1 = self._make_layer(block, | |||
layer_nums[0], | |||
in_channel=in_channels[0], | |||
out_channel=out_channels[0], | |||
stride=1, | |||
training=bn_training, | |||
weights_update=self.weights_update) | |||
self.layer2 = self._make_layer(block, | |||
layer_nums[1], | |||
in_channel=in_channels[1], | |||
out_channel=out_channels[1], | |||
stride=2, | |||
training=bn_training, | |||
weights_update=True) | |||
self.layer3 = self._make_layer(block, | |||
layer_nums[2], | |||
in_channel=in_channels[2], | |||
out_channel=out_channels[2], | |||
stride=2, | |||
training=bn_training, | |||
weights_update=True) | |||
self.layer4 = self._make_layer(block, | |||
layer_nums[3], | |||
in_channel=in_channels[3], | |||
out_channel=out_channels[3], | |||
stride=2, | |||
training=bn_training, | |||
weights_update=True) | |||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False): | |||
"""Make block layer.""" | |||
layers = [] | |||
down_sample = False | |||
if stride != 1 or in_channel != out_channel: | |||
down_sample = True | |||
resblk = block(in_channel, | |||
out_channel, | |||
stride=stride, | |||
down_sample=down_sample, | |||
training=training, | |||
weights_update=weights_update) | |||
layers.append(resblk) | |||
for _ in range(1, layer_num): | |||
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update) | |||
layers.append(resblk) | |||
return nn.SequentialCell(layers) | |||
def construct(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
c1 = self.maxpool(x) | |||
c2 = self.layer1(c1) | |||
identity = c2 | |||
if not self.weights_update: | |||
identity = F.stop_gradient(c2) | |||
c3 = self.layer2(identity) | |||
c4 = self.layer3(c3) | |||
c5 = self.layer4(c4) | |||
return identity, c3, c4, c5 | |||
class ResidualBlockUsing(nn.Cell): | |||
""" | |||
ResNet V1 residual block definition. | |||
Args: | |||
in_channels (int) - Input channel. | |||
out_channels (int) - Output channel. | |||
stride (int) - Stride size for the initial convolutional layer. Default: 1. | |||
down_sample (bool) - If to do the downsample in block. Default: False. | |||
momentum (float) - Momentum for batchnorm layer. Default: 0.1. | |||
training (bool) - Training flag. Default: False. | |||
weights_updata (bool) - Weights update flag. Default: False. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
ResidualBlock(3,256,stride=2,down_sample=True) | |||
""" | |||
expansion = 4 | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
stride=1, | |||
down_sample=False, | |||
momentum=0.1, | |||
training=False, | |||
weights_update=False): | |||
super(ResidualBlockUsing, self).__init__() | |||
self.affine = weights_update | |||
out_chls = out_channels // self.expansion | |||
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0) | |||
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training) | |||
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1) | |||
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training) | |||
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0) | |||
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training) | |||
if training: | |||
self.bn1 = self.bn1.set_train() | |||
self.bn2 = self.bn2.set_train() | |||
self.bn3 = self.bn3.set_train() | |||
if not weights_update: | |||
self.conv1.weight.requires_grad = False | |||
self.conv2.weight.requires_grad = False | |||
self.conv3.weight.requires_grad = False | |||
self.relu = P.ReLU() | |||
self.downsample = down_sample | |||
if self.downsample: | |||
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0) | |||
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, | |||
use_batch_statistics=training) | |||
if training: | |||
self.bn_down_sample = self.bn_down_sample.set_train() | |||
if not weights_update: | |||
self.conv_down_sample.weight.requires_grad = False | |||
self.add = P.TensorAdd() | |||
def construct(self, x): | |||
identity = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample: | |||
identity = self.conv_down_sample(identity) | |||
identity = self.bn_down_sample(identity) | |||
out = self.add(out, identity) | |||
out = self.relu(out) | |||
return out |
@@ -0,0 +1,181 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn ROIAlign module.""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
from mindspore.ops import operations as P | |||
from mindspore.ops import composite as C | |||
from mindspore.nn import layer as L | |||
from mindspore.common.tensor import Tensor | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
class ROIAlign(nn.Cell): | |||
""" | |||
Extract RoI features from mulitple feature map. | |||
Args: | |||
out_size_h (int) - RoI height. | |||
out_size_w (int) - RoI width. | |||
spatial_scale (int) - RoI spatial scale. | |||
sample_num (int) - RoI sample number. | |||
""" | |||
def __init__(self, | |||
out_size_h, | |||
out_size_w, | |||
spatial_scale, | |||
sample_num=0): | |||
super(ROIAlign, self).__init__() | |||
self.out_size = (out_size_h, out_size_w) | |||
self.spatial_scale = float(spatial_scale) | |||
self.sample_num = int(sample_num) | |||
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1], | |||
self.spatial_scale, self.sample_num) | |||
def construct(self, features, rois): | |||
return self.align_op(features, rois) | |||
def __repr__(self): | |||
format_str = self.__class__.__name__ | |||
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format( | |||
self.out_size, self.spatial_scale, self.sample_num) | |||
return format_str | |||
class SingleRoIExtractor(nn.Cell): | |||
""" | |||
Extract RoI features from a single level feature map. | |||
If there are mulitple input feature levels, each RoI is mapped to a level | |||
according to its scale. | |||
Args: | |||
config (dict): Config | |||
roi_layer (dict): Specify RoI layer type and arguments. | |||
out_channels (int): Output channels of RoI layers. | |||
featmap_strides (int): Strides of input feature maps. | |||
batch_size (int): Batchsize. | |||
finest_scale (int): Scale threshold of mapping to level 0. | |||
""" | |||
def __init__(self, | |||
config, | |||
roi_layer, | |||
out_channels, | |||
featmap_strides, | |||
batch_size=1, | |||
finest_scale=56): | |||
super(SingleRoIExtractor, self).__init__() | |||
cfg = config | |||
self.train_batch_size = batch_size | |||
self.out_channels = out_channels | |||
self.featmap_strides = featmap_strides | |||
self.num_levels = len(self.featmap_strides) | |||
self.out_size = roi_layer['out_size'] | |||
self.sample_num = roi_layer['sample_num'] | |||
self.roi_layers = self.build_roi_layers(self.featmap_strides) | |||
self.roi_layers = L.CellList(self.roi_layers) | |||
self.sqrt = P.Sqrt() | |||
self.log = P.Log() | |||
self.finest_scale_ = finest_scale | |||
self.clamp = C.clip_by_value | |||
self.cast = P.Cast() | |||
self.equal = P.Equal() | |||
self.select = P.Select() | |||
_mode_16 = False | |||
self.dtype = np.float16 if _mode_16 else np.float32 | |||
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32 | |||
self.set_train_local(cfg, training=True) | |||
def set_train_local(self, config, training=True): | |||
"""Set training flag.""" | |||
self.training_local = training | |||
cfg = config | |||
# Init tensor | |||
self.batch_size = cfg.roi_sample_num if self.training_local else cfg.rpn_max_num | |||
self.batch_size = self.train_batch_size*self.batch_size \ | |||
if self.training_local else cfg.test_batch_size*self.batch_size | |||
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)) | |||
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_ | |||
self.finest_scale = Tensor(finest_scale) | |||
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6)) | |||
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32)) | |||
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1)) | |||
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2) | |||
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels, | |||
self.out_size, self.out_size)), dtype=self.dtype)) | |||
def num_inputs(self): | |||
return len(self.featmap_strides) | |||
def init_weights(self): | |||
pass | |||
def log2(self, value): | |||
return self.log(value) / self.log(self.twos) | |||
def build_roi_layers(self, featmap_strides): | |||
roi_layers = [] | |||
for s in featmap_strides: | |||
layer_cls = ROIAlign(self.out_size, self.out_size, | |||
spatial_scale=1 / s, | |||
sample_num=self.sample_num) | |||
roi_layers.append(layer_cls) | |||
return roi_layers | |||
def _c_map_roi_levels(self, rois): | |||
"""Map rois to corresponding feature levels by scales. | |||
- scale < finest_scale * 2: level 0 | |||
- finest_scale * 2 <= scale < finest_scale * 4: level 1 | |||
- finest_scale * 4 <= scale < finest_scale * 8: level 2 | |||
- scale >= finest_scale * 8: level 3 | |||
Args: | |||
rois (Tensor): Input RoIs, shape (k, 5). | |||
num_levels (int): Total level number. | |||
Returns: | |||
Tensor: Level index (0-based) of each RoI, shape (k, ) | |||
""" | |||
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \ | |||
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones) | |||
target_lvls = self.log2(scale / self.finest_scale + self.epslion) | |||
target_lvls = P.Floor()(target_lvls) | |||
target_lvls = self.cast(target_lvls, mstype.int32) | |||
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels) | |||
return target_lvls | |||
def construct(self, rois, feat1, feat2, feat3, feat4): | |||
feats = (feat1, feat2, feat3, feat4) | |||
res = self.res_ | |||
target_lvls = self._c_map_roi_levels(rois) | |||
for i in range(self.num_levels): | |||
mask = self.equal(target_lvls, P.ScalarToArray()(i)) | |||
mask = P.Reshape()(mask, (-1, 1, 1, 1)) | |||
roi_feats_t = self.roi_layers[i](feats[i], rois) | |||
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, 7, 7)), mstype.bool_) | |||
res = self.select(mask, roi_feats_t, res) | |||
return res |
@@ -0,0 +1,315 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""RPN for fasterRCNN""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
from mindspore.ops import operations as P | |||
from mindspore import Tensor | |||
from mindspore.ops import functional as F | |||
from mindspore.common.initializer import initializer | |||
from .bbox_assign_sample import BboxAssignSample | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
# pylint: disable=locally-disabled, invalid-name, missing-docstring | |||
class RpnRegClsBlock(nn.Cell): | |||
""" | |||
Rpn reg cls block for rpn layer | |||
Args: | |||
in_channels (int) - Input channels of shared convolution. | |||
feat_channels (int) - Output channels of shared convolution. | |||
num_anchors (int) - The anchor number. | |||
cls_out_channels (int) - Output channels of classification convolution. | |||
weight_conv (Tensor) - weight init for rpn conv. | |||
bias_conv (Tensor) - bias init for rpn conv. | |||
weight_cls (Tensor) - weight init for rpn cls conv. | |||
bias_cls (Tensor) - bias init for rpn cls conv. | |||
weight_reg (Tensor) - weight init for rpn reg conv. | |||
bias_reg (Tensor) - bias init for rpn reg conv. | |||
Returns: | |||
Tensor, output tensor. | |||
""" | |||
def __init__(self, | |||
in_channels, | |||
feat_channels, | |||
num_anchors, | |||
cls_out_channels, | |||
weight_conv, | |||
bias_conv, | |||
weight_cls, | |||
bias_cls, | |||
weight_reg, | |||
bias_reg): | |||
super(RpnRegClsBlock, self).__init__() | |||
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same', | |||
has_bias=True, weight_init=weight_conv, bias_init=bias_conv) | |||
self.relu = nn.ReLU() | |||
self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid', | |||
has_bias=True, weight_init=weight_cls, bias_init=bias_cls) | |||
self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid', | |||
has_bias=True, weight_init=weight_reg, bias_init=bias_reg) | |||
def construct(self, x): | |||
x = self.relu(self.rpn_conv(x)) | |||
x1 = self.rpn_cls(x) | |||
x2 = self.rpn_reg(x) | |||
return x1, x2 | |||
class RPN(nn.Cell): | |||
""" | |||
ROI proposal network.. | |||
Args: | |||
config (dict) - Config. | |||
batch_size (int) - Batchsize. | |||
in_channels (int) - Input channels of shared convolution. | |||
feat_channels (int) - Output channels of shared convolution. | |||
num_anchors (int) - The anchor number. | |||
cls_out_channels (int) - Output channels of classification convolution. | |||
Returns: | |||
Tuple, tuple of output tensor. | |||
Examples: | |||
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024, | |||
num_anchors=3, cls_out_channels=512) | |||
""" | |||
def __init__(self, | |||
config, | |||
batch_size, | |||
in_channels, | |||
feat_channels, | |||
num_anchors, | |||
cls_out_channels): | |||
super(RPN, self).__init__() | |||
cfg_rpn = config | |||
self.num_bboxes = cfg_rpn.num_bboxes | |||
self.slice_index = () | |||
self.feature_anchor_shape = () | |||
self.slice_index += (0,) | |||
index = 0 | |||
for shape in cfg_rpn.feature_shapes: | |||
self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,) | |||
self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,) | |||
index += 1 | |||
self.num_anchors = num_anchors | |||
self.batch_size = batch_size | |||
self.test_batch_size = cfg_rpn.test_batch_size | |||
self.num_layers = 5 | |||
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16)) | |||
self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels, | |||
num_anchors, cls_out_channels)) | |||
self.transpose = P.Transpose() | |||
self.reshape = P.Reshape() | |||
self.concat = P.Concat(axis=0) | |||
self.fill = P.Fill() | |||
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16)) | |||
self.trans_shape = (0, 2, 3, 1) | |||
self.reshape_shape_reg = (-1, 4) | |||
self.reshape_shape_cls = (-1,) | |||
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16)) | |||
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16)) | |||
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16)) | |||
self.num_bboxes = cfg_rpn.num_bboxes | |||
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) | |||
self.CheckValid = P.CheckValid() | |||
self.sum_loss = P.ReduceSum() | |||
self.loss_cls = P.SigmoidCrossEntropyWithLogits() | |||
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0) | |||
self.squeeze = P.Squeeze() | |||
self.cast = P.Cast() | |||
self.tile = P.Tile() | |||
self.zeros_like = P.ZerosLike() | |||
self.loss = Tensor(np.zeros((1,)).astype(np.float16)) | |||
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||
self.regloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): | |||
""" | |||
make rpn layer for rpn proposal network | |||
Args: | |||
num_layers (int) - layer num. | |||
in_channels (int) - Input channels of shared convolution. | |||
feat_channels (int) - Output channels of shared convolution. | |||
num_anchors (int) - The anchor number. | |||
cls_out_channels (int) - Output channels of classification convolution. | |||
Returns: | |||
List, list of RpnRegClsBlock cells. | |||
""" | |||
rpn_layer = [] | |||
shp_weight_conv = (feat_channels, in_channels, 3, 3) | |||
shp_bias_conv = (feat_channels,) | |||
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor() | |||
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor() | |||
shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) | |||
shp_bias_cls = (num_anchors * cls_out_channels,) | |||
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor() | |||
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor() | |||
shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) | |||
shp_bias_reg = (num_anchors * 4,) | |||
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor() | |||
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor() | |||
for i in range(num_layers): | |||
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ | |||
weight_conv, bias_conv, weight_cls, \ | |||
bias_cls, weight_reg, bias_reg)) | |||
for i in range(1, num_layers): | |||
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight | |||
rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight | |||
rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight | |||
rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias | |||
rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias | |||
rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias | |||
return rpn_layer | |||
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids): | |||
loss_print = () | |||
rpn_cls_score = () | |||
rpn_bbox_pred = () | |||
rpn_cls_score_total = () | |||
rpn_bbox_pred_total = () | |||
for i in range(self.num_layers): | |||
x1, x2 = self.rpn_convs_list[i](inputs[i]) | |||
rpn_cls_score_total = rpn_cls_score_total + (x1,) | |||
rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,) | |||
x1 = self.transpose(x1, self.trans_shape) | |||
x1 = self.reshape(x1, self.reshape_shape_cls) | |||
x2 = self.transpose(x2, self.trans_shape) | |||
x2 = self.reshape(x2, self.reshape_shape_reg) | |||
rpn_cls_score = rpn_cls_score + (x1,) | |||
rpn_bbox_pred = rpn_bbox_pred + (x2,) | |||
loss = self.loss | |||
clsloss = self.clsloss | |||
regloss = self.regloss | |||
bbox_targets = () | |||
bbox_weights = () | |||
labels = () | |||
label_weights = () | |||
output = () | |||
if self.training: | |||
for i in range(self.batch_size): | |||
multi_level_flags = () | |||
anchor_list_tuple = () | |||
for j in range(self.num_layers): | |||
res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])), | |||
mstype.int32) | |||
multi_level_flags = multi_level_flags + (res,) | |||
anchor_list_tuple = anchor_list_tuple + (anchor_list[j],) | |||
valid_flag_list = self.concat(multi_level_flags) | |||
anchor_using_list = self.concat(anchor_list_tuple) | |||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) | |||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) | |||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) | |||
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i, | |||
gt_labels_i, | |||
self.cast(valid_flag_list, | |||
mstype.bool_), | |||
anchor_using_list, gt_valids_i) | |||
bbox_weight = self.cast(bbox_weight, mstype.float16) | |||
label = self.cast(label, mstype.float16) | |||
label_weight = self.cast(label_weight, mstype.float16) | |||
for j in range(self.num_layers): | |||
begin = self.slice_index[j] | |||
end = self.slice_index[j + 1] | |||
stride = 1 | |||
bbox_targets += (bbox_target[begin:end:stride, ::],) | |||
bbox_weights += (bbox_weight[begin:end:stride],) | |||
labels += (label[begin:end:stride],) | |||
label_weights += (label_weight[begin:end:stride],) | |||
for i in range(self.num_layers): | |||
bbox_target_using = () | |||
bbox_weight_using = () | |||
label_using = () | |||
label_weight_using = () | |||
for j in range(self.batch_size): | |||
bbox_target_using += (bbox_targets[i + (self.num_layers * j)],) | |||
bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],) | |||
label_using += (labels[i + (self.num_layers * j)],) | |||
label_weight_using += (label_weights[i + (self.num_layers * j)],) | |||
bbox_target_with_batchsize = self.concat(bbox_target_using) | |||
bbox_weight_with_batchsize = self.concat(bbox_weight_using) | |||
label_with_batchsize = self.concat(label_using) | |||
label_weight_with_batchsize = self.concat(label_weight_using) | |||
# stop | |||
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize) | |||
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize) | |||
label_ = F.stop_gradient(label_with_batchsize) | |||
label_weight_ = F.stop_gradient(label_weight_with_batchsize) | |||
cls_score_i = rpn_cls_score[i] | |||
reg_score_i = rpn_bbox_pred[i] | |||
loss_cls = self.loss_cls(cls_score_i, label_) | |||
loss_cls_item = loss_cls * label_weight_ | |||
loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total | |||
loss_reg = self.loss_bbox(reg_score_i, bbox_target_) | |||
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4)) | |||
loss_reg = loss_reg * bbox_weight_ | |||
loss_reg_item = self.sum_loss(loss_reg, (1,)) | |||
loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total | |||
loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item | |||
loss += loss_total | |||
loss_print += (loss_total, loss_cls_item, loss_reg_item) | |||
clsloss += loss_cls_item | |||
regloss += loss_reg_item | |||
output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print) | |||
else: | |||
output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1) | |||
return output |
@@ -0,0 +1,158 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# =========================================================================== | |||
""" | |||
network config setting, will be used in train.py and eval.py | |||
""" | |||
from easydict import EasyDict as ed | |||
config = ed({ | |||
"img_width": 1280, | |||
"img_height": 768, | |||
"keep_ratio": False, | |||
"flip_ratio": 0.5, | |||
"photo_ratio": 0.5, | |||
"expand_ratio": 1.0, | |||
# anchor | |||
"feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)], | |||
"anchor_scales": [8], | |||
"anchor_ratios": [0.5, 1.0, 2.0], | |||
"anchor_strides": [4, 8, 16, 32, 64], | |||
"num_anchors": 3, | |||
# resnet | |||
"resnet_block": [3, 4, 6, 3], | |||
"resnet_in_channels": [64, 256, 512, 1024], | |||
"resnet_out_channels": [256, 512, 1024, 2048], | |||
# fpn | |||
"fpn_in_channels": [256, 512, 1024, 2048], | |||
"fpn_out_channels": 256, | |||
"fpn_num_outs": 5, | |||
# rpn | |||
"rpn_in_channels": 256, | |||
"rpn_feat_channels": 256, | |||
"rpn_loss_cls_weight": 1.0, | |||
"rpn_loss_reg_weight": 1.0, | |||
"rpn_cls_out_channels": 1, | |||
"rpn_target_means": [0., 0., 0., 0.], | |||
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0], | |||
# bbox_assign_sampler | |||
"neg_iou_thr": 0.3, | |||
"pos_iou_thr": 0.7, | |||
"min_pos_iou": 0.3, | |||
"num_bboxes": 245520, | |||
"num_gts": 128, | |||
"num_expected_neg": 256, | |||
"num_expected_pos": 128, | |||
# proposal | |||
"activate_num_classes": 2, | |||
"use_sigmoid_cls": True, | |||
# roi_align | |||
"roi_layer": dict(type='RoIAlign', out_size=7, sample_num=2), | |||
"roi_align_out_channels": 256, | |||
"roi_align_featmap_strides": [4, 8, 16, 32], | |||
"roi_align_finest_scale": 56, | |||
"roi_sample_num": 640, | |||
# bbox_assign_sampler_stage2 | |||
"neg_iou_thr_stage2": 0.5, | |||
"pos_iou_thr_stage2": 0.5, | |||
"min_pos_iou_stage2": 0.5, | |||
"num_bboxes_stage2": 2000, | |||
"num_expected_pos_stage2": 128, | |||
"num_expected_neg_stage2": 512, | |||
"num_expected_total_stage2": 512, | |||
# rcnn | |||
"rcnn_num_layers": 2, | |||
"rcnn_in_channels": 256, | |||
"rcnn_fc_out_channels": 1024, | |||
"rcnn_loss_cls_weight": 1, | |||
"rcnn_loss_reg_weight": 1, | |||
"rcnn_target_means": [0., 0., 0., 0.], | |||
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2], | |||
# train proposal | |||
"rpn_proposal_nms_across_levels": False, | |||
"rpn_proposal_nms_pre": 2000, | |||
"rpn_proposal_nms_post": 2000, | |||
"rpn_proposal_max_num": 2000, | |||
"rpn_proposal_nms_thr": 0.7, | |||
"rpn_proposal_min_bbox_size": 0, | |||
# test proposal | |||
"rpn_nms_across_levels": False, | |||
"rpn_nms_pre": 1000, | |||
"rpn_nms_post": 1000, | |||
"rpn_max_num": 1000, | |||
"rpn_nms_thr": 0.7, | |||
"rpn_min_bbox_min_size": 0, | |||
"test_score_thr": 0.05, | |||
"test_iou_thr": 0.5, | |||
"test_max_per_img": 100, | |||
"test_batch_size": 1, | |||
"rpn_head_loss_type": "CrossEntropyLoss", | |||
"rpn_head_use_sigmoid": True, | |||
"rpn_head_weight": 1.0, | |||
# LR | |||
"base_lr": 0.02, | |||
"base_step": 58633, | |||
"total_epoch": 13, | |||
"warmup_step": 500, | |||
"warmup_mode": "linear", | |||
"warmup_ratio": 1/3.0, | |||
"sgd_step": [8, 11], | |||
"sgd_momentum": 0.9, | |||
# train | |||
"batch_size": 1, | |||
"loss_scale": 1, | |||
"momentum": 0.91, | |||
"weight_decay": 1e-4, | |||
"epoch_size": 12, | |||
"save_checkpoint": True, | |||
"save_checkpoint_epochs": 1, | |||
"keep_checkpoint_max": 10, | |||
"save_checkpoint_path": "./", | |||
"mindrecord_dir": "../MindRecord_COCO_TRAIN", | |||
"coco_root": "./cocodataset/", | |||
"train_data_type": "train2017", | |||
"val_data_type": "val2017", | |||
"instance_set": "annotations/instances_{}.json", | |||
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||
'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||
'teddy bear', 'hair drier', 'toothbrush'), | |||
"num_classes": 81 | |||
}) |
@@ -0,0 +1,505 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn dataset""" | |||
from __future__ import division | |||
import os | |||
import numpy as np | |||
from numpy import random | |||
import mmcv | |||
import mindspore.dataset as de | |||
import mindspore.dataset.vision.c_transforms as C | |||
import mindspore.dataset.transforms.c_transforms as CC | |||
import mindspore.common.dtype as mstype | |||
from mindspore.mindrecord import FileWriter | |||
from src.config import config | |||
# pylint: disable=locally-disabled, unused-variable | |||
def bbox_overlaps(bboxes1, bboxes2, mode='iou'): | |||
"""Calculate the ious between each bbox of bboxes1 and bboxes2. | |||
Args: | |||
bboxes1(ndarray): shape (n, 4) | |||
bboxes2(ndarray): shape (k, 4) | |||
mode(str): iou (intersection over union) or iof (intersection | |||
over foreground) | |||
Returns: | |||
ious(ndarray): shape (n, k) | |||
""" | |||
assert mode in ['iou', 'iof'] | |||
bboxes1 = bboxes1.astype(np.float32) | |||
bboxes2 = bboxes2.astype(np.float32) | |||
rows = bboxes1.shape[0] | |||
cols = bboxes2.shape[0] | |||
ious = np.zeros((rows, cols), dtype=np.float32) | |||
if rows * cols == 0: | |||
return ious | |||
exchange = False | |||
if bboxes1.shape[0] > bboxes2.shape[0]: | |||
bboxes1, bboxes2 = bboxes2, bboxes1 | |||
ious = np.zeros((cols, rows), dtype=np.float32) | |||
exchange = True | |||
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1) | |||
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1) | |||
for i in range(bboxes1.shape[0]): | |||
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) | |||
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) | |||
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) | |||
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) | |||
overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( | |||
y_end - y_start + 1, 0) | |||
if mode == 'iou': | |||
union = area1[i] + area2 - overlap | |||
else: | |||
union = area1[i] if not exchange else area2 | |||
ious[i, :] = overlap / union | |||
if exchange: | |||
ious = ious.T | |||
return ious | |||
class PhotoMetricDistortion: | |||
"""Photo Metric Distortion""" | |||
def __init__(self, | |||
brightness_delta=32, | |||
contrast_range=(0.5, 1.5), | |||
saturation_range=(0.5, 1.5), | |||
hue_delta=18): | |||
self.brightness_delta = brightness_delta | |||
self.contrast_lower, self.contrast_upper = contrast_range | |||
self.saturation_lower, self.saturation_upper = saturation_range | |||
self.hue_delta = hue_delta | |||
def __call__(self, img, boxes, labels): | |||
# random brightness | |||
img = img.astype('float32') | |||
if random.randint(2): | |||
delta = random.uniform(-self.brightness_delta, | |||
self.brightness_delta) | |||
img += delta | |||
# mode == 0 --> do random contrast first | |||
# mode == 1 --> do random contrast last | |||
mode = random.randint(2) | |||
if mode == 1: | |||
if random.randint(2): | |||
alpha = random.uniform(self.contrast_lower, | |||
self.contrast_upper) | |||
img *= alpha | |||
# convert color from BGR to HSV | |||
img = mmcv.bgr2hsv(img) | |||
# random saturation | |||
if random.randint(2): | |||
img[..., 1] *= random.uniform(self.saturation_lower, | |||
self.saturation_upper) | |||
# random hue | |||
if random.randint(2): | |||
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) | |||
img[..., 0][img[..., 0] > 360] -= 360 | |||
img[..., 0][img[..., 0] < 0] += 360 | |||
# convert color from HSV to BGR | |||
img = mmcv.hsv2bgr(img) | |||
# random contrast | |||
if mode == 0: | |||
if random.randint(2): | |||
alpha = random.uniform(self.contrast_lower, | |||
self.contrast_upper) | |||
img *= alpha | |||
# randomly swap channels | |||
if random.randint(2): | |||
img = img[..., random.permutation(3)] | |||
return img, boxes, labels | |||
class Expand: | |||
"""expand image""" | |||
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): | |||
if to_rgb: | |||
self.mean = mean[::-1] | |||
else: | |||
self.mean = mean | |||
self.min_ratio, self.max_ratio = ratio_range | |||
def __call__(self, img, boxes, labels): | |||
if random.randint(2): | |||
return img, boxes, labels | |||
h, w, c = img.shape | |||
ratio = random.uniform(self.min_ratio, self.max_ratio) | |||
expand_img = np.full((int(h * ratio), int(w * ratio), c), | |||
self.mean).astype(img.dtype) | |||
left = int(random.uniform(0, w * ratio - w)) | |||
top = int(random.uniform(0, h * ratio - h)) | |||
expand_img[top:top + h, left:left + w] = img | |||
img = expand_img | |||
boxes += np.tile((left, top), 2) | |||
return img, boxes, labels | |||
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""rescale operation for image""" | |||
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True) | |||
if img_data.shape[0] > config.img_height: | |||
img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True) | |||
scale_factor = scale_factor * scale_factor2 | |||
img_shape = np.append(img_shape, scale_factor) | |||
img_shape = np.asarray(img_shape, dtype=np.float32) | |||
gt_bboxes = gt_bboxes * scale_factor | |||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""resize operation for image""" | |||
img_data = img | |||
img_data, w_scale, h_scale = mmcv.imresize( | |||
img_data, (config.img_width, config.img_height), return_scale=True) | |||
scale_factor = np.array( | |||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||
img_shape = (config.img_height, config.img_width, 1.0) | |||
img_shape = np.asarray(img_shape, dtype=np.float32) | |||
gt_bboxes = gt_bboxes * scale_factor | |||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""resize operation for image of eval""" | |||
img_data = img | |||
img_data, w_scale, h_scale = mmcv.imresize( | |||
img_data, (config.img_width, config.img_height), return_scale=True) | |||
scale_factor = np.array( | |||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||
img_shape = np.append(img_shape, (h_scale, w_scale)) | |||
img_shape = np.asarray(img_shape, dtype=np.float32) | |||
gt_bboxes = gt_bboxes * scale_factor | |||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""impad operation for image""" | |||
img_data = mmcv.impad(img, (config.img_height, config.img_width)) | |||
img_data = img_data.astype(np.float32) | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""imnormalize operation for image""" | |||
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True) | |||
img_data = img_data.astype(np.float32) | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""flip operation for image""" | |||
img_data = img | |||
img_data = mmcv.imflip(img_data) | |||
flipped = gt_bboxes.copy() | |||
_, w, _ = img_data.shape | |||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 | |||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 | |||
return (img_data, img_shape, flipped, gt_label, gt_num) | |||
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""flipped generation""" | |||
img_data = img | |||
flipped = gt_bboxes.copy() | |||
_, w, _ = img_data.shape | |||
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 | |||
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 | |||
return (img_data, img_shape, flipped, gt_label, gt_num) | |||
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
img_data = img[:, :, ::-1] | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""transpose operation for image""" | |||
img_data = img.transpose(2, 0, 1).copy() | |||
img_data = img_data.astype(np.float16) | |||
img_shape = img_shape.astype(np.float16) | |||
gt_bboxes = gt_bboxes.astype(np.float16) | |||
gt_label = gt_label.astype(np.int32) | |||
gt_num = gt_num.astype(np.bool) | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""photo crop operation for image""" | |||
random_photo = PhotoMetricDistortion() | |||
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label) | |||
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
"""expand operation for image""" | |||
expand = Expand() | |||
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label) | |||
return (img, img_shape, gt_bboxes, gt_label, gt_num) | |||
def preprocess_fn(image, box, is_training): | |||
"""Preprocess function for dataset.""" | |||
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert): | |||
image_shape = image_shape[:2] | |||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert | |||
if config.keep_ratio: | |||
input_data = rescale_column(*input_data) | |||
else: | |||
input_data = resize_column_test(*input_data) | |||
input_data = image_bgr_rgb(*input_data) | |||
output_data = input_data | |||
return output_data | |||
def _data_aug(image, box, is_training): | |||
"""Data augmentation function.""" | |||
image_bgr = image.copy() | |||
image_bgr[:, :, 0] = image[:, :, 2] | |||
image_bgr[:, :, 1] = image[:, :, 1] | |||
image_bgr[:, :, 2] = image[:, :, 0] | |||
image_shape = image_bgr.shape[:2] | |||
gt_box = box[:, :4] | |||
gt_label = box[:, 4] | |||
gt_iscrowd = box[:, 5] | |||
pad_max_number = 128 | |||
gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0) | |||
gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1) | |||
gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1) | |||
gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32) | |||
if not is_training: | |||
return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert) | |||
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert | |||
if config.keep_ratio: | |||
input_data = rescale_column(*input_data) | |||
else: | |||
input_data = resize_column(*input_data) | |||
input_data = image_bgr_rgb(*input_data) | |||
output_data = input_data | |||
return output_data | |||
return _data_aug(image, box, is_training) | |||
def create_coco_label(is_training): | |||
"""Get image path and annotation from COCO.""" | |||
from pycocotools.coco import COCO | |||
coco_root = config.coco_root | |||
data_type = config.val_data_type | |||
if is_training: | |||
data_type = config.train_data_type | |||
# Classes need to train or test. | |||
train_cls = config.coco_classes | |||
train_cls_dict = {} | |||
for i, cls in enumerate(train_cls): | |||
train_cls_dict[cls] = i | |||
anno_json = os.path.join(coco_root, config.instance_set.format(data_type)) | |||
coco = COCO(anno_json) | |||
classs_dict = {} | |||
cat_ids = coco.loadCats(coco.getCatIds()) | |||
for cat in cat_ids: | |||
classs_dict[cat["id"]] = cat["name"] | |||
image_ids = coco.getImgIds() | |||
image_files = [] | |||
image_anno_dict = {} | |||
for img_id in image_ids: | |||
image_info = coco.loadImgs(img_id) | |||
file_name = image_info[0]["file_name"] | |||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) | |||
anno = coco.loadAnns(anno_ids) | |||
image_path = os.path.join(coco_root, data_type, file_name) | |||
annos = [] | |||
for label in anno: | |||
bbox = label["bbox"] | |||
class_name = classs_dict[label["category_id"]] | |||
if class_name in train_cls: | |||
x1, x2 = bbox[0], bbox[0] + bbox[2] | |||
y1, y2 = bbox[1], bbox[1] + bbox[3] | |||
annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])]) | |||
image_files.append(image_path) | |||
if annos: | |||
image_anno_dict[image_path] = np.array(annos) | |||
else: | |||
image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) | |||
return image_files, image_anno_dict | |||
def anno_parser(annos_str): | |||
"""Parse annotation from string to list.""" | |||
annos = [] | |||
for anno_str in annos_str: | |||
anno = list(map(int, anno_str.strip().split(','))) | |||
annos.append(anno) | |||
return annos | |||
def filter_valid_data(image_dir, anno_path): | |||
"""Filter valid image file, which both in image_dir and anno_path.""" | |||
image_files = [] | |||
image_anno_dict = {} | |||
if not os.path.isdir(image_dir): | |||
raise RuntimeError("Path given is not valid.") | |||
if not os.path.isfile(anno_path): | |||
raise RuntimeError("Annotation file is not valid.") | |||
with open(anno_path, "rb") as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
line_str = line.decode("utf-8").strip() | |||
line_split = str(line_str).split(' ') | |||
file_name = line_split[0] | |||
image_path = os.path.join(image_dir, file_name) | |||
if os.path.isfile(image_path): | |||
image_anno_dict[image_path] = anno_parser(line_split[1:]) | |||
image_files.append(image_path) | |||
return image_files, image_anno_dict | |||
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8): | |||
"""Create MindRecord file.""" | |||
mindrecord_dir = config.mindrecord_dir | |||
mindrecord_path = os.path.join(mindrecord_dir, prefix) | |||
writer = FileWriter(mindrecord_path, file_num) | |||
if dataset == "coco": | |||
image_files, image_anno_dict = create_coco_label(is_training) | |||
else: | |||
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) | |||
fasterrcnn_json = { | |||
"image": {"type": "bytes"}, | |||
"annotation": {"type": "int32", "shape": [-1, 6]}, | |||
} | |||
writer.add_schema(fasterrcnn_json, "fasterrcnn_json") | |||
for image_name in image_files: | |||
with open(image_name, 'rb') as f: | |||
img = f.read() | |||
annos = np.array(image_anno_dict[image_name], dtype=np.int32) | |||
row = {"image": img, "annotation": annos} | |||
writer.write_raw_data([row]) | |||
writer.commit() | |||
def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0, | |||
is_training=True, num_parallel_workers=4): | |||
"""Creatr FasterRcnn dataset with MindDataset.""" | |||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id, | |||
num_parallel_workers=1, shuffle=False) | |||
decode = C.Decode() | |||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1) | |||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | |||
hwc_to_chw = C.HWC2CHW() | |||
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) | |||
horizontally_op = C.RandomHorizontalFlip(1) | |||
type_cast0 = CC.TypeCast(mstype.float32) | |||
type_cast1 = CC.TypeCast(mstype.float16) | |||
type_cast2 = CC.TypeCast(mstype.int32) | |||
type_cast3 = CC.TypeCast(mstype.bool_) | |||
if is_training: | |||
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], | |||
output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
column_order=["image", "image_shape", "box", "label", "valid_num"], | |||
num_parallel_workers=num_parallel_workers) | |||
flip = (np.random.rand() < config.flip_ratio) | |||
if flip: | |||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], | |||
num_parallel_workers=12) | |||
ds = ds.map(operations=flipped_generation, | |||
input_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
num_parallel_workers=num_parallel_workers) | |||
else: | |||
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], | |||
num_parallel_workers=12) | |||
ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"], | |||
num_parallel_workers=12) | |||
else: | |||
ds = ds.map(operations=compose_map_func, | |||
input_columns=["image", "annotation"], | |||
output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
column_order=["image", "image_shape", "box", "label", "valid_num"], | |||
num_parallel_workers=num_parallel_workers) | |||
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], | |||
num_parallel_workers=24) | |||
# transpose_column from python to c | |||
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"]) | |||
ds = ds.map(operations=[type_cast1], input_columns=["box"]) | |||
ds = ds.map(operations=[type_cast2], input_columns=["label"]) | |||
ds = ds.map(operations=[type_cast3], input_columns=["valid_num"]) | |||
ds = ds.batch(batch_size, drop_remainder=True) | |||
ds = ds.repeat(repeat_num) | |||
return ds |
@@ -0,0 +1,42 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""lr generator for fasterrcnn""" | |||
import math | |||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): | |||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||
learning_rate = float(init_lr) + lr_inc * current_step | |||
return learning_rate | |||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||
base = float(current_step - warmup_steps) / float(decay_steps) | |||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | |||
return learning_rate | |||
def dynamic_lr(config, rank_size=1): | |||
"""dynamic learning rate generator""" | |||
base_lr = config.base_lr | |||
base_step = (config.base_step // rank_size) + rank_size | |||
total_steps = int(base_step * config.total_epoch) | |||
warmup_steps = int(config.warmup_step) | |||
lr = [] | |||
for i in range(total_steps): | |||
if i < warmup_steps: | |||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) | |||
else: | |||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | |||
return lr |
@@ -0,0 +1,184 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""FasterRcnn training network wrapper.""" | |||
import time | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.common.tensor import Tensor | |||
from mindspore.ops import functional as F | |||
from mindspore.ops import composite as C | |||
from mindspore import ParameterTuple | |||
from mindspore.train.callback import Callback | |||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
# pylint: disable=locally-disabled, missing-docstring, unused-argument | |||
time_stamp_init = False | |||
time_stamp_first = 0 | |||
class LossCallBack(Callback): | |||
""" | |||
Monitor the loss in training. | |||
If the loss is NAN or INF terminating training. | |||
Note: | |||
If per_print_times is 0 do not print loss. | |||
Args: | |||
per_print_times (int): Print loss every times. Default: 1. | |||
""" | |||
def __init__(self, per_print_times=1, rank_id=0): | |||
super(LossCallBack, self).__init__() | |||
if not isinstance(per_print_times, int) or per_print_times < 0: | |||
raise ValueError("print_step must be int and >= 0.") | |||
self._per_print_times = per_print_times | |||
self.count = 0 | |||
self.rpn_loss_sum = 0 | |||
self.rcnn_loss_sum = 0 | |||
self.rpn_cls_loss_sum = 0 | |||
self.rpn_reg_loss_sum = 0 | |||
self.rcnn_cls_loss_sum = 0 | |||
self.rcnn_reg_loss_sum = 0 | |||
self.rank_id = rank_id | |||
global time_stamp_init, time_stamp_first | |||
if not time_stamp_init: | |||
time_stamp_first = time.time() | |||
time_stamp_init = True | |||
def step_end(self, run_context): | |||
cb_params = run_context.original_args() | |||
rpn_loss = cb_params.net_outputs[0].asnumpy() | |||
rcnn_loss = cb_params.net_outputs[1].asnumpy() | |||
rpn_cls_loss = cb_params.net_outputs[2].asnumpy() | |||
rpn_reg_loss = cb_params.net_outputs[3].asnumpy() | |||
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy() | |||
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy() | |||
self.count += 1 | |||
self.rpn_loss_sum += float(rpn_loss) | |||
self.rcnn_loss_sum += float(rcnn_loss) | |||
self.rpn_cls_loss_sum += float(rpn_cls_loss) | |||
self.rpn_reg_loss_sum += float(rpn_reg_loss) | |||
self.rcnn_cls_loss_sum += float(rcnn_cls_loss) | |||
self.rcnn_reg_loss_sum += float(rcnn_reg_loss) | |||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
if self.count >= 1: | |||
global time_stamp_first | |||
time_stamp_current = time.time() | |||
rpn_loss = self.rpn_loss_sum/self.count | |||
rcnn_loss = self.rcnn_loss_sum/self.count | |||
rpn_cls_loss = self.rpn_cls_loss_sum/self.count | |||
rpn_reg_loss = self.rpn_reg_loss_sum/self.count | |||
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count | |||
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count | |||
total_loss = rpn_loss + rcnn_loss | |||
loss_file = open("./loss_{}.log".format(self.rank_id), "a+") | |||
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, " | |||
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, total_loss: %.5f" % | |||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, | |||
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, | |||
rcnn_cls_loss, rcnn_reg_loss, total_loss)) | |||
loss_file.write("\n") | |||
loss_file.close() | |||
self.count = 0 | |||
self.rpn_loss_sum = 0 | |||
self.rcnn_loss_sum = 0 | |||
self.rpn_cls_loss_sum = 0 | |||
self.rpn_reg_loss_sum = 0 | |||
self.rcnn_cls_loss_sum = 0 | |||
self.rcnn_reg_loss_sum = 0 | |||
class LossNet(nn.Cell): | |||
"""FasterRcnn loss method""" | |||
def construct(self, x1, x2, x3, x4, x5, x6): | |||
return x1 + x2 | |||
class WithLossCell(nn.Cell): | |||
""" | |||
Wrap the network with loss function to compute loss. | |||
Args: | |||
backbone (Cell): The target network to wrap. | |||
loss_fn (Cell): The loss function used to compute loss. | |||
""" | |||
def __init__(self, backbone, loss_fn): | |||
super(WithLossCell, self).__init__(auto_prefix=False) | |||
self._backbone = backbone | |||
self._loss_fn = loss_fn | |||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): | |||
loss1, loss2, loss3, loss4, loss5, loss6 = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num) | |||
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6) | |||
@property | |||
def backbone_network(self): | |||
""" | |||
Get the backbone network. | |||
Returns: | |||
Cell, return backbone network. | |||
""" | |||
return self._backbone | |||
class TrainOneStepCell(nn.Cell): | |||
""" | |||
Network training package class. | |||
Append an optimizer to the training network after that the construct function | |||
can be called to create the backward graph. | |||
Args: | |||
network (Cell): The training network. | |||
network_backbone (Cell): The forward network. | |||
optimizer (Cell): Optimizer for updating the weights. | |||
sens (Number): The adjust parameter. Default value is 1.0. | |||
reduce_flag (bool): The reduce flag. Default value is False. | |||
mean (bool): Allreduce method. Default value is False. | |||
degree (int): Device number. Default value is None. | |||
""" | |||
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): | |||
super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||
self.network = network | |||
self.network.set_grad() | |||
self.backbone = network_backbone | |||
self.weights = ParameterTuple(network.trainable_params()) | |||
self.optimizer = optimizer | |||
self.grad = C.GradOperation(get_by_list=True, | |||
sens_param=True) | |||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) | |||
self.reduce_flag = reduce_flag | |||
if reduce_flag: | |||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): | |||
weights = self.weights | |||
loss1, loss2, loss3, loss4, loss5, loss6 = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num) | |||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens) | |||
if self.reduce_flag: | |||
grads = self.grad_reducer(grads) | |||
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6 |
@@ -0,0 +1,227 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""coco eval for fasterrcnn""" | |||
import json | |||
import numpy as np | |||
from pycocotools.coco import COCO | |||
from pycocotools.cocoeval import COCOeval | |||
import mmcv | |||
# pylint: disable=locally-disabled, invalid-name | |||
_init_value = np.array(0.0) | |||
summary_init = { | |||
'Precision/mAP': _init_value, | |||
'Precision/mAP@.50IOU': _init_value, | |||
'Precision/mAP@.75IOU': _init_value, | |||
'Precision/mAP (small)': _init_value, | |||
'Precision/mAP (medium)': _init_value, | |||
'Precision/mAP (large)': _init_value, | |||
'Recall/AR@1': _init_value, | |||
'Recall/AR@10': _init_value, | |||
'Recall/AR@100': _init_value, | |||
'Recall/AR@100 (small)': _init_value, | |||
'Recall/AR@100 (medium)': _init_value, | |||
'Recall/AR@100 (large)': _init_value, | |||
} | |||
def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False): | |||
"""coco eval for fasterrcnn""" | |||
anns = json.load(open(result_files['bbox'])) | |||
if not anns: | |||
return summary_init | |||
if mmcv.is_str(coco): | |||
coco = COCO(coco) | |||
assert isinstance(coco, COCO) | |||
for res_type in result_types: | |||
result_file = result_files[res_type] | |||
assert result_file.endswith('.json') | |||
coco_dets = coco.loadRes(result_file) | |||
gt_img_ids = coco.getImgIds() | |||
det_img_ids = coco_dets.getImgIds() | |||
iou_type = 'bbox' if res_type == 'proposal' else res_type | |||
cocoEval = COCOeval(coco, coco_dets, iou_type) | |||
if res_type == 'proposal': | |||
cocoEval.params.useCats = 0 | |||
cocoEval.params.maxDets = list(max_dets) | |||
tgt_ids = gt_img_ids if not single_result else det_img_ids | |||
if single_result: | |||
res_dict = dict() | |||
for id_i in tgt_ids: | |||
cocoEval = COCOeval(coco, coco_dets, iou_type) | |||
if res_type == 'proposal': | |||
cocoEval.params.useCats = 0 | |||
cocoEval.params.maxDets = list(max_dets) | |||
cocoEval.params.imgIds = [id_i] | |||
cocoEval.evaluate() | |||
cocoEval.accumulate() | |||
cocoEval.summarize() | |||
res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]}) | |||
cocoEval = COCOeval(coco, coco_dets, iou_type) | |||
if res_type == 'proposal': | |||
cocoEval.params.useCats = 0 | |||
cocoEval.params.maxDets = list(max_dets) | |||
cocoEval.params.imgIds = tgt_ids | |||
cocoEval.evaluate() | |||
cocoEval.accumulate() | |||
cocoEval.summarize() | |||
summary_metrics = { | |||
'Precision/mAP': cocoEval.stats[0], | |||
'Precision/mAP@.50IOU': cocoEval.stats[1], | |||
'Precision/mAP@.75IOU': cocoEval.stats[2], | |||
'Precision/mAP (small)': cocoEval.stats[3], | |||
'Precision/mAP (medium)': cocoEval.stats[4], | |||
'Precision/mAP (large)': cocoEval.stats[5], | |||
'Recall/AR@1': cocoEval.stats[6], | |||
'Recall/AR@10': cocoEval.stats[7], | |||
'Recall/AR@100': cocoEval.stats[8], | |||
'Recall/AR@100 (small)': cocoEval.stats[9], | |||
'Recall/AR@100 (medium)': cocoEval.stats[10], | |||
'Recall/AR@100 (large)': cocoEval.stats[11], | |||
} | |||
return summary_metrics | |||
def xyxy2xywh(bbox): | |||
_bbox = bbox.tolist() | |||
return [ | |||
_bbox[0], | |||
_bbox[1], | |||
_bbox[2] - _bbox[0] + 1, | |||
_bbox[3] - _bbox[1] + 1, | |||
] | |||
def bbox2result_1image(bboxes, labels, num_classes): | |||
"""Convert detection results to a list of numpy arrays. | |||
Args: | |||
bboxes (Tensor): shape (n, 5) | |||
labels (Tensor): shape (n, ) | |||
num_classes (int): class number, including background class | |||
Returns: | |||
list(ndarray): bbox results of each class | |||
""" | |||
if bboxes.shape[0] == 0: | |||
result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)] | |||
else: | |||
result = [bboxes[labels == i, :] for i in range(num_classes - 1)] | |||
return result | |||
def proposal2json(dataset, results): | |||
"""convert proposal to json mode""" | |||
img_ids = dataset.getImgIds() | |||
json_results = [] | |||
dataset_len = dataset.get_dataset_size()*2 | |||
for idx in range(dataset_len): | |||
img_id = img_ids[idx] | |||
bboxes = results[idx] | |||
for i in range(bboxes.shape[0]): | |||
data = dict() | |||
data['image_id'] = img_id | |||
data['bbox'] = xyxy2xywh(bboxes[i]) | |||
data['score'] = float(bboxes[i][4]) | |||
data['category_id'] = 1 | |||
json_results.append(data) | |||
return json_results | |||
def det2json(dataset, results): | |||
"""convert det to json mode""" | |||
cat_ids = dataset.getCatIds() | |||
img_ids = dataset.getImgIds() | |||
json_results = [] | |||
dataset_len = len(img_ids) | |||
for idx in range(dataset_len): | |||
img_id = img_ids[idx] | |||
if idx == len(results): break | |||
result = results[idx] | |||
for label, result_label in enumerate(result): | |||
bboxes = result_label | |||
for i in range(bboxes.shape[0]): | |||
data = dict() | |||
data['image_id'] = img_id | |||
data['bbox'] = xyxy2xywh(bboxes[i]) | |||
data['score'] = float(bboxes[i][4]) | |||
data['category_id'] = cat_ids[label] | |||
json_results.append(data) | |||
return json_results | |||
def segm2json(dataset, results): | |||
"""convert segm to json mode""" | |||
bbox_json_results = [] | |||
segm_json_results = [] | |||
for idx in range(len(dataset)): | |||
img_id = dataset.img_ids[idx] | |||
det, seg = results[idx] | |||
for label, det_label in enumerate(det): | |||
# bbox results | |||
bboxes = det_label | |||
for i in range(bboxes.shape[0]): | |||
data = dict() | |||
data['image_id'] = img_id | |||
data['bbox'] = xyxy2xywh(bboxes[i]) | |||
data['score'] = float(bboxes[i][4]) | |||
data['category_id'] = dataset.cat_ids[label] | |||
bbox_json_results.append(data) | |||
if len(seg) == 2: | |||
segms = seg[0][label] | |||
mask_score = seg[1][label] | |||
else: | |||
segms = seg[label] | |||
mask_score = [bbox[4] for bbox in bboxes] | |||
for i in range(bboxes.shape[0]): | |||
data = dict() | |||
data['image_id'] = img_id | |||
data['score'] = float(mask_score[i]) | |||
data['category_id'] = dataset.cat_ids[label] | |||
segms[i]['counts'] = segms[i]['counts'].decode() | |||
data['segmentation'] = segms[i] | |||
segm_json_results.append(data) | |||
return bbox_json_results, segm_json_results | |||
def results2json(dataset, results, out_file): | |||
"""convert result convert to json mode""" | |||
result_files = dict() | |||
if isinstance(results[0], list): | |||
json_results = det2json(dataset, results) | |||
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') | |||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox') | |||
mmcv.dump(json_results, result_files['bbox']) | |||
elif isinstance(results[0], tuple): | |||
json_results = segm2json(dataset, results) | |||
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') | |||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox') | |||
result_files['segm'] = '{}.{}.json'.format(out_file, 'segm') | |||
mmcv.dump(json_results[0], result_files['bbox']) | |||
mmcv.dump(json_results[1], result_files['segm']) | |||
elif isinstance(results[0], np.ndarray): | |||
json_results = proposal2json(dataset, results) | |||
result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal') | |||
mmcv.dump(json_results, result_files['proposal']) | |||
else: | |||
raise TypeError('invalid type of results') | |||
return result_files |
@@ -19,7 +19,7 @@ from abc import abstractmethod | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||
from mindspore.nn import Cell | |||
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss | |||
from mindarmour.utils.logger import LogUtil | |||
@@ -44,12 +44,13 @@ class GradientMethod(Attack): | |||
Default: None. | |||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
In form of (clip_min, clip_max). Default: None. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
Examples: | |||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | |||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | |||
>>> attack = FastGradientMethod(network) | |||
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
>>> adv_x = attack.generate(inputs, labels) | |||
""" | |||
@@ -71,9 +72,10 @@ class GradientMethod(Attack): | |||
else: | |||
self._alpha = alpha | |||
if loss_fn is None: | |||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) | |||
with_loss_cell = WithLossCell(self._network, loss_fn) | |||
self._grad_all = GradWrapWithLoss(with_loss_cell) | |||
self._grad_all = self._network | |||
else: | |||
with_loss_cell = WithLossCell(self._network, loss_fn) | |||
self._grad_all = GradWrapWithLoss(with_loss_cell) | |||
self._grad_all.set_train() | |||
def generate(self, inputs, labels): | |||
@@ -83,13 +85,19 @@ class GradientMethod(Attack): | |||
Args: | |||
inputs (numpy.ndarray): Benign input samples used as references to create | |||
adversarial examples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Returns: | |||
numpy.ndarray, generated adversarial examples. | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels', labels) | |||
self._dtype = inputs.dtype | |||
gradient = self._gradient(inputs, labels) | |||
# use random method or not | |||
@@ -117,7 +125,8 @@ class GradientMethod(Attack): | |||
Args: | |||
inputs (numpy.ndarray): Benign input samples used as references to | |||
create adversarial examples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Raises: | |||
NotImplementedError: It is an abstract method. | |||
@@ -149,12 +158,13 @@ class FastGradientMethod(GradientMethod): | |||
Possible values: np.inf, 1 or 2. Default: 2. | |||
is_targeted (bool): If True, targeted attack. If False, untargeted | |||
attack. Default: False. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
Examples: | |||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | |||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | |||
>>> attack = FastGradientMethod(network) | |||
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
>>> adv_x = attack.generate(inputs, labels) | |||
""" | |||
@@ -175,12 +185,19 @@ class FastGradientMethod(GradientMethod): | |||
Args: | |||
inputs (numpy.ndarray): Input sample. | |||
labels (numpy.ndarray): Original/target label. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Returns: | |||
numpy.ndarray, gradient of inputs. | |||
""" | |||
out_grad = self._grad_all(Tensor(inputs), Tensor(labels)) | |||
if isinstance(labels, tuple): | |||
labels_tensor = tuple() | |||
for item in labels: | |||
labels_tensor += (Tensor(item),) | |||
else: | |||
labels_tensor = (Tensor(labels),) | |||
out_grad = self._grad_all(Tensor(inputs), *labels_tensor) | |||
if isinstance(out_grad, tuple): | |||
out_grad = out_grad[0] | |||
gradient = out_grad.asnumpy() | |||
@@ -210,7 +227,8 @@ class RandomFastGradientMethod(FastGradientMethod): | |||
Possible values: np.inf, 1 or 2. Default: 2. | |||
is_targeted (bool): If True, targeted attack. If False, untargeted | |||
attack. Default: False. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
Raises: | |||
ValueError: eps is smaller than alpha! | |||
@@ -218,7 +236,7 @@ class RandomFastGradientMethod(FastGradientMethod): | |||
Examples: | |||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | |||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | |||
>>> attack = RandomFastGradientMethod(network) | |||
>>> attack = RandomFastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
>>> adv_x = attack.generate(inputs, labels) | |||
""" | |||
@@ -254,12 +272,13 @@ class FastGradientSignMethod(GradientMethod): | |||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
is_targeted (bool): If True, targeted attack. If False, untargeted | |||
attack. Default: False. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
Examples: | |||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | |||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | |||
>>> attack = FastGradientSignMethod(network) | |||
>>> attack = FastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
>>> adv_x = attack.generate(inputs, labels) | |||
""" | |||
@@ -279,12 +298,19 @@ class FastGradientSignMethod(GradientMethod): | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (union[numpy.ndarray, tuple]): original/target labels. \ | |||
for each input if it has more than one label, it is wrapped in a tuple. | |||
Returns: | |||
numpy.ndarray, gradient of inputs. | |||
""" | |||
out_grad = self._grad_all(Tensor(inputs), Tensor(labels)) | |||
if isinstance(labels, tuple): | |||
labels_tensor = tuple() | |||
for item in labels: | |||
labels_tensor += (Tensor(item),) | |||
else: | |||
labels_tensor = (Tensor(labels),) | |||
out_grad = self._grad_all(Tensor(inputs), *labels_tensor) | |||
if isinstance(out_grad, tuple): | |||
out_grad = out_grad[0] | |||
gradient = out_grad.asnumpy() | |||
@@ -311,7 +337,8 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
is_targeted (bool): True: targeted attack. False: untargeted attack. | |||
Default: False. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
Raises: | |||
ValueError: eps is smaller than alpha! | |||
@@ -319,7 +346,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||
Examples: | |||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | |||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | |||
>>> attack = RandomFastGradientSignMethod(network) | |||
>>> attack = RandomFastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
>>> adv_x = attack.generate(inputs, labels) | |||
""" | |||
@@ -350,12 +377,13 @@ class LeastLikelyClassMethod(FastGradientSignMethod): | |||
Default: None. | |||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
Examples: | |||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | |||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | |||
>>> attack = LeastLikelyClassMethod(network) | |||
>>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
>>> adv_x = attack.generate(inputs, labels) | |||
""" | |||
@@ -384,7 +412,8 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod): | |||
Default: 0.035. | |||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
loss_fn (Loss): Loss function for optimization. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
Raises: | |||
ValueError: eps is smaller than alpha! | |||
@@ -392,7 +421,7 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod): | |||
Examples: | |||
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) | |||
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) | |||
>>> attack = RandomLeastLikelyClassMethod(network) | |||
>>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
>>> adv_x = attack.generate(inputs, labels) | |||
""" | |||
@@ -17,7 +17,7 @@ from abc import abstractmethod | |||
import numpy as np | |||
from PIL import Image, ImageOps | |||
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||
from mindspore.nn import Cell | |||
from mindspore import Tensor | |||
from mindarmour.utils.logger import LogUtil | |||
@@ -114,7 +114,8 @@ class IterativeGradientMethod(Attack): | |||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
nb_iter (int): Number of iteration. Default: 5. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
""" | |||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | |||
loss_fn=None): | |||
@@ -123,12 +124,15 @@ class IterativeGradientMethod(Attack): | |||
self._eps = check_value_positive('eps', eps) | |||
self._eps_iter = check_value_positive('eps_iter', eps_iter) | |||
self._nb_iter = check_int_positive('nb_iter', nb_iter) | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
self._bounds = None | |||
if bounds is not None: | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
if loss_fn is None: | |||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) | |||
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) | |||
self._loss_grad = network | |||
else: | |||
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) | |||
self._loss_grad.set_train() | |||
@abstractmethod | |||
@@ -139,8 +143,8 @@ class IterativeGradientMethod(Attack): | |||
Args: | |||
inputs (numpy.ndarray): Benign input samples used as references to create | |||
adversarial examples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Raises: | |||
NotImplementedError: This function is not available in | |||
IterativeGradientMethod. | |||
@@ -177,12 +181,13 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
is_targeted (bool): If True, targeted attack. If False, untargeted | |||
attack. Default: False. | |||
nb_iter (int): Number of iteration. Default: 5. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
attack (class): The single step gradient method of each iteration. In | |||
this class, FGSM is used. | |||
Examples: | |||
>>> attack = BasicIterativeMethod(network) | |||
>>> attack = BasicIterativeMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
""" | |||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | |||
is_targeted=False, nb_iter=5, loss_fn=None): | |||
@@ -207,8 +212,8 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
Args: | |||
inputs (numpy.ndarray): Benign input samples used as references to | |||
create adversarial examples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Returns: | |||
numpy.ndarray, generated adversarial examples. | |||
@@ -218,8 +223,13 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | |||
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels', labels) | |||
arr_x = inputs | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
@@ -267,7 +277,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
decay_factor (float): Decay factor in iterations. Default: 1.0. | |||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||
np.inf, 1 or 2. Default: 'inf'. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
""" | |||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | |||
@@ -290,7 +301,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
Args: | |||
inputs (numpy.ndarray): Benign input samples used as references to | |||
create adversarial examples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Returns: | |||
numpy.ndarray, generated adversarial examples. | |||
@@ -301,8 +313,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | |||
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels', labels) | |||
arr_x = inputs | |||
momentum = 0 | |||
if self._bounds is not None: | |||
@@ -340,7 +357,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Returns: | |||
numpy.ndarray, gradient of labels w.r.t inputs. | |||
@@ -350,7 +368,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | |||
""" | |||
# get grad of loss over x | |||
out_grad = self._loss_grad(Tensor(inputs), Tensor(labels)) | |||
if isinstance(labels, tuple): | |||
labels_tensor = tuple() | |||
for item in labels: | |||
labels_tensor += (Tensor(item),) | |||
else: | |||
labels_tensor = (Tensor(labels),) | |||
out_grad = self._loss_grad(Tensor(inputs), *labels_tensor) | |||
if isinstance(out_grad, tuple): | |||
out_grad = out_grad[0] | |||
gradient = out_grad.asnumpy() | |||
@@ -384,7 +408,8 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
nb_iter (int): Number of iteration. Default: 5. | |||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||
np.inf, 1 or 2. Default: 'inf'. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
""" | |||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | |||
@@ -406,7 +431,8 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
Args: | |||
inputs (numpy.ndarray): Benign input samples used as references to | |||
create adversarial examples. | |||
labels (numpy.ndarray): Original/target labels. | |||
labels (Union[numpy.ndarray, tuple]): Original/target labels. \ | |||
For each input if it has more than one label, it is wrapped in a tuple. | |||
Returns: | |||
numpy.ndarray, generated adversarial examples. | |||
@@ -417,8 +443,13 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | |||
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
inputs, _ = check_pair_numpy_param('inputs', inputs, \ | |||
'labels', labels) | |||
arr_x = inputs | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
@@ -460,7 +491,8 @@ class DiverseInputIterativeMethod(BasicIterativeMethod): | |||
is_targeted (bool): If True, targeted attack. If False, untargeted | |||
attack. Default: False. | |||
prob (float): Transformation probability. Default: 0.5. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
""" | |||
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||
is_targeted=False, prob=0.5, loss_fn=None): | |||
@@ -495,7 +527,8 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||
np.inf, 1 or 2. Default: 'l1'. | |||
prob (float): Transformation probability. Default: 0.5. | |||
loss_fn (Loss): Loss function for optimization. Default: None. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
""" | |||
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||
is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None): | |||
@@ -19,6 +19,7 @@ from random import choice | |||
import numpy as np | |||
from mindspore import Model | |||
from mindspore import Tensor | |||
from mindspore import nn | |||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
check_param_multi_types, check_norm_level, check_param_in_range, \ | |||
@@ -451,6 +452,8 @@ class Fuzzer: | |||
else: | |||
network = self._target_model._network | |||
loss_fn = self._target_model._loss_fn | |||
if loss_fn is None: | |||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | |||
mutates[method] = self._strategies[method](network, | |||
loss_fn=loss_fn) | |||
return mutates | |||
@@ -18,7 +18,7 @@ import numpy as np | |||
import pytest | |||
import mindspore.ops.operations as P | |||
from mindspore.nn import Cell | |||
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||
import mindspore.context as context | |||
from mindarmour.adv_robustness.attacks import FastGradientMethod | |||
@@ -67,7 +67,7 @@ def test_batch_generate_attack(): | |||
label = np.random.randint(0, 10, 128).astype(np.int32) | |||
label = np.eye(10)[label].astype(np.float32) | |||
attack = FastGradientMethod(Net()) | |||
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.batch_generate(input_np, label, batch_size=32) | |||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | |||
@@ -71,7 +71,7 @@ def test_fast_gradient_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = FastGradientMethod(Net()) | |||
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | |||
@@ -91,7 +91,7 @@ def test_fast_gradient_method_gpu(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = FastGradientMethod(Net()) | |||
attack = FastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ | |||
@@ -132,7 +132,7 @@ def test_random_fast_gradient_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = RandomFastGradientMethod(Net()) | |||
attack = RandomFastGradientMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Random fast gradient method: ' \ | |||
@@ -154,7 +154,7 @@ def test_fast_gradient_sign_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = FastGradientSignMethod(Net()) | |||
attack = FastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Fast gradient sign method: generate' \ | |||
@@ -176,7 +176,7 @@ def test_random_fast_gradient_sign_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(28)[label].astype(np.float32) | |||
attack = RandomFastGradientSignMethod(Net()) | |||
attack = RandomFastGradientSignMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Random fast gradient sign method: ' \ | |||
@@ -198,7 +198,7 @@ def test_least_likely_class_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = LeastLikelyClassMethod(Net()) | |||
attack = LeastLikelyClassMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Least likely class method: generate' \ | |||
@@ -220,7 +220,8 @@ def test_random_least_likely_class_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01) | |||
attack = RandomLeastLikelyClassMethod(Net(), eps=0.1, alpha=0.01, \ | |||
loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Random least likely class method: ' \ | |||
@@ -239,5 +240,6 @@ def test_assert_error(): | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
with pytest.raises(ValueError) as e: | |||
assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21) | |||
assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21, \ | |||
loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
assert str(e.value) == 'eps must be larger than alpha!' |
@@ -20,6 +20,7 @@ import pytest | |||
from mindspore.ops import operations as P | |||
from mindspore.nn import Cell | |||
from mindspore import context | |||
from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
from mindarmour.adv_robustness.attacks import BasicIterativeMethod | |||
from mindarmour.adv_robustness.attacks import MomentumIterativeMethod | |||
@@ -70,7 +71,7 @@ def test_basic_iterative_method(): | |||
for i in range(5): | |||
net = Net() | |||
attack = BasicIterativeMethod(net, nb_iter=i + 1) | |||
attack = BasicIterativeMethod(net, nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any( | |||
ms_adv_x != input_np), 'Basic iterative method: generate value' \ | |||
@@ -91,7 +92,7 @@ def test_momentum_iterative_method(): | |||
label = np.eye(3)[label].astype(np.float32) | |||
for i in range(5): | |||
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) | |||
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \ | |||
' value must not be equal to' \ | |||
@@ -112,7 +113,7 @@ def test_projected_gradient_descent_method(): | |||
label = np.eye(3)[label].astype(np.float32) | |||
for i in range(5): | |||
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1) | |||
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any( | |||
@@ -134,7 +135,7 @@ def test_diverse_input_iterative_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = DiverseInputIterativeMethod(Net()) | |||
attack = DiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ | |||
' value must not be equal to' \ | |||
@@ -154,7 +155,7 @@ def test_momentum_diverse_input_iterative_method(): | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
attack = MomentumDiverseInputIterativeMethod(Net()) | |||
attack = MomentumDiverseInputIterativeMethod(Net(), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ | |||
'generate value must not be equal to' \ | |||
@@ -167,10 +168,7 @@ def test_momentum_diverse_input_iterative_method(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_error(): | |||
with pytest.raises(TypeError): | |||
# check_param_multi_types | |||
assert IterativeGradientMethod(Net(), bounds=None) | |||
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0)) | |||
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
with pytest.raises(NotImplementedError): | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
@@ -59,8 +59,8 @@ def test_ead(): | |||
optimizer = Momentum(net.trainable_params(), 0.001, 0.9) | |||
net = Net() | |||
fgsm = FastGradientSignMethod(net) | |||
pgd = ProjectedGradientDescent(net) | |||
fgsm = FastGradientSignMethod(net, loss_fn=loss_fn) | |||
pgd = ProjectedGradientDescent(net, loss_fn=loss_fn) | |||
ead = EnsembleAdversarialDefense(net, [fgsm, pgd], loss_fn=loss_fn, | |||
optimizer=optimizer) | |||
LOGGER.set_level(logging.DEBUG) | |||
@@ -117,7 +117,7 @@ def test_lenet_mnist_coverage_ascend(): | |||
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) | |||
# generate adv_data | |||
attack = FastGradientSignMethod(net, eps=0.3) | |||
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) | |||
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | |||