|
|
@@ -159,6 +159,13 @@ class DPModel(Model): |
|
|
|
self._clip_mech = clip_mech |
|
|
|
super(DPModel, self).__init__(**kwargs) |
|
|
|
|
|
|
|
# judge device_target, only GPU or Ascend is supported until now |
|
|
|
device = context.get_context("device_target") |
|
|
|
if device not in ["GPU", "Ascend"]: |
|
|
|
msg = "'device_target' or DP training should be 'GPU' or 'Ascend', but got {}.".format(device) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
def _amp_build_train_network(self, network, optimizer, loss_fn=None, |
|
|
|
level='O0', **kwargs): |
|
|
|
""" |
|
|
|