Browse Source

refine git commits

tags/v0.4.10
yunfan 6 years ago
parent
commit
2aaa381827
7 changed files with 32 additions and 39 deletions
  1. +12
    -21
      fastNLP/api/api.py
  2. +8
    -3
      fastNLP/core/dataset.py
  3. +4
    -4
      fastNLP/core/metrics.py
  4. +6
    -7
      fastNLP/core/trainer.py
  5. +0
    -2
      fastNLP/models/sequence_modeling.py
  6. +1
    -1
      reproduction/pos_tag_model/pos_tag.cfg
  7. +1
    -1
      setup.py

+ 12
- 21
fastNLP/api/api.py View File

@@ -6,6 +6,7 @@ warnings.filterwarnings('ignore')
import os

from fastNLP.core.dataset import DataSet

from fastNLP.api.model_zoo import load_url
from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader
@@ -120,7 +121,7 @@ class POS(API):
f1 = round(test_result['F'] * 100, 2)
pre = round(test_result['P'] * 100, 2)
rec = round(test_result['R'] * 100, 2)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))

return f1, pre, rec

@@ -179,7 +180,7 @@ class CWS(API):
f1 = round(f1 * 100, 2)
pre = round(pre * 100, 2)
rec = round(rec * 100, 2)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))

return f1, pre, rec

@@ -251,30 +252,23 @@ class Parser(API):


class Analyzer:
def __init__(self, seg=True, pos=True, parser=True, device='cpu'):

self.seg = seg
self.pos = pos
self.parser = parser
def __init__(self, device='cpu'):

if self.seg:
self.cws = CWS(device=device)
if self.pos:
self.pos = POS(device=device)
if parser:
self.parser = None
self.cws = CWS(device=device)
self.pos = POS(device=device)
self.parser = Parser(device=device)

def predict(self, content, seg=False, pos=False, parser=False):
if seg is False and pos is False and parser is False:
seg = True
output_dict = {}
if self.seg:
if seg:
seg_output = self.cws.predict(content)
output_dict['seg'] = seg_output
if self.pos:
if pos:
pos_output = self.pos.predict(content)
output_dict['pos'] = pos_output
if self.parser:
if parser:
parser_output = self.parser.predict(content)
output_dict['parser'] = parser_output

@@ -301,7 +295,7 @@ if __name__ == "__main__":
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' ,
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# '那么这款无人机到底有多厉害?']
# print(pos.test('/Users/yh/Desktop/test_data/small_test.conll'))
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll'))
# print(pos.predict(s))

# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
@@ -317,7 +311,4 @@ if __name__ == "__main__":
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(cws.test('/Users/yh/Desktop/test_data/small_test.conll'))
print(cws.predict(s))


print(parser.predict(s))

+ 8
- 3
fastNLP/core/dataset.py View File

@@ -313,9 +313,14 @@ class DataSet(object):
for col in headers:
_dict[col] = []
for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep)
assert len(contents)==len(headers), "Line {} has {} parts, while header has {}."\
.format(line_idx, len(contents), len(headers))
contents = line.split(sep)
if len(contents)!=len(headers):
if dropna:
continue
else:
#TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts."\
.format(line_idx, len(contents), len(headers)))
for header, content in zip(headers, contents):
_dict[header].append(content)
return cls(_dict)

+ 4
- 4
fastNLP/core/metrics.py View File

@@ -38,15 +38,15 @@ class SeqLabelEvaluator(Evaluator):
def __call__(self, predict, truth, **_):
"""

:param predict: list of dict, the network outputs from all batches.
:param predict: list of List, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
"""
total_correct, total_count = 0., 0.
total_correct, total_count = 0., 0.
for x, y in zip(predict, truth):
# x = torch.tensor(x)
x = torch.tensor(x)
y = y.to(x) # make sure they are in the same device
mask = (y > 0)
mask = (y > 0)
correct = torch.sum(((x == y) * mask).long())
total_correct += float(correct)
total_count += float(torch.sum(mask.long()))


+ 6
- 7
fastNLP/core/trainer.py View File

@@ -4,6 +4,7 @@ from datetime import datetime
import warnings
from collections import defaultdict
import os
import itertools
import shutil

from tensorboardX import SummaryWriter
@@ -121,10 +122,7 @@ class Trainer(object):
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(model, batch_x)

# TODO: refactor self.get_loss
loss = prediction["loss"] if "loss" in prediction else self.get_loss(prediction, batch_y)
# acc = self._evaluator([{"predict": prediction["predict"]}], [{"truth": batch_x["truth"]}])

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
@@ -133,7 +131,7 @@ class Trainer(object):
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if n_print > 0 and self.step % n_print == 0:
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
@@ -241,7 +239,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No

batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
_syn_model_data(model, batch_x, batch_y)
# forward check
if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level,
@@ -269,7 +267,8 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
model_name, loss.size()
))
loss.backward()
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE:
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break

if dev_data is not None:


+ 0
- 2
fastNLP/models/sequence_modeling.py View File

@@ -1,4 +1,3 @@
import numpy as np
import torch
import numpy as np

@@ -141,7 +140,6 @@ class AdvSeqLabel(SeqLabeling):
idx_sort = idx_sort.cuda()
idx_unsort = idx_unsort.cuda()
self.mask = self.mask.cuda()
truth = truth.cuda() if truth is not None else None

x = self.Embedding(word_seq)
x = self.norm1(x)


+ 1
- 1
reproduction/pos_tag_model/pos_tag.cfg View File

@@ -36,4 +36,4 @@ pickle_path = "./save/"
use_crf = true
use_cuda = true
rnn_hidden_units = 100
word_emb_dim = 100
word_emb_dim = 100

+ 1
- 1
setup.py View File

@@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f:

setup(
name='fastNLP',
version='0.1.0',
version='0.1.1',
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
long_description=readme,
license=license,


Loading…
Cancel
Save