From b94508c7022f323f40bd77ce02d45816c32bc02b Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Fri, 24 Dec 2021 09:25:48 +0800 Subject: [PATCH] Add device_target check for DP Train --- mindarmour/privacy/diff_privacy/train/model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mindarmour/privacy/diff_privacy/train/model.py b/mindarmour/privacy/diff_privacy/train/model.py index b510b2b..93fc952 100644 --- a/mindarmour/privacy/diff_privacy/train/model.py +++ b/mindarmour/privacy/diff_privacy/train/model.py @@ -159,6 +159,13 @@ class DPModel(Model): self._clip_mech = clip_mech super(DPModel, self).__init__(**kwargs) + # judge device_target, only GPU or Ascend is supported until now + device = context.get_context("device_target") + if device not in ["GPU", "Ascend"]: + msg = "'device_target' or DP training should be 'GPU' or 'Ascend', but got {}.".format(device) + LOGGER.error(TAG, msg) + raise ValueError(msg) + def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): """