You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

losses.py 5.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # ------------------------------------------------------------------------------
  2. # Portions of this code are from
  3. # CornerNet (https://github.com/princeton-vl/CornerNet)
  4. # Copyright (c) 2018, University of Michigan
  5. # Licensed under the BSD 3-Clause License
  6. # Modified by Nguyen Mau Dung (2020.08.09)
  7. # ------------------------------------------------------------------------------
  8. import os
  9. import sys
  10. import math
  11. import torch.nn as nn
  12. import torch
  13. import torch.nn.functional as F
  14. src_dir = os.path.dirname(os.path.realpath(__file__))
  15. # while not src_dir.endswith("sfa"):
  16. # src_dir = os.path.dirname(src_dir)
  17. if src_dir not in sys.path:
  18. sys.path.append(src_dir)
  19. from utils.torch_utils import to_cpu, _sigmoid
  20. def _gather_feat(feat, ind, mask=None):
  21. dim = feat.size(2)
  22. ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  23. feat = feat.gather(1, ind)
  24. if mask is not None:
  25. mask = mask.unsqueeze(2).expand_as(feat)
  26. feat = feat[mask]
  27. feat = feat.view(-1, dim)
  28. return feat
  29. def _transpose_and_gather_feat(feat, ind):
  30. feat = feat.permute(0, 2, 3, 1).contiguous()
  31. feat = feat.view(feat.size(0), -1, feat.size(3))
  32. feat = _gather_feat(feat, ind)
  33. return feat
  34. def _neg_loss(pred, gt, alpha=2, beta=4):
  35. ''' Modified focal loss. Exactly the same as CornerNet.
  36. Runs faster and costs a little bit more memory
  37. Arguments:
  38. pred (batch x c x h x w)
  39. gt_regr (batch x c x h x w)
  40. '''
  41. pos_inds = gt.eq(1).float()
  42. neg_inds = gt.lt(1).float()
  43. neg_weights = torch.pow(1 - gt, beta)
  44. loss = 0
  45. pos_loss = torch.log(pred) * torch.pow(1 - pred, alpha) * pos_inds
  46. neg_loss = torch.log(1 - pred) * torch.pow(pred, alpha) * neg_weights * neg_inds
  47. num_pos = pos_inds.float().sum()
  48. pos_loss = pos_loss.sum()
  49. neg_loss = neg_loss.sum()
  50. if num_pos == 0:
  51. loss = loss - neg_loss
  52. else:
  53. loss = loss - (pos_loss + neg_loss) / num_pos
  54. return loss
  55. class FocalLoss(nn.Module):
  56. '''nn.Module warpper for focal loss'''
  57. def __init__(self):
  58. super(FocalLoss, self).__init__()
  59. self.neg_loss = _neg_loss
  60. def forward(self, out, target):
  61. return self.neg_loss(out, target)
  62. class L1Loss(nn.Module):
  63. def __init__(self):
  64. super(L1Loss, self).__init__()
  65. def forward(self, output, mask, ind, target):
  66. pred = _transpose_and_gather_feat(output, ind)
  67. mask = mask.unsqueeze(2).expand_as(pred).float()
  68. loss = F.l1_loss(pred * mask, target * mask, size_average=False)
  69. loss = loss / (mask.sum() + 1e-4)
  70. return loss
  71. class L1Loss_Balanced(nn.Module):
  72. """Balanced L1 Loss
  73. paper: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
  74. Code refer from: https://github.com/OceanPang/Libra_R-CNN
  75. """
  76. def __init__(self, alpha=0.5, gamma=1.5, beta=1.0):
  77. super(L1Loss_Balanced, self).__init__()
  78. self.alpha = alpha
  79. self.gamma = gamma
  80. assert beta > 0
  81. self.beta = beta
  82. def forward(self, output, mask, ind, target):
  83. pred = _transpose_and_gather_feat(output, ind)
  84. mask = mask.unsqueeze(2).expand_as(pred).float()
  85. loss = self.balanced_l1_loss(pred * mask, target * mask)
  86. loss = loss.sum() / (mask.sum() + 1e-4)
  87. return loss
  88. def balanced_l1_loss(self, pred, target):
  89. assert pred.size() == target.size() and target.numel() > 0
  90. diff = torch.abs(pred - target)
  91. b = math.exp(self.gamma / self.alpha) - 1
  92. loss = torch.where(diff < self.beta,
  93. self.alpha / b * (b * diff + 1) * torch.log(b * diff / self.beta + 1) - self.alpha * diff,
  94. self.gamma * diff + self.gamma / b - self.alpha * self.beta)
  95. return loss
  96. class Compute_Loss(nn.Module):
  97. def __init__(self, device):
  98. super(Compute_Loss, self).__init__()
  99. self.device = device
  100. self.focal_loss = FocalLoss()
  101. self.l1_loss = L1Loss()
  102. self.l1_loss_balanced = L1Loss_Balanced(alpha=0.5, gamma=1.5, beta=1.0)
  103. self.weight_hm_cen = 1.
  104. self.weight_z_coor, self.weight_cenoff, self.weight_dim, self.weight_direction = 1., 1., 1., 1.
  105. def forward(self, outputs, tg):
  106. # tg: targets
  107. outputs['hm_cen'] = _sigmoid(outputs['hm_cen'])
  108. outputs['cen_offset'] = _sigmoid(outputs['cen_offset'])
  109. l_hm_cen = self.focal_loss(outputs['hm_cen'], tg['hm_cen'])
  110. l_cen_offset = self.l1_loss(outputs['cen_offset'], tg['obj_mask'], tg['indices_center'], tg['cen_offset'])
  111. l_direction = self.l1_loss(outputs['direction'], tg['obj_mask'], tg['indices_center'], tg['direction'])
  112. # Apply the L1_loss balanced for z coor and dimension regression
  113. l_z_coor = self.l1_loss_balanced(outputs['z_coor'], tg['obj_mask'], tg['indices_center'], tg['z_coor'])
  114. l_dim = self.l1_loss_balanced(outputs['dim'], tg['obj_mask'], tg['indices_center'], tg['dim'])
  115. total_loss = l_hm_cen * self.weight_hm_cen + l_cen_offset * self.weight_cenoff + \
  116. l_dim * self.weight_dim + l_direction * self.weight_direction + \
  117. l_z_coor * self.weight_z_coor
  118. loss_stats = {
  119. 'total_loss': to_cpu(total_loss).item(),
  120. 'hm_cen_loss': to_cpu(l_hm_cen).item(),
  121. 'cen_offset_loss': to_cpu(l_cen_offset).item(),
  122. 'dim_loss': to_cpu(l_dim).item(),
  123. 'direction_loss': to_cpu(l_direction).item(),
  124. 'z_coor_loss': to_cpu(l_z_coor).item(),
  125. }
  126. return total_loss, loss_stats

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能