|
|
@@ -35,20 +35,20 @@ class SuppressPrivacyFactory: |
|
|
|
pass |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None, |
|
|
|
lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None): |
|
|
|
def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, |
|
|
|
mask_times=500, lr=0.10, sparse_end=0.90, sparse_start=0.0): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
policy (str): Training policy for suppress privacy training. "local_train" means local training. |
|
|
|
end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . |
|
|
|
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size . |
|
|
|
start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . |
|
|
|
mask_times (int): The num of suppress operations. |
|
|
|
networks (Cell): The training network. |
|
|
|
lr (float): Learning rate. |
|
|
|
sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 . |
|
|
|
sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 . |
|
|
|
mask_layers (list): Description of the training network layers that need to be suppressed. |
|
|
|
policy (str): Training policy for suppress privacy training. Default: "local_train", means local training. |
|
|
|
end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10. |
|
|
|
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20. |
|
|
|
start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3. |
|
|
|
mask_times (int): The num of suppress operations. Default: 500. |
|
|
|
lr (Union[float, int]): Learning rate, 0 < lr <= 0.5. Default: 0.10. |
|
|
|
sparse_end (float): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90. |
|
|
|
sparse_start (float): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0. |
|
|
|
|
|
|
|
Returns: |
|
|
|
SuppressCtrl, class of Suppress Privavy Mechanism. |
|
|
@@ -84,8 +84,8 @@ class SuppressPrivacyFactory: |
|
|
|
dataset_sink_mode=False) |
|
|
|
""" |
|
|
|
if policy == "local_train": |
|
|
|
return SuppressCtrl(networks, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end, |
|
|
|
sparse_start, mask_layers) |
|
|
|
return SuppressCtrl(networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, |
|
|
|
sparse_end, sparse_start) |
|
|
|
msg = "Only local training is supported now, federal training will be supported " \ |
|
|
|
"in the future. But got {}.".format(policy) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
@@ -95,6 +95,7 @@ class SuppressCtrl(Cell): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
networks (Cell): The training network. |
|
|
|
mask_layers (list): Description of those layers that need to be suppressed. |
|
|
|
end_epoch (int): The last epoch in suppress operations. |
|
|
|
batch_num (int): The num of grad operation in an epoch. |
|
|
|
mask_start_epoch (int): The first epoch in suppress operations. |
|
|
@@ -102,14 +103,12 @@ class SuppressCtrl(Cell): |
|
|
|
lr (Union[float, int]): Learning rate. |
|
|
|
sparse_end (Union[float, int]): The sparsity to reach. |
|
|
|
sparse_start (float): The sparsity to start. |
|
|
|
mask_layers (list): Description of those layers that need to be suppressed. |
|
|
|
""" |
|
|
|
def __init__(self, networks, end_epoch, batch_num, mask_start_epoch=3, mask_times=500, lr=0.05, |
|
|
|
sparse_end=0.60, |
|
|
|
sparse_start=0.0, |
|
|
|
mask_layers=None): |
|
|
|
def __init__(self, networks, mask_layers, end_epoch, batch_num, mask_start_epoch, mask_times, lr, |
|
|
|
sparse_end, sparse_start): |
|
|
|
super(SuppressCtrl, self).__init__() |
|
|
|
self.networks = check_param_type('networks', networks, Cell) |
|
|
|
self.mask_layers = check_param_type('mask_layers', mask_layers, list) |
|
|
|
self.mask_end_epoch = check_int_positive('end_epoch', end_epoch) |
|
|
|
self.batch_num = check_int_positive('batch_num', batch_num) |
|
|
|
self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch) |
|
|
@@ -117,7 +116,6 @@ class SuppressCtrl(Cell): |
|
|
|
self.lr = check_value_positive('lr', lr) |
|
|
|
self.sparse_end = check_value_non_negative('sparse_end', sparse_end) |
|
|
|
self.sparse_start = check_value_non_negative('sparse_start', sparse_start) |
|
|
|
self.mask_layers = check_param_type('mask_layers', mask_layers, list) |
|
|
|
|
|
|
|
self.weight_lower_bound = 0.005 # all network weight will be larger than this value |
|
|
|
self.sparse_vibra = 0.02 # the sparsity may have certain range of variations |
|
|
@@ -137,13 +135,19 @@ class SuppressCtrl(Cell): |
|
|
|
self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation |
|
|
|
self.mask_initialized = False # flag means the initialization is done |
|
|
|
|
|
|
|
if self.lr > 0.5: |
|
|
|
msg = "learning rate should be smaller than 0.5, but got {}".format(self.lr) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self.mask_start_epoch > self.mask_end_epoch: |
|
|
|
msg = "start_epoch error: {}".format(self.mask_start_epoch) |
|
|
|
msg = "start_epoch should not be greater than end_epoch, but got start_epoch and end_epoch are: " \ |
|
|
|
"{}, {}".format(self.mask_start_epoch, self.mask_end_epoch) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self.mask_end_epoch > 100: |
|
|
|
msg = "end_epoch error: {}".format(self.mask_end_epoch) |
|
|
|
msg = "The end_epoch should be smaller than 100, but got {}".format(self.mask_end_epoch) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
@@ -152,13 +156,14 @@ class SuppressCtrl(Cell): |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self.sparse_end > 1.00 or self.sparse_end <= 0: |
|
|
|
msg = "sparse_end error: {}".format(self.sparse_end) |
|
|
|
if self.sparse_end >= 1.00 or self.sparse_end <= 0: |
|
|
|
msg = "sparse_end should be in range (0, 1), but got {}".format(self.sparse_end) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self.sparse_start >= self.sparse_end: |
|
|
|
msg = "sparse_start error: {}".format(self.sparse_start) |
|
|
|
msg = "sparse_start should be smaller than sparse_end, but got sparse_start and sparse_end are: " \ |
|
|
|
"{}, {}".format(self.sparse_start, self.sparse_end) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|