From 58ddc2d26759154c82dba54121992f94adfe9631 Mon Sep 17 00:00:00 2001 From: choocewhatulike <1901722105@qq.com> Date: Mon, 12 Mar 2018 17:38:22 +0800 Subject: [PATCH] add train --- model_inplement/.gitignore | 1 + .../code/__pycache__/model.cpython-36.pyc | Bin 2223 -> 2342 bytes model_inplement/code/model.py | 63 +++------- model_inplement/code/train.py | 138 +++++++++++++++++++++ 4 files changed, 153 insertions(+), 49 deletions(-) create mode 100644 model_inplement/.gitignore diff --git a/model_inplement/.gitignore b/model_inplement/.gitignore new file mode 100644 index 00000000..7e99e367 --- /dev/null +++ b/model_inplement/.gitignore @@ -0,0 +1 @@ +*.pyc \ No newline at end of file diff --git a/model_inplement/code/__pycache__/model.cpython-36.pyc b/model_inplement/code/__pycache__/model.cpython-36.pyc index 2e18ae7f3f1addfca6601d873620127d292e35aa..167f6ff9d4401d8077c23d66909e6c96b6049a11 100644 GIT binary patch delta 1274 zcmV<{9000002><{91poj5 z0{{R3b5@b^OCn?fWCM)>d;nhxWB^kDDFjaG0ssI2=>Y%$04V|517vA!lTQIz1PV-V zXOpP`B!AiqZ*_EVb#z~IX?kVa5O;5KWM61$WMpM-Uvp`CW!evSZ*pW`Zgp*6Y+-q2 za&y`dcW-iJUt@1>bY*ySUvp`CW!eyPWo~p|XlZ0*Wo}<{X?kVa4|8R1bYE_DZC`9* zd1Z2Q+7WYQZggK`Z*FvDcywQLX?kTT0on;)Uz1D%MF9wtk^(S)V*&sG00000015yA z00jU507C!(0CRW%003tIeFAt;d;vrNT}XWcX8?T!cpQ8KL;zhBeFc03Y5;ryY5{x& zg8_p90RV9rR{(SZY6E-(WC3IXjRJiJd;n?$deF%Ji0BQw%0BQz&2!jEG0euO42~z+m15R=Q0000f0dfWa001cm+6G~8 zaAj^}atZ(d0NMp~Z*pU3+5=-@baD#-0043j0001T4gdfEDGG8C0001531VS%V`yJw zZ(~{vWN%|%c4cE7V_F7tUubP;M2d<0|%g93vAd;nh)d;n?_Y7}Y}WC>&lg8~5nWB^kDDGE+TR4D~=5&!@I z+6!oDWMpM-Uvp`CW!egEb!}g4VR>b8bJ~+k1t1P`0RR91DFJc@0001T1(R+C7!U^l z0043c0001T2mk;8auSo41vf$r0002m14mL-+5=~Db#e^=007zoW@Fj}RAFvt+5~iA zZfM#EQ*UN;ZDDwF5C8xG+6ho$a$#*{bY*gK9smFU+LH|i8a&zvb#7^9Z*py4DFt#7 z0001T5&!@IaxDM=0CFw>0043?0000f0dg7u0043u0001T8~^|SatHtb04bAW1|k6t zlZFN`0b-N423vn>34;KA27Crn04V}aas>bY04WM`9smFU+6i@Tb8&TLWqM_DF#rGn z+6QxSb!BCGWpXqC0043_0001T5C8xG+5&BDayS400NMd`DF<>90002m25D|^b#!xD z1884#+6HK8WMpM-+5vUi1z~J(Xkl_A0001T8vp(^bawY%(0CFb)0043b0001TBme*aawq@* z0CF1u0043u0000f0dg7u0043v0001T2><{9C;$Kea|8eY00;pL7IFyy001cr+68WP zZE$&V9smFU+6Z)Sa${&NZfsFV`wg6b#!lMa$#h00ssI2S`&1WNeB`F k+LK%e92E%w0043u0001T8vpoGYt(2HlU+gkqVItphzYF003nGS^;DLS^|dvW&vaYWCDZ;S_5PT zWCUadgaKLvhX4csS_M-8DF@mENI_0x2><{9000002><{91poj50{{R3b5fD;OA>0SA-10VID~3vYFFaCLNFb7^{IS`c?{a%5jb8b6OE|Wo~p|V{dMBWq5R7b7^{IDFNCEUtg15 z0!0A^lcNGKf7$_eV*>yH0000000sa600RI307C!(0CPwH004XfWC3deawdELY5;rz zg8_X5d;n?zd;n?!d;^04g8_X6d<0Vfd;(+wYXWjSd;n?#d;x<2eFc04Qvh@XWCMc% z0RUtGQvfLhPTB!;aufgn0D2mAZfkCDcOYzKc4cfJe`#(YP+@XmZDn+2av>=NatQzc z0CElh0043f0001T2mk;8+68cOX>N2W267Jo007znc-jSQWp-t3+5vXj25)tAaCLNY z6#xJLauNUl0CExl0043m0002m2WD?_cVTj582|tPa}WRk0009B0SW;S0R#aE0SW;G z0(uKb5kXEaW^ZzLVRB?C0df=o005IR1Scy70001T8vpE0ek>o6MO(_6KWJ{6=Vlw27>|t0Av7D04WGgMpP*Uau5Ij09p%ZX=G$&ZeMe0 zdSzM)Zgp*6Y+-q2a&uafu>~Lk=>e191sD(o0001T2LJ#7as~hZ0CEtM2L?Al3IG5A zS_4N?RoVk*a&>YG0002m0%l`c1XN*eXxapHVQy%04FCWDat;6h0NM#qVRB(@Wprh7 z+68oPa${&(lYj;qJX#2KZfRz3a&2EJ1#%An0043j0001TA^-pYaw7l$0CFS%001cg zauEOk0CExl0043m0001T1^@s6D3jI(A^{4M`35inUXv;ZTYr59dM?JbaPq*XkT<%254zyWMyvJ0d?91VQg?{VR9M(0043l z0001T5&!@Ia(@#50043u0000!0001U4gdfE00ImF4gm}S5CIec5CM7=L3DIwZggpH zZcb%%E@p3XcVTj5DF$*J0001T9RL6TavlHx0CENZ0043u0001T9{>OVauNUl0CExl z001cgauEOk0CE!m0043a0000d0001U1ONa42muThav}!+001cr+68WPZE$&VDF6Tf zS_pJ+a${&NZfV>}Y}x@v+6GN;WOZz1asvPW0CEQa c0043l0001T5&!@IauNUl0CE$P1PLJl9{@=q*#H0l diff --git a/model_inplement/code/model.py b/model_inplement/code/model.py index 32ebdbf9..a4c2e59b 100644 --- a/model_inplement/code/model.py +++ b/model_inplement/code/model.py @@ -20,16 +20,20 @@ class HAN(nn.Module): sent_num_layers, sent_context_size) self.output_layer = nn.Linear(2* sent_hidden_size, output_size) - self.softmax = nn.Softmax() + self.softmax = nn.LogSoftmax(dim=1) - def forward(self, doc): + def forward(self, batch_doc): # input is a sequence of vector # if level == w, a seq of words (a sent); level == s, a seq of sents (a doc) - s_list = [] - for sent in doc: - s_list.append(self.word_layer(sent)) - s_vec = torch.cat(s_list, dim=1).t() - doc_vec = self.sent_layer(s_vec) + doc_vec_list = [] + for doc in batch_doc: + s_list = [] + for sent in doc: + s_list.append(self.word_layer(sent)) + s_vec = torch.cat(s_list, dim=0) + vec = self.sent_layer(s_vec) + doc_vec_list.append(vec) + doc_vec = torch.cat(doc_vec_list, dim=0) output = self.softmax(self.output_layer(doc_vec)) return output @@ -51,7 +55,7 @@ class AttentionNet(nn.Module): # Attention self.fc = nn.Linear(2* gru_hidden_size, context_vec_size) self.tanh = nn.Tanh() - self.softmax = nn.Softmax() + self.softmax = nn.Softmax(dim=0) # context vector self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1)) self.context_vec.data.uniform_(-0.1, 0.1) @@ -63,47 +67,8 @@ class AttentionNet(nn.Module): h_t = torch.squeeze(h_t, 1) u = self.tanh(self.fc(h_t)) alpha = self.softmax(torch.mm(u, self.context_vec)) - output = torch.mm(h_t.t(), alpha) - # output's dim (2*hidden_size, 1) + output = torch.mm(h_t.t(), alpha).t() + # output's dim (1, 2*hidden_size) return output -''' -Train process -''' -import math -import os -import copy -import pickle - -import matplotlib.pyplot as plt -import matplotlib.ticker as ticker -import numpy as np -import json -import nltk - -optimizer = torch.optim.SGD(lr=0.01) -criterion = nn.NLLLoss() -epoch = 1 -batch_size = 10 - -net = HAN(input_size=100, output_size=5, - word_hidden_size=50, word_num_layers=1, word_context_size=100, - sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) - -def dataloader(filename): - samples = pickle.load(open(filename, 'rb')) - return samples - -def gen_doc(text): - pass - -class SampleDoc: - def __init__(self, doc, label): - self.doc = doc - self.label = label - - def __iter__(self): - for sent in self.doc: - for word in sent: - diff --git a/model_inplement/code/train.py b/model_inplement/code/train.py index e69de29b..ae7ee925 100644 --- a/model_inplement/code/train.py +++ b/model_inplement/code/train.py @@ -0,0 +1,138 @@ +import gensim +from gensim import models + +import os +import pickle + +class SampleIter: + def __init__(self, dirname): + self.dirname = dirname + + def __iter__(self): + for f in os.listdir(self.dirname): + for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')): + yield x, y + +class SentIter: + def __init__(self, dirname, count): + self.dirname = dirname + self.count = int(count) + + def __iter__(self): + for f in os.listdir(self.dirname)[:self.count]: + for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')): + for sent in x: + yield sent + +def train_word_vec(): + # load data + dirname = 'reviews' + sents = SentIter(dirname, 238) + # define model and train + model = models.Word2Vec(sentences=sents, size=200, sg=0, workers=4, min_count=5) + model.save('yelp.word2vec') + + +''' +Train process +''' +import math +import os +import copy +import pickle + +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import json +import nltk +from gensim.models import Word2Vec +import torch +from torch.utils.data import DataLoader, Dataset + +from model import * + +net = HAN(input_size=200, output_size=5, + word_hidden_size=50, word_num_layers=1, word_context_size=100, + sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) + +optimizer = torch.optim.SGD(net.parameters(), lr=0.01) +criterion = nn.NLLLoss() +num_epoch = 1 +batch_size = 64 + +class Embedding_layer: + def __init__(self, wv, vector_size): + self.wv = wv + self.vector_size = vector_size + + def get_vec(self, w): + try: + v = self.wv[w] + except KeyError as e: + v = np.zeros(self.vector_size) + return v + +embed_model = Word2Vec.load('yelp.word2vec') +embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) +del embed_model + +class YelpDocSet(Dataset): + def __init__(self, dirname, num_files, embedding): + self.dirname = dirname + self.num_files = num_files + self._len = num_files*5000 + self._files = os.listdir(dirname)[:num_files] + self.embedding = embedding + + def __len__(self): + return self._len + + def __getitem__(self, n): + file_id = n // 5000 + sample_list = pickle.load(open( + os.path.join(self.dirname, self._files[file_id]), 'rb')) + y, x = sample_list[n % 5000] + return x, y-1 + +def collate(iterable): + y_list = [] + x_list = [] + for x, y in iterable: + y_list.append(y) + x_list.append(x) + return x_list, torch.LongTensor(y_list) + +if __name__ == '__main__': + dirname = 'reviews' + dataloader = DataLoader(YelpDocSet(dirname, 238, embedding), batch_size=batch_size, collate_fn=collate) + running_loss = 0.0 + print_size = 10 + + for epoch in range(num_epoch): + for i, batch_samples in enumerate(dataloader): + x, y = batch_samples + doc_list = [] + for sample in x: + doc = [] + for sent in sample: + sent_vec = [] + for word in sent: + vec = embedding.get_vec(word) + sent_vec.append(torch.Tensor(vec.reshape((1, -1)))) + sent_vec = torch.cat(sent_vec, dim=0) + # print(sent_vec.size()) + doc.append(Variable(sent_vec)) + doc_list.append(doc) + y = Variable(y) + predict = net(doc_list) + loss = criterion(predict, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + running_loss += loss.data[0] + print(loss.data[0]) + if i % print_size == print_size-1: + print(running_loss/print_size) + running_loss = 0.0 +