diff --git a/mindarmour/privacy/diff_privacy/train/model.py b/mindarmour/privacy/diff_privacy/train/model.py index b510b2b..93fc952 100644 --- a/mindarmour/privacy/diff_privacy/train/model.py +++ b/mindarmour/privacy/diff_privacy/train/model.py @@ -159,6 +159,13 @@ class DPModel(Model): self._clip_mech = clip_mech super(DPModel, self).__init__(**kwargs) + # judge device_target, only GPU or Ascend is supported until now + device = context.get_context("device_target") + if device not in ["GPU", "Ascend"]: + msg = "'device_target' or DP training should be 'GPU' or 'Ascend', but got {}.".format(device) + LOGGER.error(TAG, msg) + raise ValueError(msg) + def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): """