From a404c13fd4f8f1e2ef37cf34c1c2d3acb4d0983e Mon Sep 17 00:00:00 2001 From: ZhidanLiu Date: Tue, 22 Mar 2022 14:51:32 +0800 Subject: [PATCH] fix cmetrics problem --- mindarmour/privacy/diff_privacy/train/model.py | 6 ++---- mindarmour/privacy/sup_privacy/mask_monitor/masker.py | 1 + mindarmour/privacy/sup_privacy/train/model.py | 3 ++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mindarmour/privacy/diff_privacy/train/model.py b/mindarmour/privacy/diff_privacy/train/model.py index faade7f..25e4a4b 100644 --- a/mindarmour/privacy/diff_privacy/train/model.py +++ b/mindarmour/privacy/diff_privacy/train/model.py @@ -621,10 +621,8 @@ class _TrainOneStepCell(Cell): self._noise_mech_param_updater = _MechanismsParamsUpdater( decay_policy=self._noise_mech._decay_policy, decay_rate=self._noise_mech._noise_decay_rate, - cur_noise_multiplier= - self._noise_mech._noise_multiplier, - init_noise_multiplier= - self._noise_mech._initial_noise_multiplier) + cur_noise_multiplier=self._noise_mech._noise_multiplier, + init_noise_multiplier=self._noise_mech._initial_noise_multiplier) def construct(self, data, label): """ diff --git a/mindarmour/privacy/sup_privacy/mask_monitor/masker.py b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py index ef7f906..cad07c2 100644 --- a/mindarmour/privacy/sup_privacy/mask_monitor/masker.py +++ b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py @@ -23,6 +23,7 @@ from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl LOGGER = LogUtil.get_instance() TAG = 'suppress masker' + class SuppressMasker(Callback): """ For details, please check `Tutorial `_ diff --git a/mindarmour/privacy/sup_privacy/train/model.py b/mindarmour/privacy/sup_privacy/train/model.py index 3b06a8d..dedb6ff 100644 --- a/mindarmour/privacy/sup_privacy/train/model.py +++ b/mindarmour/privacy/sup_privacy/train/model.py @@ -184,6 +184,7 @@ class _TupleAdd(nn.Cell): out = self.hyper_map(self.add, input1, input2) return out + class _TupleMul(nn.Cell): """ Mul two tuple of data. @@ -198,7 +199,7 @@ class _TupleMul(nn.Cell): out = self.hyper_map(self.mul, input1, input2) return out -# come from nn.cell_wrapper.TrainOneStepCell + class TrainOneStepCell(Cell): r""" Network training package class.