diff --git a/mindarmour/privacy/diff_privacy/train/model.py b/mindarmour/privacy/diff_privacy/train/model.py index c68c119..7af3480 100644 --- a/mindarmour/privacy/diff_privacy/train/model.py +++ b/mindarmour/privacy/diff_privacy/train/model.py @@ -137,7 +137,7 @@ class DPModel(Model): if clip_mech is None or isinstance(clip_mech, Cell): self._clip_mech = clip_mech - super(DPModel, self).__init__(**kwargs) + super(DPModel, self).__init__(optimizer=optimizer, **kwargs) # judge device_target, only GPU or Ascend is supported until now device = context.get_context("device_target")