You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 2.8 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import numpy
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. import torch.nn.functional as F
  6. class HAN(nn.Module):
  7. def __init__(self, input_size, output_size,
  8. word_hidden_size, word_num_layers, word_context_size,
  9. sent_hidden_size, sent_num_layers, sent_context_size):
  10. super(HAN, self).__init__()
  11. self.word_layer = AttentionNet(input_size,
  12. word_hidden_size,
  13. word_num_layers,
  14. word_context_size)
  15. self.sent_layer = AttentionNet(2* word_hidden_size,
  16. sent_hidden_size,
  17. sent_num_layers,
  18. sent_context_size)
  19. self.output_layer = nn.Linear(2* sent_hidden_size, output_size)
  20. self.softmax = nn.LogSoftmax(dim=1)
  21. def forward(self, batch_doc):
  22. # input is a sequence of vector
  23. # if level == w, a seq of words (a sent); level == s, a seq of sents (a doc)
  24. doc_vec_list = []
  25. for doc in batch_doc:
  26. s_list = []
  27. for sent in doc:
  28. s_list.append(self.word_layer(sent))
  29. s_vec = torch.cat(s_list, dim=0)
  30. vec = self.sent_layer(s_vec)
  31. doc_vec_list.append(vec)
  32. doc_vec = torch.cat(doc_vec_list, dim=0)
  33. output = self.softmax(self.output_layer(doc_vec))
  34. return output
  35. class AttentionNet(nn.Module):
  36. def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
  37. super(AttentionNet, self).__init__()
  38. self.input_size = input_size
  39. self.gru_hidden_size = gru_hidden_size
  40. self.gru_num_layers = gru_num_layers
  41. self.context_vec_size = context_vec_size
  42. # Encoder
  43. self.gru = nn.GRU(input_size=input_size,
  44. hidden_size=gru_hidden_size,
  45. num_layers=gru_num_layers,
  46. batch_first=False,
  47. bidirectional=True)
  48. # Attention
  49. self.fc = nn.Linear(2* gru_hidden_size, context_vec_size)
  50. self.tanh = nn.Tanh()
  51. self.softmax = nn.Softmax(dim=0)
  52. # context vector
  53. self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
  54. self.context_vec.data.uniform_(-0.1, 0.1)
  55. def forward(self, inputs):
  56. # inputs's dim (seq_len, word_dim)
  57. inputs = torch.unsqueeze(inputs, 1)
  58. h_t, hidden = self.gru(inputs)
  59. h_t = torch.squeeze(h_t, 1)
  60. u = self.tanh(self.fc(h_t))
  61. alpha = self.softmax(torch.mm(u, self.context_vec))
  62. output = torch.mm(h_t.t(), alpha).t()
  63. # output's dim (1, 2*hidden_size)
  64. return output

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等