|
|
@@ -15,6 +15,7 @@ |
|
|
|
control function of suppress-based privacy. |
|
|
|
""" |
|
|
|
import math |
|
|
|
import gc |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mindspore import Tensor |
|
|
@@ -35,18 +36,20 @@ class SuppressPrivacyFactory: |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, |
|
|
|
mask_times=1000, lr=0.10, sparse_end=0.90, sparse_start=0.0): |
|
|
|
mask_times=1000, lr=0.05, sparse_end=0.90, sparse_start=0.0): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
networks (Cell): The training network. |
|
|
|
mask_layers (list): Description of the training network layers that need to be suppressed. |
|
|
|
policy (str): Training policy for suppress privacy training. Default: "local_train", means local training. |
|
|
|
end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10. |
|
|
|
This end_epoch parameter should be the same as 'epoch' parameter of mindspore.train.model.train(). |
|
|
|
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20. |
|
|
|
start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3. |
|
|
|
mask_times (int): The num of suppress operations. Default: 1000. |
|
|
|
lr (Union[float, int]): Learning rate, 0 < lr <= 0.5. Default: 0.10. |
|
|
|
sparse_end (Union[float, int]): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90. |
|
|
|
lr (Union[float, int]): Learning rate, should be unchanged during training. 0<lr<=0.50. Default: 0.05. |
|
|
|
This lr parameter should be the same as 'learning_rate' parameter of mindspore.nn.SGD(). |
|
|
|
sparse_end (float): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90. |
|
|
|
sparse_start (Union[float, int]): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0. |
|
|
|
|
|
|
|
Returns: |
|
|
@@ -101,7 +104,7 @@ class SuppressCtrl(Cell): |
|
|
|
start_epoch (int): The first epoch in suppress operations. |
|
|
|
mask_times (int): The num of suppress operations. |
|
|
|
lr (Union[float, int]): Learning rate. |
|
|
|
sparse_end (Union[float, int]): The sparsity to reach. |
|
|
|
sparse_end (float): The sparsity to reach. |
|
|
|
sparse_start (Union[float, int]): The sparsity to start. |
|
|
|
""" |
|
|
|
def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, |
|
|
@@ -114,12 +117,12 @@ class SuppressCtrl(Cell): |
|
|
|
self.mask_start_epoch = check_int_positive('start_epoch', start_epoch) |
|
|
|
self.mask_times = check_int_positive('mask_times', mask_times) |
|
|
|
self.lr = check_value_positive('lr', lr) |
|
|
|
self.sparse_end = check_value_non_negative('sparse_end', sparse_end) |
|
|
|
self.sparse_end = check_param_type('sparse_end', sparse_end, float) |
|
|
|
self.sparse_start = check_value_non_negative('sparse_start', sparse_start) |
|
|
|
|
|
|
|
self.weight_lower_bound = 0.005 # all network weight will be larger than this value |
|
|
|
self.sparse_vibra = 0.02 # the sparsity may have certain range of variations |
|
|
|
self.sparse_valid_max_weight = 0.20 # if max network weight is less than this value, suppress operation stop temporarily |
|
|
|
self.sparse_valid_max_weight = 0.02 # if max network weight is less than this value, suppress operation stop temporarily |
|
|
|
self.add_noise_thd = 0.50 # if network weight is more than this value, noise is forced |
|
|
|
self.noise_volume = 0.1 # noise volume 0.1 |
|
|
|
self.base_ground_thd = 0.0000001 # if network weight is less than this value, will be considered as 0 |
|
|
@@ -253,6 +256,10 @@ class SuppressCtrl(Cell): |
|
|
|
msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
msg = "this lr parameter should be the same as 'learning_rate' parameter of mindspore.nn.SGD()\n" |
|
|
|
msg += "this end_epoch parameter should be the same as 'epoch' parameter of mindspore.train.model.train()\n" |
|
|
|
msg += "sup_privacy only support SGD optimizer" |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
|
|
|
|
def update_status(self, cur_epoch, cur_step, cur_step_in_epoch): |
|
|
|
""" |
|
|
@@ -296,6 +303,7 @@ class SuppressCtrl(Cell): |
|
|
|
self.cur_sparse = self.sparse_end +\ |
|
|
|
(self.sparse_start - self.sparse_end)*\ |
|
|
|
math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) |
|
|
|
self.cur_sparse = min(self.cur_sparse, self.sparse_end) |
|
|
|
m = 0 |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
grad_idx = self.grad_idx_map[m] |
|
|
@@ -303,31 +311,58 @@ class SuppressCtrl(Cell): |
|
|
|
m = m + 1 |
|
|
|
continue |
|
|
|
if self.grads_mask_list[grad_idx].mask_able: |
|
|
|
len_array = self.grads_mask_list[grad_idx].para_num |
|
|
|
min_num = self.grads_mask_list[grad_idx].min_num |
|
|
|
sparse_min_thd = 1.0 - min(min_num, len_array) / len_array |
|
|
|
actual_stop_pos = int(len_array * min(sparse_min_thd, self.cur_sparse)) |
|
|
|
|
|
|
|
grad_mask_cell = self.grads_mask_list[grad_idx] |
|
|
|
last_sparse_pos = grad_mask_cell.sparse_pos_list[-1] |
|
|
|
if actual_stop_pos <= 0 or \ |
|
|
|
(actual_stop_pos < last_sparse_pos + grad_mask_cell.part_num and \ |
|
|
|
grad_mask_cell.is_approximity and m > 0): |
|
|
|
sparse_weight_thd = 0 |
|
|
|
msg = "{} len={}, sparse={}, current sparse thd={}, [idle] \n" \ |
|
|
|
.format(layer.name, len_array, actual_stop_pos / len_array, sparse_weight_thd) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
m = m + 1 |
|
|
|
continue |
|
|
|
|
|
|
|
weight_array = layer.data.asnumpy() |
|
|
|
weight_avg = np.mean(weight_array) |
|
|
|
weight_array_flat = weight_array.flatten() |
|
|
|
weight_array_flat_abs = np.abs(weight_array_flat) |
|
|
|
weight_abs_avg = np.mean(weight_array_flat_abs) |
|
|
|
weight_array_flat_abs.sort() |
|
|
|
len_array = weight_array.size |
|
|
|
weight_abs_max = np.max(weight_array_flat_abs) |
|
|
|
weight_abs_min = np.min(weight_array_flat_abs) |
|
|
|
|
|
|
|
if m == 0 and weight_abs_max < self.sparse_valid_max_weight: |
|
|
|
msg = "give up this masking .." |
|
|
|
msg = "layer 0 weight_abs_max = {}, give up this masking ... ".format(weight_abs_max) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
del weight_array_flat_abs |
|
|
|
del weight_array_flat |
|
|
|
del weight_array |
|
|
|
gc.collect() |
|
|
|
return |
|
|
|
if self.grads_mask_list[grad_idx].min_num > 0: |
|
|
|
sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs, |
|
|
|
self.cur_sparse, grad_idx) |
|
|
|
else: |
|
|
|
actual_stop_pos = int(len_array * self.cur_sparse) |
|
|
|
sparse_weight_thd = weight_array_flat_abs[actual_stop_pos] |
|
|
|
|
|
|
|
self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, grad_idx) |
|
|
|
if grad_mask_cell.is_approximity and m > 0: |
|
|
|
sparse_weight_thd = self.update_mask_layer_approximity(weight_array_flat, weight_array_flat_abs, |
|
|
|
actual_stop_pos, grad_idx) |
|
|
|
else: |
|
|
|
partition = np.partition(weight_array_flat_abs, actual_stop_pos - 1) |
|
|
|
sparse_weight_thd = partition[actual_stop_pos - 1] |
|
|
|
self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, |
|
|
|
weight_abs_max, grad_idx) |
|
|
|
del partition |
|
|
|
|
|
|
|
msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format( |
|
|
|
msg = "{} len={}, sparse={}, current sparse thd={}, max={}, min={}, avg={}, avg_abs={} \n".format( |
|
|
|
layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, |
|
|
|
weight_abs_max, weight_avg, weight_abs_avg) |
|
|
|
weight_abs_max, weight_abs_min, weight_avg, weight_abs_avg) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
del weight_array_flat_abs |
|
|
|
del weight_array_flat |
|
|
|
del weight_array |
|
|
|
gc.collect() |
|
|
|
m = m + 1 |
|
|
|
|
|
|
|
def update_mask_layer(self, weight_array_flat, sparse_weight_thd, sparse_stop_pos, weight_abs_max, layer_index): |
|
|
@@ -335,7 +370,7 @@ class SuppressCtrl(Cell): |
|
|
|
Update add mask arrays and multiply mask arrays of one single layer. |
|
|
|
|
|
|
|
Args: |
|
|
|
weight_array (numpy.ndarray): The weight array of layer's parameters. |
|
|
|
weight_array_flat (numpy.ndarray): The weight array of layer's parameters. |
|
|
|
sparse_weight_thd (float): The weight threshold of sparse operation. |
|
|
|
sparse_stop_pos (int): The maximum number of elements to be suppressed. |
|
|
|
weight_abs_max (float): The maximum absolute value of weights. |
|
|
@@ -358,9 +393,13 @@ class SuppressCtrl(Cell): |
|
|
|
q = 0 |
|
|
|
# add noise on weights if not masking or clipping. |
|
|
|
weight_noise_bound = min(self.add_noise_thd, max(self.noise_volume*10, weight_abs_max*0.75)) |
|
|
|
for i in range(0, weight_array_flat.size): |
|
|
|
if abs(weight_array_flat[i]) <= sparse_weight_thd: |
|
|
|
if m < weight_array_flat.size - min_num and m < sparse_stop_pos: |
|
|
|
size = self.grads_mask_list[layer_index].para_num |
|
|
|
for i in range(0, size): |
|
|
|
if mul_mask_array_flat[i] <= 0.0: |
|
|
|
add_mask_array_flat[i] = weight_array_flat[i] / self.lr |
|
|
|
m = m + 1 |
|
|
|
elif abs(weight_array_flat[i]) <= sparse_weight_thd: |
|
|
|
if m < size - min_num and m < sparse_stop_pos: |
|
|
|
# to mask |
|
|
|
mul_mask_array_flat[i] = 0.0 |
|
|
|
add_mask_array_flat[i] = weight_array_flat[i] / self.lr |
|
|
@@ -368,9 +407,11 @@ class SuppressCtrl(Cell): |
|
|
|
else: |
|
|
|
# not mask |
|
|
|
if weight_array_flat[i] > 0.0: |
|
|
|
add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr |
|
|
|
add_mask_array_flat[i] = (weight_array_flat[i] \ |
|
|
|
- min(self.weight_lower_bound, sparse_weight_thd)) / self.lr |
|
|
|
else: |
|
|
|
add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr |
|
|
|
add_mask_array_flat[i] = (weight_array_flat[i] |
|
|
|
+ min(self.weight_lower_bound, sparse_weight_thd)) / self.lr |
|
|
|
p = p + 1 |
|
|
|
elif is_lower_clip and abs(weight_array_flat[i]) <= \ |
|
|
|
self.weight_lower_bound and sparse_weight_thd > self.weight_lower_bound*0.5: |
|
|
@@ -404,28 +445,99 @@ class SuppressCtrl(Cell): |
|
|
|
"suppressed elements, max-clip elements, min-clip elements and noised elements are {}, {}, {}, {}"\ |
|
|
|
.format(len(grad_mask_cell.mul_mask_array_shape), layer_index, m, n, p, q) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
grad_mask_cell.sparse_pos_list.append(m) |
|
|
|
|
|
|
|
def calc_sparse_thd(self, array_flat, sparse_value, layer_index): |
|
|
|
def update_mask_layer_approximity(self, weight_array_flat, weight_array_flat_abs, actual_stop_pos, layer_index): |
|
|
|
""" |
|
|
|
Calculate the suppression threshold of one weight array. |
|
|
|
Update add mask arrays and multiply mask arrays of one single layer with many parameter. |
|
|
|
disable clipping loweer, clipping, adding noise operation |
|
|
|
|
|
|
|
Args: |
|
|
|
array_flat (numpy.ndarray): The flattened weight array. |
|
|
|
sparse_value (float): The target sparse value of weight array. |
|
|
|
weight_array_flat (numpy.ndarray): The weight array of layer's parameters. |
|
|
|
weight_array_flat_abs (numpy.ndarray): The abs weight array of layer's parameters. |
|
|
|
actual_stop_pos (int): The actually para num should be suppressed. |
|
|
|
layer_index (int): The index of target layer. |
|
|
|
""" |
|
|
|
grad_mask_cell = self.grads_mask_list[layer_index] |
|
|
|
mul_mask_array_flat = grad_mask_cell.mul_mask_array_flat |
|
|
|
de_weight_cell = self.de_weight_mask_list[layer_index] |
|
|
|
add_mask_array_flat = de_weight_cell.add_mask_array_flat |
|
|
|
|
|
|
|
Returns: |
|
|
|
- float, the sparse threshold of this array. |
|
|
|
part_size = grad_mask_cell.part_size |
|
|
|
part_num = grad_mask_cell.part_num |
|
|
|
para_num = grad_mask_cell.para_num |
|
|
|
init_batch_suppress = False |
|
|
|
|
|
|
|
- int, the number of weight elements to be suppressed. |
|
|
|
if not self.grads_mask_list[layer_index].mask_able: |
|
|
|
return 0.0 |
|
|
|
real_part_num = 0 |
|
|
|
sparse_thd = 0.0 |
|
|
|
last_sparse_pos = grad_mask_cell.sparse_pos_list[-1] |
|
|
|
split_k_num = max(0, int((actual_stop_pos - last_sparse_pos) / part_num)) |
|
|
|
if last_sparse_pos <= 0: |
|
|
|
init_batch_suppress = True |
|
|
|
for i in range(0, part_num): |
|
|
|
array_row_mul_mask = mul_mask_array_flat[i * part_size : (i + 1) * part_size] |
|
|
|
array_row_flat_abs = weight_array_flat_abs[i * part_size : (i + 1) * part_size] |
|
|
|
if not init_batch_suppress: |
|
|
|
array_row_flat_abs_masked = np.where(array_row_mul_mask <= 0.0, -1.0, array_row_flat_abs) |
|
|
|
set_abs = set(array_row_flat_abs_masked) |
|
|
|
set_abs.remove(-1.0) |
|
|
|
list2 = list(set_abs) |
|
|
|
val_array_align = np.array(list2) |
|
|
|
del array_row_flat_abs_masked |
|
|
|
del set_abs |
|
|
|
del list2 |
|
|
|
gc.collect() |
|
|
|
else: |
|
|
|
val_array_align = array_row_flat_abs |
|
|
|
|
|
|
|
real_split_k_num = min(split_k_num, len(val_array_align) - 1) |
|
|
|
if real_split_k_num <= 0: |
|
|
|
del array_row_flat_abs |
|
|
|
del array_row_mul_mask |
|
|
|
del val_array_align |
|
|
|
gc.collect() |
|
|
|
continue |
|
|
|
|
|
|
|
- int, the larger number of weight elements to be suppressed. |
|
|
|
""" |
|
|
|
size = len(array_flat) |
|
|
|
sparse_max_thd = 1.0 - min(self.grads_mask_list[layer_index].min_num, size) / size |
|
|
|
pos = int(size*min(sparse_max_thd, sparse_value)) |
|
|
|
thd = array_flat[pos] |
|
|
|
farther_stop_pos = int(size*min(sparse_max_thd, max(0, sparse_value + self.sparse_vibra / 2.0))) |
|
|
|
return thd, pos, farther_stop_pos |
|
|
|
partition = np.partition(val_array_align, real_split_k_num - 1) |
|
|
|
sparse_k_thd = partition[real_split_k_num - 1] |
|
|
|
if sparse_k_thd > 0 or init_batch_suppress: |
|
|
|
real_part_num = real_part_num + 1 |
|
|
|
sparse_thd = sparse_thd + sparse_k_thd |
|
|
|
del array_row_flat_abs |
|
|
|
del array_row_mul_mask |
|
|
|
del val_array_align |
|
|
|
del partition |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
if real_part_num > 0: |
|
|
|
sparse_thd = sparse_thd / real_part_num |
|
|
|
new_mul_mask_array_flat = np.where(weight_array_flat_abs <= sparse_thd, 0.0, 1.0) |
|
|
|
grad_mask_cell.mul_mask_array_flat = new_mul_mask_array_flat |
|
|
|
new_add_mask_array_flat = np.where(new_mul_mask_array_flat <= 0.0, weight_array_flat / self.lr, 0.0) |
|
|
|
de_weight_cell.add_mask_array_flat = new_add_mask_array_flat |
|
|
|
grad_mask_cell.update() |
|
|
|
de_weight_cell.update() |
|
|
|
del mul_mask_array_flat |
|
|
|
del add_mask_array_flat |
|
|
|
gc.collect() |
|
|
|
real_suppress_num = para_num - int(np.sum(grad_mask_cell.mul_mask_array_flat)) |
|
|
|
grad_mask_cell.sparse_pos_list.append(real_suppress_num) |
|
|
|
else: |
|
|
|
real_suppress_num = 0 |
|
|
|
|
|
|
|
msg = "Dimension of mask tensor is {}D, which located in the {}-th layer of the network. " \ |
|
|
|
"\n The ideal number of suppressed elements is {}/{}/{}, real suppress elements is {}" \ |
|
|
|
.format(len(grad_mask_cell.mul_mask_array_shape), layer_index, |
|
|
|
split_k_num, (actual_stop_pos - last_sparse_pos), actual_stop_pos, real_suppress_num) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
if init_batch_suppress: |
|
|
|
init_sparse_actual = real_suppress_num/para_num |
|
|
|
print("init batch suppresss, actual sparse = {}".format(init_sparse_actual)) |
|
|
|
|
|
|
|
gc.collect() |
|
|
|
return sparse_thd |
|
|
|
|
|
|
|
def reset_zeros(self): |
|
|
|
""" |
|
|
@@ -452,7 +564,6 @@ class SuppressCtrl(Cell): |
|
|
|
if array_mul_mask_flat_conv1[i] <= 0.0: |
|
|
|
sparse += 1.0 |
|
|
|
sparse_value_1 += 1.0 |
|
|
|
|
|
|
|
for i in range(0, array_mul_mask_flat_conv2.size): |
|
|
|
full = full + 1.0 |
|
|
|
full_conv2 = full_conv2 + 1.0 |
|
|
@@ -483,10 +594,13 @@ class SuppressCtrl(Cell): |
|
|
|
array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32) |
|
|
|
array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32) |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
if "conv1.weight" in layer.name: |
|
|
|
if "networks.conv1.weight" in layer.name or "networks.layers.0.weight" in layer.name: # lenet5/res50 vgg16 |
|
|
|
array_cur_conv1 = layer.data.asnumpy() |
|
|
|
if "conv2.weight" in layer.name: |
|
|
|
print("calc_actual_sparse, match conv1") |
|
|
|
if "networks.conv2.weight" in layer.name or "networks.layers.3.weight" in layer.name \ |
|
|
|
or "networks.layer1.0.conv1.weight" in layer.name: # res50 |
|
|
|
array_cur_conv2 = layer.data.asnumpy() |
|
|
|
print("calc_actual_sparse, match conv2") |
|
|
|
|
|
|
|
array_mul_mask_flat_conv1 = array_cur_conv1.flatten() |
|
|
|
array_mul_mask_flat_conv2 = array_cur_conv2.flatten() |
|
|
@@ -510,10 +624,15 @@ class SuppressCtrl(Cell): |
|
|
|
sparse_value_2 = sparse_value_2 / full_conv2 |
|
|
|
msg = "conv sparse fact={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
del array_mul_mask_flat_conv1 |
|
|
|
del array_mul_mask_flat_conv2 |
|
|
|
del array_cur_conv1 |
|
|
|
del array_cur_conv2 |
|
|
|
gc.collect() |
|
|
|
return sparse, sparse_value_1, sparse_value_2 |
|
|
|
|
|
|
|
def calc_actual_sparse_for_fc1(self, networks): |
|
|
|
self.calc_actual_sparse_for_layer(networks, "fc1.weight") |
|
|
|
return self.calc_actual_sparse_for_layer(networks, "fc1.weight") |
|
|
|
|
|
|
|
def calc_actual_sparse_for_layer(self, networks, layer_name): |
|
|
|
""" |
|
|
@@ -533,11 +652,12 @@ class SuppressCtrl(Cell): |
|
|
|
for layer in networks.get_parameters(expand=True): |
|
|
|
if layer_name in layer.name: |
|
|
|
array_cur = layer.data.asnumpy() |
|
|
|
break |
|
|
|
|
|
|
|
if array_cur is None: |
|
|
|
msg = "no such layer to calc sparse: {} ".format(layer_name) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
return |
|
|
|
return 0.0 |
|
|
|
|
|
|
|
array_cur_flat = array_cur.flatten() |
|
|
|
|
|
|
@@ -549,6 +669,10 @@ class SuppressCtrl(Cell): |
|
|
|
sparse = sparse / full |
|
|
|
msg = "{} sparse fact={} ".format(layer_name, sparse) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
del array_cur_flat |
|
|
|
del array_cur |
|
|
|
gc.collect() |
|
|
|
return sparse |
|
|
|
|
|
|
|
def print_paras(self): |
|
|
|
msg = "paras: start_epoch:{}, end_epoch:{}, batch_num:{}, interval:{}, lr:{}, sparse_end:{}, sparse_start:{}" \ |
|
|
@@ -631,6 +755,31 @@ class GradMaskInCell(Cell): |
|
|
|
self.min_num = min_num |
|
|
|
self.upper_bound = check_value_positive('upper_bound', upper_bound) |
|
|
|
|
|
|
|
self.para_num = array.size |
|
|
|
self.is_approximity = False |
|
|
|
self.sparse_pos_list = [0] |
|
|
|
self.part_num = 1 |
|
|
|
self.part_size = self.para_num |
|
|
|
self.part_num_max = 16 |
|
|
|
self.para_many_num = 10000 |
|
|
|
self.para_huge_num = 10*10000*10000 |
|
|
|
|
|
|
|
if self.para_num > self.para_many_num: |
|
|
|
self.is_approximity = True |
|
|
|
self.is_add_noise = False |
|
|
|
self.is_lower_clip = False |
|
|
|
|
|
|
|
ratio = 2 |
|
|
|
if self.part_size > self.para_huge_num: |
|
|
|
while self.part_size % ratio == 0 and self.part_size > self.para_huge_num \ |
|
|
|
and self.part_num < self.part_num_max: |
|
|
|
self.part_num = self.part_num * ratio |
|
|
|
self.part_size = int(self.part_size / ratio) |
|
|
|
msg = "this layer has {} para, disable the operation of clipping lower, clipping upper_bound, " \ |
|
|
|
"adding noise. \n part_num={}, part_size={}" \ |
|
|
|
.format(self.para_num, self.part_num, self.part_size) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
|
|
|
|
def construct(self): |
|
|
|
""" |
|
|
|
Return the mask matrix for optimization. |
|
|
|