|
|
@@ -47,6 +47,7 @@ class InversionLoss(Cell): |
|
|
|
self._mse_loss = MSELoss() |
|
|
|
self._weights = check_param_multi_types('weights', weights, [list, tuple]) |
|
|
|
self._get_shape = P.Shape() |
|
|
|
self._zeros = P.ZerosLike() |
|
|
|
|
|
|
|
def construct(self, input_data, target_features): |
|
|
|
""" |
|
|
@@ -65,15 +66,11 @@ class InversionLoss(Cell): |
|
|
|
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) |
|
|
|
|
|
|
|
data_shape = self._get_shape(input_data) |
|
|
|
split_op_1 = P.Split(2, data_shape[2]) |
|
|
|
split_op_2 = P.Split(3, data_shape[3]) |
|
|
|
data_split_1 = split_op_1(input_data) |
|
|
|
data_split_2 = split_op_2(input_data) |
|
|
|
loss_2 = 0 |
|
|
|
for i in range(1, data_shape[2]): |
|
|
|
loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1]) |
|
|
|
for j in range(1, data_shape[3]): |
|
|
|
loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1]) |
|
|
|
data_copy_1 = self._zeros(input_data) |
|
|
|
data_copy_2 = self._zeros(input_data) |
|
|
|
data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :] |
|
|
|
data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] |
|
|
|
loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) |
|
|
|
|
|
|
|
loss_3 = self._mse_loss(input_data, 0) |
|
|
|
|
|
|
|