Browse Source

!307 Add device_target check for DP Train

Merge pull request !307 from jxlang910/master
tags/v1.6.0
i-robot Gitee 3 years ago
parent
commit
35ac2fc8db
1 changed files with 7 additions and 0 deletions
  1. +7
    -0
      mindarmour/privacy/diff_privacy/train/model.py

+ 7
- 0
mindarmour/privacy/diff_privacy/train/model.py View File

@@ -159,6 +159,13 @@ class DPModel(Model):
self._clip_mech = clip_mech
super(DPModel, self).__init__(**kwargs)

# judge device_target, only GPU or Ascend is supported until now
device = context.get_context("device_target")
if device not in ["GPU", "Ascend"]:
msg = "'device_target' or DP training should be 'GPU' or 'Ascend', but got {}.".format(device)
LOGGER.error(TAG, msg)
raise ValueError(msg)

def _amp_build_train_network(self, network, optimizer, loss_fn=None,
level='O0', **kwargs):
"""


Loading…
Cancel
Save