|
|
@@ -40,15 +40,16 @@ class SuppressPrivacyFactory: |
|
|
|
""" |
|
|
|
Args: |
|
|
|
networks (Cell): The training network. |
|
|
|
This networks parameter should be same as 'network' parameter of SuppressModel(). |
|
|
|
mask_layers (list): Description of the training network layers that need to be suppressed. |
|
|
|
policy (str): Training policy for suppress privacy training. Default: "local_train", means local training. |
|
|
|
end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10. |
|
|
|
This end_epoch parameter should be the same as 'epoch' parameter of mindspore.train.model.train(). |
|
|
|
This end_epoch parameter should be same as 'epoch' parameter of mindspore.train.model.train(). |
|
|
|
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20. |
|
|
|
start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3. |
|
|
|
mask_times (int): The num of suppress operations. Default: 1000. |
|
|
|
lr (Union[float, int]): Learning rate, should be unchanged during training. 0<lr<=0.50. Default: 0.05. |
|
|
|
This lr parameter should be the same as 'learning_rate' parameter of mindspore.nn.SGD(). |
|
|
|
This lr parameter should be same as 'learning_rate' parameter of mindspore.nn.SGD(). |
|
|
|
sparse_end (float): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90. |
|
|
|
sparse_start (Union[float, int]): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0. |
|
|
|
|
|
|
@@ -164,7 +165,7 @@ class SuppressCtrl(Cell): |
|
|
|
msg = "mask_interval should be greater than 10, smaller than 20, but got {}".format(self.mask_step_interval) |
|
|
|
msg += "\n Precision of trained model may be poor !!! " |
|
|
|
msg += "\n please modify epoch_start, epoch_end and batch_num !" |
|
|
|
msg += "\n mask_interval = (epoch_end-epoch_start+1)*batch_num, batch_num = samples/batch_size" |
|
|
|
msg += "\n mask_interval = (epoch_end-epoch_start+1)*batch_num/mask_times, batch_num = samples/batch_size" |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
|
|
|
|
if self.sparse_end >= 1.00 or self.sparse_end <= 0: |
|
|
@@ -256,10 +257,11 @@ class SuppressCtrl(Cell): |
|
|
|
msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
msg = "this lr parameter should be the same as 'learning_rate' parameter of mindspore.nn.SGD()\n" |
|
|
|
msg += "this end_epoch parameter should be the same as 'epoch' parameter of mindspore.train.model.train()\n" |
|
|
|
msg += "sup_privacy only support SGD optimizer" |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
msg = "\nThis networks parameter should be same as 'network' parameter of SuppressModel()" |
|
|
|
msg += "\nThis lr parameter should be same as 'learning_rate' parameter of mindspore.nn.SGD()\n" |
|
|
|
msg += "\nThis end_epoch parameter should be same as 'epoch' parameter of mindspore.train.model.train()\n" |
|
|
|
msg += "\nsup_privacy only support SGD optimizer" |
|
|
|
LOGGER.warn(TAG, msg) |
|
|
|
|
|
|
|
def update_status(self, cur_epoch, cur_step, cur_step_in_epoch): |
|
|
|
""" |
|
|
@@ -289,21 +291,26 @@ class SuppressCtrl(Cell): |
|
|
|
self.to_do_mask = False |
|
|
|
self.mask_started = False |
|
|
|
|
|
|
|
def update_mask(self, networks, cur_step): |
|
|
|
def update_mask(self, networks, cur_step, target_sparse=0.0): |
|
|
|
""" |
|
|
|
Update add mask arrays and multiply mask arrays of network layers. |
|
|
|
|
|
|
|
Args: |
|
|
|
networks (Cell): The training network. |
|
|
|
cur_step (int): Current epoch of the whole training process. |
|
|
|
target_sparse(float): The sparsity to reach. Default: 0.0. |
|
|
|
""" |
|
|
|
if self.sparse_end <= 0.0: |
|
|
|
return |
|
|
|
|
|
|
|
self.cur_sparse = self.sparse_end +\ |
|
|
|
(self.sparse_start - self.sparse_end)*\ |
|
|
|
math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) |
|
|
|
self.cur_sparse = min(self.cur_sparse, self.sparse_end) |
|
|
|
last_sparse = self.cur_sparse |
|
|
|
if target_sparse > 0.0: |
|
|
|
self.cur_sparse = target_sparse |
|
|
|
else: |
|
|
|
self.cur_sparse = self.sparse_end + \ |
|
|
|
(self.sparse_start - self.sparse_end) * \ |
|
|
|
math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) |
|
|
|
self.cur_sparse = max(self.sparse_start, max(last_sparse, min(self.cur_sparse, self.sparse_end))) |
|
|
|
m = 0 |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
grad_idx = self.grad_idx_map[m] |
|
|
@@ -450,7 +457,7 @@ class SuppressCtrl(Cell): |
|
|
|
def update_mask_layer_approximity(self, weight_array_flat, weight_array_flat_abs, actual_stop_pos, layer_index): |
|
|
|
""" |
|
|
|
Update add mask arrays and multiply mask arrays of one single layer with many parameter. |
|
|
|
disable clipping loweer, clipping, adding noise operation |
|
|
|
Disable clipping lower, clipping, adding noise operation |
|
|
|
|
|
|
|
Args: |
|
|
|
weight_array_flat (numpy.ndarray): The weight array of layer's parameters. |
|
|
@@ -477,6 +484,8 @@ class SuppressCtrl(Cell): |
|
|
|
if last_sparse_pos <= 0: |
|
|
|
init_batch_suppress = True |
|
|
|
for i in range(0, part_num): |
|
|
|
if split_k_num <= 0: |
|
|
|
break |
|
|
|
array_row_mul_mask = mul_mask_array_flat[i * part_size : (i + 1) * part_size] |
|
|
|
array_row_flat_abs = weight_array_flat_abs[i * part_size : (i + 1) * part_size] |
|
|
|
if not init_batch_suppress: |
|
|
@@ -488,7 +497,6 @@ class SuppressCtrl(Cell): |
|
|
|
del array_row_flat_abs_masked |
|
|
|
del set_abs |
|
|
|
del list2 |
|
|
|
gc.collect() |
|
|
|
else: |
|
|
|
val_array_align = array_row_flat_abs |
|
|
|
|
|
|
@@ -497,7 +505,6 @@ class SuppressCtrl(Cell): |
|
|
|
del array_row_flat_abs |
|
|
|
del array_row_mul_mask |
|
|
|
del val_array_align |
|
|
|
gc.collect() |
|
|
|
continue |
|
|
|
|
|
|
|
partition = np.partition(val_array_align, real_split_k_num - 1) |
|
|
@@ -509,7 +516,6 @@ class SuppressCtrl(Cell): |
|
|
|
del array_row_mul_mask |
|
|
|
del val_array_align |
|
|
|
del partition |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
if real_part_num > 0: |
|
|
|
sparse_thd = sparse_thd / real_part_num |
|
|
@@ -570,9 +576,9 @@ class SuppressCtrl(Cell): |
|
|
|
if array_mul_mask_flat_conv2[i] <= 0.0: |
|
|
|
sparse = sparse + 1.0 |
|
|
|
sparse_value_2 += 1.0 |
|
|
|
sparse = sparse/full |
|
|
|
sparse_value_1 = sparse_value_1/full_conv1 |
|
|
|
sparse_value_2 = sparse_value_2/full_conv2 |
|
|
|
sparse = sparse / full |
|
|
|
sparse_value_1 = sparse_value_1 / full_conv1 |
|
|
|
sparse_value_2 = sparse_value_2 / full_conv2 |
|
|
|
msg = "conv sparse mask={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
return sparse, sparse_value_1, sparse_value_2 |
|
|
@@ -591,16 +597,24 @@ class SuppressCtrl(Cell): |
|
|
|
full_conv1 = 0.0 |
|
|
|
full_conv2 = 0.0 |
|
|
|
|
|
|
|
conv1_matched = False |
|
|
|
conv2_matched = False |
|
|
|
|
|
|
|
array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32) |
|
|
|
array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32) |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
if "networks.conv1.weight" in layer.name or "networks.layers.0.weight" in layer.name: # lenet5/res50 vgg16 |
|
|
|
if not conv1_matched and \ |
|
|
|
("networks.conv1.weight" in layer.name or "networks.layers.0.weight" in layer.name): |
|
|
|
# lenet5/res50 vgg16 |
|
|
|
array_cur_conv1 = layer.data.asnumpy() |
|
|
|
print("calc_actual_sparse, match conv1") |
|
|
|
if "networks.conv2.weight" in layer.name or "networks.layers.3.weight" in layer.name \ |
|
|
|
or "networks.layer1.0.conv1.weight" in layer.name: # res50 |
|
|
|
print("calc_actual_sparse, match conv1: {}".format(layer.name)) |
|
|
|
conv1_matched = True |
|
|
|
if not conv2_matched and \ |
|
|
|
("networks.conv2.weight" in layer.name or "networks.layers.3.weight" in layer.name \ |
|
|
|
or "networks.layer1.0.conv1.weight" in layer.name): # res50 |
|
|
|
array_cur_conv2 = layer.data.asnumpy() |
|
|
|
print("calc_actual_sparse, match conv2") |
|
|
|
print("calc_actual_sparse, match conv2: {}".format(layer.name)) |
|
|
|
conv2_matched = True |
|
|
|
|
|
|
|
array_mul_mask_flat_conv1 = array_cur_conv1.flatten() |
|
|
|
array_mul_mask_flat_conv2 = array_cur_conv2.flatten() |
|
|
@@ -675,10 +689,18 @@ class SuppressCtrl(Cell): |
|
|
|
return sparse |
|
|
|
|
|
|
|
def print_paras(self): |
|
|
|
""" |
|
|
|
Show parameters info |
|
|
|
""" |
|
|
|
msg = "paras: start_epoch:{}, end_epoch:{}, batch_num:{}, interval:{}, lr:{}, sparse_end:{}, sparse_start:{}" \ |
|
|
|
.format(self.mask_start_epoch, self.mask_end_epoch, self.batch_num, self.mask_step_interval, |
|
|
|
self.lr, self.sparse_end, self.sparse_start) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
msg = "\nThis networks parameter should be same as 'network' parameter of SuppressModel()" |
|
|
|
msg = "\nThis lr parameter should be same as 'learning_rate' parameter of mindspore.nn.SGD()" |
|
|
|
msg += "\nThis end_epoch parameter should be same as 'epoch' parameter of mindspore.train.model.train()" |
|
|
|
msg += "\nsup_privacy only support SGD optimizer" |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
|
|
|
|
def get_one_mask_layer(mask_layers, layer_name): |
|
|
|
""" |
|
|
@@ -713,11 +735,14 @@ class MaskLayerDes: |
|
|
|
grad layers (print in PYNATIVE_MODE). |
|
|
|
is_add_noise (bool): If True, the weight of this layer can add noise. |
|
|
|
If False, the weight of this layer can not add noise. |
|
|
|
If parameter num is greater than 100000, is_add_noise has not effect. |
|
|
|
is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. |
|
|
|
If False, the weights of this layer won't be clipped. |
|
|
|
If parameter num is greater than 100000, is_lower_clip has not effect. |
|
|
|
min_num (int): The number of weights left that not be suppressed. |
|
|
|
If min_num is smaller than (parameter num*SupperssCtrl.sparse_end), min_num has not effect. |
|
|
|
upper_bound (Union[float, int]): max abs value of weight in this layer, default: 1.20. |
|
|
|
If parameter num is greater than 100000, upper_bound has not effect. |
|
|
|
""" |
|
|
|
def __init__(self, layer_name, grad_idx, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): |
|
|
|
self.layer_name = check_param_type('layer_name', layer_name, str) |
|
|
@@ -753,7 +778,7 @@ class GradMaskInCell(Cell): |
|
|
|
self.is_add_noise = is_add_noise |
|
|
|
self.is_lower_clip = is_lower_clip |
|
|
|
self.min_num = min_num |
|
|
|
self.upper_bound = check_value_positive('upper_bound', upper_bound) |
|
|
|
self.upper_bound = max(0.10, check_value_positive('upper_bound', upper_bound)) |
|
|
|
|
|
|
|
self.para_num = array.size |
|
|
|
self.is_approximity = False |
|
|
|