You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. from utils import get_length, INF
  7. class Mask(nn.Module):
  8. def forward(self, seq, mask):
  9. # seq: (N, C, L)
  10. # mask: (N, L)
  11. seq_mask = torch.unsqueeze(mask, 2)
  12. seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2)
  13. return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq))
  14. def __str__(self):
  15. return 'Mask'
  16. class BatchNorm(nn.Module):
  17. def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True):
  18. super(BatchNorm, self).__init__()
  19. self.mask_opt = Mask()
  20. self.pre_mask = pre_mask
  21. self.post_mask = post_mask
  22. self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine)
  23. def forward(self, seq, mask):
  24. if self.pre_mask:
  25. seq = self.mask_opt(seq, mask)
  26. seq = self.bn(seq)
  27. if self.post_mask:
  28. seq = self.mask_opt(seq, mask)
  29. return seq
  30. def __str__(self):
  31. return 'BatchNorm'
  32. class ConvBN(nn.Module):
  33. def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob,
  34. pre_mask, post_mask, with_bn=True, with_relu=True):
  35. super(ConvBN, self).__init__()
  36. self.mask_opt = Mask()
  37. self.pre_mask = pre_mask
  38. self.post_mask = post_mask
  39. self.with_bn = with_bn
  40. self.with_relu = with_relu
  41. self.kernal_size = kernal_size
  42. self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, padding=(kernal_size - 1) // 2)
  43. self.dropout = nn.Dropout(p=(1 - cnn_keep_prob))
  44. if with_bn:
  45. self.bn = BatchNorm(out_channels, not post_mask, True)
  46. if with_relu:
  47. self.relu = nn.ReLU()
  48. def forward(self, seq, mask):
  49. if self.pre_mask:
  50. seq = self.mask_opt(seq, mask)
  51. seq = self.conv(seq)
  52. if self.post_mask:
  53. seq = self.mask_opt(seq, mask)
  54. if self.with_bn:
  55. seq = self.bn(seq, mask)
  56. if self.with_relu:
  57. seq = self.relu(seq)
  58. seq = self.dropout(seq)
  59. return seq
  60. def __str__(self):
  61. return 'ConvBN_{}'.format(self.kernal_size)
  62. class AvgPool(nn.Module):
  63. def __init__(self, kernal_size, pre_mask, post_mask):
  64. super(AvgPool, self).__init__()
  65. self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
  66. self.pre_mask = pre_mask
  67. self.post_mask = post_mask
  68. self.mask_opt = Mask()
  69. self.kernal_size = kernal_size
  70. def forward(self, seq, mask):
  71. if self.pre_mask:
  72. seq = self.mask_opt(seq, mask)
  73. seq = self.avg_pool(seq)
  74. if self.post_mask:
  75. seq = self.mask_opt(seq, mask)
  76. return seq
  77. def __str__(self):
  78. return 'AvgPool{}'.format(self.kernal_size)
  79. class MaxPool(nn.Module):
  80. def __init__(self, kernal_size, pre_mask, post_mask):
  81. super(MaxPool, self).__init__()
  82. self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
  83. self.pre_mask = pre_mask
  84. self.post_mask = post_mask
  85. self.mask_opt = Mask()
  86. self.kernel_size = kernal_size
  87. def forward(self, seq, mask):
  88. if self.pre_mask:
  89. seq = self.mask_opt(seq, mask)
  90. seq = self.max_pool(seq)
  91. if self.post_mask:
  92. seq = self.mask_opt(seq, mask)
  93. return seq
  94. def __str__(self):
  95. return 'MaxPool{}'.format(self.kernel_size)
  96. class Attention(nn.Module):
  97. def __init__(self, num_units, num_heads, keep_prob, is_mask):
  98. super(Attention, self).__init__()
  99. self.num_heads = num_heads
  100. self.keep_prob = keep_prob
  101. self.linear_q = nn.Linear(num_units, num_units)
  102. self.linear_k = nn.Linear(num_units, num_units)
  103. self.linear_v = nn.Linear(num_units, num_units)
  104. self.bn = BatchNorm(num_units, True, is_mask)
  105. self.dropout = nn.Dropout(p=1 - self.keep_prob)
  106. def forward(self, seq, mask):
  107. in_c = seq.size()[1]
  108. seq = torch.transpose(seq, 1, 2) # (N, L, C)
  109. queries = seq
  110. keys = seq
  111. num_heads = self.num_heads
  112. # T_q = T_k = L
  113. Q = F.relu(self.linear_q(seq)) # (N, T_q, C)
  114. K = F.relu(self.linear_k(seq)) # (N, T_k, C)
  115. V = F.relu(self.linear_v(seq)) # (N, T_k, C)
  116. # Split and concat
  117. Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
  118. K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
  119. V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
  120. # Multiplication
  121. outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
  122. # Scale
  123. outputs = outputs / (K_.size()[-1] ** 0.5)
  124. # Key Masking
  125. key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
  126. key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
  127. key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
  128. paddings = torch.ones_like(outputs) * (-INF) # extremely small value
  129. outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
  130. query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
  131. query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
  132. query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
  133. att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
  134. att_scores = self.dropout(att_scores)
  135. # Weighted sum
  136. x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
  137. # Restore shape
  138. x_outputs = torch.cat(
  139. torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
  140. dim=2) # (N, T_q, C)
  141. x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
  142. x = self.bn(x, mask)
  143. return x
  144. def __str__(self):
  145. return 'Attention'
  146. class RNN(nn.Module):
  147. def __init__(self, hidden_size, output_keep_prob):
  148. super(RNN, self).__init__()
  149. self.hidden_size = hidden_size
  150. self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
  151. self.output_keep_prob = output_keep_prob
  152. self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob))
  153. def forward(self, seq, mask):
  154. # seq: (N, C, L)
  155. # mask: (N, L)
  156. max_len = seq.size()[2]
  157. length = get_length(mask)
  158. seq = torch.transpose(seq, 1, 2) # to (N, L, C)
  159. packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True,
  160. enforce_sorted=False)
  161. outputs, _ = self.bid_rnn(packed_seq)
  162. outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True,
  163. total_length=max_len)[0]
  164. outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C)
  165. outputs = self.out_dropout(outputs) # output dropout
  166. return torch.transpose(outputs, 1, 2) # back to: (N, C, L)
  167. def __str__(self):
  168. return 'RNN'
  169. class LinearCombine(nn.Module):
  170. def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False):
  171. super(LinearCombine, self).__init__()
  172. self.input_aware = input_aware
  173. self.word_level = word_level
  174. if input_aware:
  175. raise NotImplementedError("Input aware is not supported.")
  176. self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num),
  177. requires_grad=trainable)
  178. def forward(self, seq):
  179. nw = F.softmax(self.w, dim=0)
  180. seq = torch.mul(seq, nw)
  181. seq = torch.sum(seq, dim=0)
  182. return seq

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能