Browse Source

remove redundant codes

pull/320/head
jin-xiulang 3 years ago
parent
commit
77eb1c017f
10 changed files with 40 additions and 38 deletions
  1. +5
    -6
      mindarmour/adv_robustness/attacks/attack.py
  2. +4
    -4
      mindarmour/fuzz_testing/image_transform.py
  3. +2
    -4
      mindarmour/privacy/diff_privacy/monitor/monitor.py
  4. +1
    -1
      mindarmour/privacy/sup_privacy/mask_monitor/masker.py
  5. +0
    -1
      mindarmour/privacy/sup_privacy/train/model.py
  6. +22
    -11
      mindarmour/reliability/concept_drift/concept_drift_check_images.py
  7. +6
    -8
      mindarmour/reliability/concept_drift/concept_drift_check_time_series.py
  8. +0
    -1
      mindarmour/reliability/model_fault_injection/__init__.py
  9. +0
    -1
      mindarmour/reliability/model_fault_injection/fault_injection.py
  10. +0
    -1
      mindarmour/reliability/model_fault_injection/fault_type.py

+ 5
- 6
mindarmour/adv_robustness/attacks/attack.py View File

@@ -183,13 +183,12 @@ class Attack:
best_position = check_numpy_param('best_position', best_position)
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position)
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
# pylint: disable=invalid-name
REDUCTION_ITERS = 6 # recover 10% difference each time and recover 60% totally.
for _ in range(REDUCTION_ITERS):
BLOCK_NUM = 30 # divide the image into 30 segments
block_width = best_position.shape[0] // BLOCK_NUM
reduction_iters = 6 # recover 10% difference each time and recover 60% totally.
for _ in range(reduction_iters):
block_num = 30 # divide the image into 30 segments
block_width = best_position.shape[0] // block_num
if block_width > 0:
for i in range(BLOCK_NUM):
for i in range(block_num):
diff = x_ori[i*block_width: (i+1)*block_width, :, :]\
- best_position[i*block_width:(i+1)*block_width, :, :]
if np.max(np.abs(diff)) >= 0.1*(self._bounds[1] - self._bounds[0]):


+ 4
- 4
mindarmour/fuzz_testing/image_transform.py View File

@@ -107,8 +107,8 @@ def is_normalized(img):
"""
if is_numpy(img):
minimal = np.min(img)
maximun = np.max(img)
if minimal >= 0 and maximun <= 1:
maximum = np.max(img)
if minimal >= 0 and maximum <= 1:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))
@@ -359,9 +359,9 @@ class Translate(ImageTransform):
Translate an image.

Args:
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length.
x_bias (Union[int, float]): X-direction translation: x + x_bias*image_length.
Default: 0.
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide.
y_bias (Union[int, float]): Y-direction translation: y + y_bias*image_wide.
Default: 0.
"""



+ 2
- 4
mindarmour/privacy/diff_privacy/monitor/monitor.py View File

@@ -347,9 +347,8 @@ class RDPMonitor(Callback):
Returns:
float, delta budget.
"""
orders = np.atleast_1d(self._orders)
rdps = np.atleast_1d(rdp)
deltas = np.exp((rdps - self._target_eps)*(orders - 1))
deltas = np.exp((rdps - self._target_eps)*(np.atleast_1d(self._orders) - 1))
min_delta = np.min(deltas)
return np.min([min_delta, 1.])

@@ -363,9 +362,8 @@ class RDPMonitor(Callback):
Returns:
float, eps budget.
"""
orders = np.atleast_1d(self._orders)
rdps = np.atleast_1d(rdp)
eps = rdps - np.log(self._target_delta) / (orders - 1)
eps = rdps - np.log(self._target_delta) / (np.atleast_1d(self._orders) - 1)
return np.min(eps)




+ 1
- 1
mindarmour/privacy/sup_privacy/mask_monitor/masker.py View File

@@ -26,7 +26,7 @@ TAG = 'suppress masker'
class SuppressMasker(Callback):
"""
Args:
model (SuppressModel): SuppressModel instance.
model (SuppressModel): SuppressModel instance.
suppress_ctrl (SuppressCtrl): SuppressCtrl instance.

Examples:


+ 0
- 1
mindarmour/privacy/sup_privacy/train/model.py View File

@@ -225,7 +225,6 @@ class _TupleMul(nn.Cell):
def construct(self, input1, input2):
"""Add two tuple of data."""
out = self.hyper_map(self.mul, input1, input2)
#print(out)
return out

# come from nn.cell_wrapper.TrainOneStepCell


+ 22
- 11
mindarmour/reliability/concept_drift/concept_drift_check_images.py View File

@@ -11,20 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Out-of-Distribution detection for images.
"""

import heapq
from abc import abstractmethod
import numpy as np
from mindspore import Tensor
from sklearn.cluster import KMeans
from mindarmour.utils._check_param import check_param_type, check_param_in_range
from mindspore.train.summary.summary_record import _get_summary_tensor_data

from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore import Tensor
from mindarmour.utils._check_param import check_param_type, check_param_in_range
from mindarmour.utils.logger import LogUtil

"""
Out-of-Distribution detection for images.
"""
LOGGER = LogUtil.get_instance()
TAG = 'concept drift detection'


class OodDetector:
@@ -56,6 +58,7 @@ class OodDetector:
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()

@abstractmethod
def get_optimal_threshold(self, label, ds_eval):
"""
Get the optimal threshold.
@@ -67,8 +70,12 @@ class OodDetector:
Returns:
- float, the optimal threshold.
"""
pass
msg = 'The function generate() is an abstract function in class ' \
'`OodDetector` and should be implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)

@abstractmethod
def ood_predict(self, threshold, ds_test):
"""
The out-of-distribution detection.
@@ -81,7 +88,10 @@ class OodDetector:
Returns:
- numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood.
"""
pass
msg = 'The function generate() is an abstract function in class ' \
'`OodDetector` and should be implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)


class OodDetectorFeatureCluster(OodDetector):
@@ -103,6 +113,7 @@ class OodDetectorFeatureCluster(OodDetector):
"""

def __init__(self, model, ds_train, n_cluster, layer):
super().__init__(model, ds_train)
self.model = model
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray)
self.n_cluster = check_param_type('n_cluster', n_cluster, int)
@@ -173,7 +184,7 @@ class OodDetectorFeatureCluster(OodDetector):
threshold.append(threshold_change)
acc = np.array(acc)
threshold = np.array(threshold)
optimal_threshold = threshold[np.where(acc==np.max(acc))[0]][0]
optimal_threshold = threshold[np.where(acc == np.max(acc))[0]][0]
return optimal_threshold

def ood_predict(self, threshold, ds_test):


+ 6
- 8
mindarmour/reliability/concept_drift/concept_drift_check_time_series.py View File

@@ -9,7 +9,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Concpt drift module
@@ -72,8 +71,7 @@ class ConceptDriftCheckTimeSeries:
>>> w, x = ConceptDriftCheckTimeSeries._reservoir_model_feature(window_data)
"""
# Initialize weights
res_size = self._res_size
x_state = _w_generate(res_size, len(window_data), window_data)
x_state = _w_generate(self._res_size, len(window_data), window_data)
x_state_t = x_state.T
# Data reshape
data_channel = None
@@ -85,7 +83,7 @@ class ConceptDriftCheckTimeSeries:
reg = 1e-8
# Calculate w_out
w_out = np.dot(np.dot(y_t, x_state_t),
np.linalg.inv(np.dot(x_state, x_state_t) + reg*np.eye(res_size)))
np.linalg.inv(np.dot(x_state, x_state_t) + reg*np.eye(self._res_size)))
return w_out, x_state

def _concept_distance(self, data_x, data_y):
@@ -391,11 +389,11 @@ def _cal_threshold(distance, threshold_index):
Returns:
- float, [0, 1].
"""
distance = distance[distance > 0]
pos_distance = distance[distance > 0]
# Threshold calculation
if distance.size > 0:
q_1 = np.percentile(distance, 25)
q_3 = np.percentile(distance, 75)
if pos_distance.size > 0:
q_1 = np.percentile(pos_distance, 25)
q_3 = np.percentile(pos_distance, 75)
q_diff = q_3 - q_1
threshold = np.clip(0.1 + threshold_index*q_diff, 0, 1)
else:


+ 0
- 1
mindarmour/reliability/model_fault_injection/__init__.py View File

@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
This module provides model fault injection to evaluate the reliability of given model.
"""


+ 0
- 1
mindarmour/reliability/model_fault_injection/fault_injection.py View File

@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Fault injection module


+ 0
- 1
mindarmour/reliability/model_fault_injection/fault_type.py View File

@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Fault type module


Loading…
Cancel
Save