You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 3.4 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import numpy
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. import torch.nn.functional as F
  6. class HAN(nn.Module):
  7. def __init__(self, input_size, output_size,
  8. word_hidden_size, word_num_layers, word_context_size,
  9. sent_hidden_size, sent_num_layers, sent_context_size):
  10. super(HAN, self).__init__()
  11. self.word_layer = AttentionNet(input_size,
  12. word_hidden_size,
  13. word_num_layers,
  14. word_context_size)
  15. self.sent_layer = AttentionNet(2* word_hidden_size,
  16. sent_hidden_size,
  17. sent_num_layers,
  18. sent_context_size)
  19. self.output_layer = nn.Linear(2* sent_hidden_size, output_size)
  20. self.softmax = nn.Softmax()
  21. def forward(self, doc):
  22. # input is a sequence of vector
  23. # if level == w, a seq of words (a sent); level == s, a seq of sents (a doc)
  24. s_list = []
  25. for sent in doc:
  26. s_list.append(self.word_layer(sent))
  27. s_vec = torch.cat(s_list, dim=1).t()
  28. doc_vec = self.sent_layer(s_vec)
  29. output = self.softmax(self.output_layer(doc_vec))
  30. return output
  31. class AttentionNet(nn.Module):
  32. def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
  33. super(AttentionNet, self).__init__()
  34. self.input_size = input_size
  35. self.gru_hidden_size = gru_hidden_size
  36. self.gru_num_layers = gru_num_layers
  37. self.context_vec_size = context_vec_size
  38. # Encoder
  39. self.gru = nn.GRU(input_size=input_size,
  40. hidden_size=gru_hidden_size,
  41. num_layers=gru_num_layers,
  42. batch_first=False,
  43. bidirectional=True)
  44. # Attention
  45. self.fc = nn.Linear(2* gru_hidden_size, context_vec_size)
  46. self.tanh = nn.Tanh()
  47. self.softmax = nn.Softmax()
  48. # context vector
  49. self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
  50. self.context_vec.data.uniform_(-0.1, 0.1)
  51. def forward(self, inputs):
  52. # inputs's dim (seq_len, word_dim)
  53. inputs = torch.unsqueeze(inputs, 1)
  54. h_t, hidden = self.gru(inputs)
  55. h_t = torch.squeeze(h_t, 1)
  56. u = self.tanh(self.fc(h_t))
  57. alpha = self.softmax(torch.mm(u, self.context_vec))
  58. output = torch.mm(h_t.t(), alpha)
  59. # output's dim (2*hidden_size, 1)
  60. return output
  61. '''
  62. Train process
  63. '''
  64. import math
  65. import os
  66. import copy
  67. import pickle
  68. import matplotlib.pyplot as plt
  69. import matplotlib.ticker as ticker
  70. import numpy as np
  71. import json
  72. import nltk
  73. optimizer = torch.optim.SGD(lr=0.01)
  74. criterion = nn.NLLLoss()
  75. epoch = 1
  76. batch_size = 10
  77. net = HAN(input_size=100, output_size=5,
  78. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  79. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  80. def dataloader(filename):
  81. samples = pickle.load(open(filename, 'rb'))
  82. return samples
  83. def gen_doc(text):
  84. pass
  85. class SampleDoc:
  86. def __init__(self, doc, label):
  87. self.doc = doc
  88. self.label = label
  89. def __iter__(self):
  90. for sent in self.doc:
  91. for word in sent:

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等