Browse Source

add base methods for model.base_model

tags/v0.4.10
FengZiYjun 7 years ago
parent
commit
7b46f422c7
18 changed files with 84 additions and 40 deletions
  1. +4
    -0
      .idea/deployment.xml
  2. +11
    -0
      .idea/fastNLP.iml
  3. +4
    -0
      .idea/misc.xml
  4. +8
    -0
      .idea/modules.xml
  5. +6
    -0
      .idea/vcs.xml
  6. +1
    -1
      README.md
  7. +0
    -1
      action/README.md
  8. +4
    -4
      action/action.py
  9. +1
    -1
      action/tester.py
  10. +1
    -1
      action/trainer.py
  11. +1
    -2
      loader/config_loader.py
  12. +20
    -0
      model/base_model.py
  13. +6
    -12
      reproduction/CNN-sentence_classification/train.py
  14. +9
    -9
      reproduction/Char-aware_NLM/test.py
  15. +4
    -5
      reproduction/Char-aware_NLM/train.py
  16. +3
    -3
      saver/base_saver.py
  17. +1
    -1
      saver/logger.py
  18. +0
    -0
      tests/test_loader.py

+ 4
- 0
.idea/deployment.xml View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" persistUploadOnCheckin="false" />
</project>

+ 11
- 0
.idea/fastNLP.iml View File

@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
</component>
</module>

+ 4
- 0
.idea/misc.xml View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.5 (PCA_emb)" project-jdk-type="Python SDK" />
</project>

+ 8
- 0
.idea/modules.xml View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/fastNLP.iml" filepath="$PROJECT_DIR$/.idea/fastNLP.iml" />
</modules>
</component>
</project>

+ 6
- 0
.idea/vcs.xml View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

+ 1
- 1
README.md View File

@@ -1,2 +1,2 @@
# FastNLP # FastNLP
FastNLP
FastNLP

model/empty.txt → action/README.md View File

@@ -1,4 +1,3 @@
Some useful reference:
SpaCy "Doc" SpaCy "Doc"
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80 https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80



+ 4
- 4
action/action.py View File

@@ -8,10 +8,10 @@ class Action(object):
self.logger = None self.logger = None


def load_config(self, args): def load_config(self, args):
pass
raise NotImplementedError


def load_dataset(self, args): def load_dataset(self, args):
pass
raise NotImplementedError


def log(self, args): def log(self, args):
self.logger.log(args) self.logger.log(args)
@@ -22,7 +22,7 @@ class Action(object):


def batchify(self, X, Y=None): def batchify(self, X, Y=None):
# a generator # a generator
pass
raise NotImplementedError


def make_log(self, *args): def make_log(self, *args):
pass
raise NotImplementedError

+ 1
- 1
action/tester.py View File

@@ -29,7 +29,7 @@ class Tester(Action):
for step in range(iterations): for step in range(iterations):
batch_x, batch_y = test_batch_generator.__next__() batch_x, batch_y = test_batch_generator.__next__()


# forward pass from test input to predicted output
# forward pass from tests input to predicted output
prediction = network.data_forward(batch_x) prediction = network.data_forward(batch_x)


# get the loss # get the loss


+ 1
- 1
action/trainer.py View File

@@ -11,4 +11,4 @@ class Trainer(Action):
self.arg = arg self.arg = arg


def train(self, args): def train(self, args):
pass
raise NotImplementedError

+ 1
- 2
loader/config_loader.py View File

@@ -10,5 +10,4 @@ class ConfigLoader(BaseLoader):


@staticmethod @staticmethod
def parse(string): def parse(string):
# To do
return string
raise NotImplementedError

+ 20
- 0
model/base_model.py View File

@@ -0,0 +1,20 @@
class BaseModel(object):
"""base model for all models"""

def __init__(self):
pass

def prepare_input(self, data):
raise NotImplementedError

def mode(self, test=False):
raise NotImplementedError

def data_forward(self, x):
raise NotImplementedError

def grad_backward(self):
raise NotImplementedError

def loss(self, pred, truth):
raise NotImplementedError

+ 6
- 12
reproduction/CNN-sentence_classification/train.py View File

@@ -1,17 +1,12 @@
import os import os
import torch

import
import torch
import torch.nn as nn import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import dataset as dst
from model import CNN_text
.dataset as dst
from .model import CNN_text
from torch.autograd import Variable from torch.autograd import Variable


from sklearn import cross_validation
from sklearn import datasets



# Hyper Parameters # Hyper Parameters
batch_size = 50 batch_size = 50
learning_rate = 0.0001 learning_rate = 0.0001
@@ -51,8 +46,7 @@ if cuda:
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)



#train and test
# train and tests
best_acc = None best_acc = None


for epoch in range(num_epochs): for epoch in range(num_epochs):


+ 9
- 9
reproduction/Char-aware_NLM/test.py View File

@@ -1,12 +1,12 @@
import os import os
from collections import namedtuple

import numpy as np
import torch import torch
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from model import charLM
from torch.autograd import Variable
from utilities import * from utilities import *
from collections import namedtuple



def to_var(x): def to_var(x):
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -76,18 +76,18 @@ if __name__ == "__main__":




if os.path.exists("cache/data_sets.pt") is False: if os.path.exists("cache/data_sets.pt") is False:
test_text = read_data("./test.txt")
test_text = read_data("./tests.txt")
test_set = np.array(text2vec(test_text, char_dict, max_word_len)) test_set = np.array(text2vec(test_text, char_dict, max_word_len))


# Labels are next-word index in word_dict with the same length as inputs # Labels are next-word index in word_dict with the same length as inputs
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])


category = {"test": test_set, "tlabel":test_label}
category = {"tests": test_set, "tlabel": test_label}
torch.save(category, "cache/data_sets.pt") torch.save(category, "cache/data_sets.pt")
else: else:
data_sets = torch.load("cache/data_sets.pt") data_sets = torch.load("cache/data_sets.pt")
test_set = data_sets["test"]
test_set = data_sets["tests"]
test_label = data_sets["tlabel"] test_label = data_sets["tlabel"]
train_set = data_sets["tdata"] train_set = data_sets["tdata"]
train_label = data_sets["trlabel"] train_label = data_sets["trlabel"]


+ 4
- 5
reproduction/Char-aware_NLM/train.py View File

@@ -13,8 +13,7 @@ from .utilities import *




def preprocess(): def preprocess():
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "test.txt")
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "tests.txt")
num_words = len(word_dict) num_words = len(word_dict)
num_char = len(char_dict) num_char = len(char_dict)
char_dict["BOW"] = num_char+1 char_dict["BOW"] = num_char+1
@@ -195,7 +194,7 @@ if __name__=="__main__":
if os.path.exists("cache/data_sets.pt") is False: if os.path.exists("cache/data_sets.pt") is False:
train_text = read_data("./train.txt") train_text = read_data("./train.txt")
valid_text = read_data("./valid.txt") valid_text = read_data("./valid.txt")
test_text = read_data("./test.txt")
test_text = read_data("./tests.txt")


train_set = np.array(text2vec(train_text, char_dict, max_word_len)) train_set = np.array(text2vec(train_text, char_dict, max_word_len))
valid_set = np.array(text2vec(valid_text, char_dict, max_word_len)) valid_set = np.array(text2vec(valid_text, char_dict, max_word_len))
@@ -206,14 +205,14 @@ if __name__=="__main__":
valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]]) valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]])
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])


category = {"tdata":train_set, "vdata":valid_set, "test": test_set,
category = {"tdata": train_set, "vdata": valid_set, "tests": test_set,
"trlabel":train_label, "vlabel":valid_label, "tlabel":test_label} "trlabel":train_label, "vlabel":valid_label, "tlabel":test_label}
torch.save(category, "cache/data_sets.pt") torch.save(category, "cache/data_sets.pt")
else: else:
data_sets = torch.load("cache/data_sets.pt") data_sets = torch.load("cache/data_sets.pt")
train_set = data_sets["tdata"] train_set = data_sets["tdata"]
valid_set = data_sets["vdata"] valid_set = data_sets["vdata"]
test_set = data_sets["test"]
test_set = data_sets["tests"]
train_label = data_sets["trlabel"] train_label = data_sets["trlabel"]
valid_label = data_sets["vlabel"] valid_label = data_sets["vlabel"]
test_label = data_sets["tlabel"] test_label = data_sets["tlabel"]


+ 3
- 3
saver/base_saver.py View File

@@ -5,10 +5,10 @@ class BaseSaver(object):
self.save_path = save_path self.save_path = save_path


def save_bytes(self): def save_bytes(self):
pass
raise NotImplementedError


def save_str(self): def save_str(self):
pass
raise NotImplementedError


def compress(self): def compress(self):
pass
raise NotImplementedError

+ 1
- 1
saver/logger.py View File

@@ -8,4 +8,4 @@ class Logger(BaseSaver):
super(Logger, self).__init__(save_path) super(Logger, self).__init__(save_path)


def log(self, string): def log(self, string):
pass
raise NotImplementedError

test/test_loader.py → tests/test_loader.py View File


Loading…
Cancel
Save