import torch def seq_lens_to_mask(seq_lens): batch_size = seq_lens.size(0) max_len = seq_lens.max() indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) masks = indexes.lt(seq_lens.unsqueeze(1)) return masks from itertools import chain def refine_ys_on_seq_len(ys, seq_lens): refined_ys = [] for b_idx, length in enumerate(seq_lens): refined_ys.append(list(ys[b_idx][:length])) return refined_ys def flat_nested_list(nested_list): return list(chain(*nested_list)) def calculate_pre_rec_f1(model, batcher, type='segapp'): true_ys, pred_ys = decode_iterator(model, batcher) true_ys = flat_nested_list(true_ys) pred_ys = flat_nested_list(pred_ys) cor_num = 0 start = 0 if type=='segapp': yp_wordnum = pred_ys.count(1) yt_wordnum = true_ys.count(1) if true_ys[0]==1 and pred_ys[0]==1: cor_num += 1 start = 1 for i in range(1, len(true_ys)): if true_ys[i] == 1: flag = True if true_ys[start-1] != pred_ys[start-1]: flag = False else: for j in range(start, i + 1): if true_ys[j] != pred_ys[j]: flag = False break if flag: cor_num += 1 start = i + 1 elif type=='bmes': yp_wordnum = pred_ys.count(2) + pred_ys.count(3) yt_wordnum = true_ys.count(2) + true_ys.count(3) for i in range(len(true_ys)): if true_ys[i] == 2 or true_ys[i] == 3: flag = True for j in range(start, i + 1): if true_ys[j] != pred_ys[j]: flag = False break if flag: cor_num += 1 start = i + 1 P = cor_num / (float(yp_wordnum) + 1e-6) R = cor_num / (float(yt_wordnum) + 1e-6) F = 2 * P * R / (P + R + 1e-6) # print(cor_num, yt_wordnum, yp_wordnum) return P, R, F def decode_iterator(model, batcher): true_ys = [] pred_ys = [] seq_lens = [] with torch.no_grad(): model.eval() for batch_x, batch_y in batcher: pred_dict = model.predict(**batch_x) seq_len = batch_x['seq_lens'].cpu().numpy() pred_y = pred_dict['pred_tags'] true_y = batch_y['tags'] pred_y = pred_y.cpu().numpy() true_y = true_y.cpu().numpy() true_ys.extend(true_y.tolist()) pred_ys.extend(pred_y.tolist()) seq_lens.extend(list(seq_len)) model.train() true_ys = refine_ys_on_seq_len(true_ys, seq_lens) pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) return true_ys, pred_ys from torch import nn import torch.nn.functional as F class FocalLoss(nn.Module): r""" This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection. Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) The losses are averaged across observations for each minibatch. Args: alpha(1D Tensor, Variable) : the scalar factor for this criterion gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), putting more focus on hard, misclassified examples size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. However, if the field size_average is set to False, the losses are instead summed for each minibatch. """ def __init__(self, class_num, gamma=2, size_average=True, reduce=False): super(FocalLoss, self).__init__() self.gamma = gamma self.class_num = class_num self.size_average = size_average self.reduce = reduce def forward(self, inputs, targets): N = inputs.size(0) C = inputs.size(1) P = F.softmax(inputs, dim=-1) class_mask = inputs.data.new(N, C).fill_(0) class_mask.requires_grad = True ids = targets.view(-1, 1) class_mask = class_mask.scatter(1, ids.data, 1.) probs = (P * class_mask).sum(1).view(-1, 1) log_p = probs.log() batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p if self.reduce: if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss return batch_loss