You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # -*- coding:utf-8 -*-
  2. from __future__ import absolute_import
  3. import os
  4. import random
  5. from collections import defaultdict
  6. import pickle
  7. import logging
  8. from AveragedPerceptron import AveragedPerceptron
  9. PICKLE = "data/trontagger-0.1.0.pickle"
  10. class PerceptronTagger():
  11. '''Greedy Averaged Perceptron tagger, as implemented by Matthew Honnibal.
  12. See more implementation details here:
  13. http://honnibal.wordpress.com/2013/09/11/a-good-part-of-speechpos-tagger-in-about-200-lines-of-python/
  14. :param load: Load the pickled model upon instantiation.
  15. '''
  16. START = ['-START-', '-START2-']
  17. END = ['-END-', '-END2-']
  18. AP_MODEL_LOC = os.path.join(os.path.dirname(__file__), PICKLE)
  19. def __init__(self, load=True):
  20. self.model = AveragedPerceptron()
  21. self.tagdict = {}
  22. self.classes = set()
  23. if load:
  24. self.load(self.AP_MODEL_LOC)
  25. def tag(self, corpus):
  26. '''Tags a string `corpus`.'''
  27. # Assume untokenized corpus has \n between sentences and ' ' between words
  28. s_split = lambda t: t.split('\n')
  29. w_split = lambda s: s.split()
  30. def split_sents(corpus):
  31. for s in s_split(corpus):
  32. yield w_split(s)
  33. prev, prev2 = self.START
  34. tokens = []
  35. for words in split_sents(corpus):
  36. context = self.START + [self._normalize(w) for w in words] + self.END
  37. for i, word in enumerate(words):
  38. tag = self.tagdict.get(word)
  39. if not tag:
  40. features = self._get_features(i, word, context, prev, prev2)
  41. tag = self.model.predict(features)
  42. tokens.append((word, tag))
  43. prev2 = prev
  44. prev = tag
  45. return tokens
  46. def train(self, sentences, save_loc=None, nr_iter=5):
  47. '''Train a model from sentences, and save it at ``save_loc``. ``nr_iter``
  48. controls the number of Perceptron training iterations.
  49. :param sentences: A list of (words, tags) tuples.
  50. :param save_loc: If not ``None``, saves a pickled model in this location.
  51. :param nr_iter: Number of training iterations.
  52. '''
  53. self._make_tagdict(sentences)
  54. self.model.classes = self.classes
  55. for iter_ in range(nr_iter):
  56. c = 0
  57. n = 0
  58. for words, tags in sentences:
  59. prev, prev2 = self.START
  60. context = self.START + [self._normalize(w) for w in words] \
  61. + self.END
  62. for i, word in enumerate(words):
  63. guess = self.tagdict.get(word)
  64. if not guess:
  65. feats = self._get_features(i, word, context, prev, prev2)
  66. guess = self.model.predict(feats)
  67. self.model.update(tags[i], guess, feats)
  68. prev2 = prev
  69. prev = guess
  70. c += guess == tags[i]
  71. n += 1
  72. random.shuffle(sentences)
  73. logging.info("Iter {0}: {1}/{2}={3}".format(iter_, c, n, _pc(c, n)))
  74. self.model.average_weights()
  75. # Pickle as a binary file
  76. if save_loc is not None:
  77. pickle.dump((self.model.weights, self.tagdict, self.classes),
  78. open(save_loc, 'wb'), -1)
  79. return None
  80. def load(self, loc):
  81. '''Load a pickled model.'''
  82. try:
  83. w_td_c = pickle.load(open(loc, 'rb'))
  84. except IOError:
  85. msg = ("Missing trontagger.pickle file.")
  86. raise IOError(msg)
  87. self.model.weights, self.tagdict, self.classes = w_td_c
  88. self.model.classes = self.classes
  89. return None
  90. def _normalize(self, word):
  91. '''Normalization used in pre-processing.
  92. - All words are lower cased
  93. - Digits in the range 1800-2100 are represented as !YEAR;
  94. - Other digits are represented as !DIGITS
  95. :rtype: str
  96. '''
  97. if '-' in word and word[0] != '-':
  98. return '!HYPHEN'
  99. elif word.isdigit() and len(word) == 4:
  100. return '!YEAR'
  101. elif word[0].isdigit():
  102. return '!DIGITS'
  103. else:
  104. return word.lower()
  105. def _get_features(self, i, word, context, prev, prev2):
  106. '''Map tokens into a feature representation, implemented as a
  107. {hashable: float} dict. If the features change, a new model must be
  108. trained.
  109. '''
  110. def add(name, *args):
  111. features[' '.join((name,) + tuple(args))] += 1
  112. i += len(self.START)
  113. features = defaultdict(int)
  114. # It's useful to have a constant feature, which acts sort of like a prior
  115. add('bias')
  116. add('i suffix', word[-3:])
  117. add('i pref1', word[0])
  118. add('i-1 tag', prev)
  119. add('i-2 tag', prev2)
  120. add('i tag+i-2 tag', prev, prev2)
  121. add('i word', context[i])
  122. add('i-1 tag+i word', prev, context[i])
  123. add('i-1 word', context[i - 1])
  124. add('i-1 suffix', context[i - 1][-3:])
  125. add('i-2 word', context[i - 2])
  126. add('i+1 word', context[i + 1])
  127. add('i+1 suffix', context[i + 1][-3:])
  128. add('i+2 word', context[i + 2])
  129. return features
  130. def _make_tagdict(self, sentences):
  131. '''Make a tag dictionary for single-tag words.'''
  132. counts = defaultdict(lambda: defaultdict(int))
  133. for words, tags in sentences:
  134. for word, tag in zip(words, tags):
  135. counts[word][tag] += 1
  136. self.classes.add(tag)
  137. freq_thresh = 20
  138. ambiguity_thresh = 0.97
  139. for word, tag_freqs in counts.items():
  140. tag, mode = max(tag_freqs.items(), key=lambda item: item[1])
  141. n = sum(tag_freqs.values())
  142. # Don't add rare words to the tag dictionary
  143. # Only add quite unambiguous words
  144. if n >= freq_thresh and (float(mode) / n) >= ambiguity_thresh:
  145. self.tagdict[word] = tag
  146. def _pc(n, d):
  147. return (float(n) / d) * 100
  148. if __name__ == '__main__':
  149. logging.basicConfig(level=logging.INFO)
  150. tagger = PerceptronTagger(False)
  151. try:
  152. tagger.load(PICKLE)
  153. print(tagger.tag('how are you ?'))
  154. logging.info('Start testing...')
  155. right = 0.0
  156. total = 0.0
  157. sentence = ([], [])
  158. for line in open('data/test.txt'):
  159. params = line.split()
  160. if len(params) != 2: continue
  161. sentence[0].append(params[0])
  162. sentence[1].append(params[1])
  163. if params[0] == '.':
  164. text = ''
  165. words = sentence[0]
  166. tags = sentence[1]
  167. for i, word in enumerate(words):
  168. text += word
  169. if i < len(words): text += ' '
  170. outputs = tagger.tag(text)
  171. assert len(tags) == len(outputs)
  172. total += len(tags)
  173. for o, t in zip(outputs, tags):
  174. if o[1].strip() == t: right += 1
  175. sentence = ([], [])
  176. logging.info("Precision : %f", right / total)
  177. except IOError:
  178. logging.info('Reading corpus...')
  179. training_data = []
  180. sentence = ([], [])
  181. for line in open('data/train.txt'):
  182. params = line.split('\t')
  183. sentence[0].append(params[0])
  184. sentence[1].append(params[1])
  185. if params[0] == '.':
  186. training_data.append(sentence)
  187. sentence = ([], [])
  188. logging.info('training corpus size : %d', len(training_data))
  189. logging.info('Start training...')
  190. tagger.train(training_data, save_loc=PICKLE)

Jiagu使用大规模语料训练而成。将提供中文分词、词性标注、命名实体识别、情感分析、知识图谱关系抽取、关键词抽取、文本摘要、新词发现、情感分析、文本聚类等常用自然语言处理功能。参考了各大工具优缺点制作,将Jiagu回馈给大家