You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.9 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import json
  7. INF = 1E10
  8. EPS = 1E-12
  9. logger = logging.getLogger(__name__)
  10. logger.setLevel(logging.INFO)
  11. def get_length(mask):
  12. length = torch.sum(mask, 1)
  13. length = length.long().cpu()
  14. return length
  15. class GlobalAvgPool(nn.Module):
  16. def forward(self, x, mask):
  17. x = torch.sum(x, 2)
  18. length = torch.sum(mask, 1, keepdim=True).float()
  19. length += torch.eq(length, 0.0).float() * EPS
  20. length = length.repeat(1, x.size()[1])
  21. x /= length
  22. return x
  23. class GlobalMaxPool(nn.Module):
  24. def forward(self, x, mask):
  25. mask = torch.eq(mask.float(), 0.0).long()
  26. mask = torch.unsqueeze(mask, dim=1).repeat(1, x.size()[1], 1)
  27. mask *= -INF
  28. x += mask
  29. x, _ = torch.max(x + mask, 2)
  30. return x
  31. class IteratorWrapper:
  32. def __init__(self, loader):
  33. self.loader = loader
  34. self.iterator = None
  35. def __iter__(self):
  36. self.iterator = iter(self.loader)
  37. return self
  38. def __len__(self):
  39. return len(self.loader)
  40. def __next__(self):
  41. data = next(self.iterator)
  42. text, length = data.text
  43. max_length = text.size(1)
  44. label = data.label - 1
  45. bs = label.size(0)
  46. mask = torch.arange(max_length, device=length.device).unsqueeze(0).repeat(bs, 1)
  47. mask = mask < length.unsqueeze(-1).repeat(1, max_length)
  48. return (text, mask), label
  49. def accuracy(output, target):
  50. batch_size = target.size(0)
  51. _, predicted = torch.max(output.data, 1)
  52. return (predicted == target).sum().item() / batch_size
  53. def dump_global_result(res_path,global_result, sort_keys = False):
  54. with open(res_path, "w") as ss_file:
  55. json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能