You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 1.5 kB

7 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. import sys
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.autograd import Variable
  8. import dataset
  9. class CNN_text(nn.Module):
  10. def __init__(self, kernel_h=[3,4,5], kernel_num=100, embed_num=1000, embed_dim=300, dropout=0.5, L2_constrain=3, batchsize=50, pretrained_embeddings=None):
  11. super(CNN_text, self).__init__()
  12. self.embedding = nn.Embedding(embed_num,embed_dim)
  13. self.dropout = nn.Dropout(dropout)
  14. if pretrained_embeddings is not None:
  15. self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
  16. #the network structure
  17. #Conv2d: input- N,C,H,W output- (50,100,62,1)
  18. self.conv1 = nn.ModuleList([nn.Conv2d(1, 100, (K, 300)) for K in kernel_h])
  19. self.fc1 = nn.Linear(300,2)
  20. def max_pooling(self, x):
  21. x = F.relu(conv(x)).squeeze(3) #N,C,L - (50,100,62)
  22. x = F.max_pool1d(x, x.size(2)).squeeze(2)
  23. #x.size(2)=62 squeeze: (50,100,1) -> (50,100)
  24. return x
  25. def forward(self, x):
  26. x = self.embedding(x) #output: (N,H,W) = (50,64,300)
  27. x = x.unsqueeze(1) #(N,C,H,W)
  28. x = [F.relu(conv(x)).squeeze(3) for conv in self.conv1] #[N, C, H(50,100,62),(50,100,61),(50,100,60)]
  29. x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[N,C(50,100),(50,100),(50,100)]
  30. x = torch.cat(x,1)
  31. x = self.dropout(x)
  32. x = self.fc1(x)
  33. return x

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等