|
|
@@ -17,7 +17,11 @@ __all__ = [ |
|
|
|
import re |
|
|
|
import warnings |
|
|
|
|
|
|
|
from nltk import Tree |
|
|
|
try: |
|
|
|
from nltk import Tree |
|
|
|
except: |
|
|
|
# only nltk in some versions can run |
|
|
|
pass |
|
|
|
|
|
|
|
from .pipe import Pipe |
|
|
|
from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize |
|
|
@@ -32,12 +36,12 @@ from ...core.instance import Instance |
|
|
|
|
|
|
|
|
|
|
|
class CLSBasePipe(Pipe): |
|
|
|
|
|
|
|
def __init__(self, lower: bool=False, tokenizer: str='spacy', lang='en'): |
|
|
|
|
|
|
|
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', lang='en'): |
|
|
|
super().__init__() |
|
|
|
self.lower = lower |
|
|
|
self.tokenizer = get_tokenizer(tokenizer, lang=lang) |
|
|
|
|
|
|
|
|
|
|
|
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): |
|
|
|
r""" |
|
|
|
将DataBundle中的数据进行tokenize |
|
|
@@ -50,9 +54,9 @@ class CLSBasePipe(Pipe): |
|
|
|
new_field_name = new_field_name or field_name |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) |
|
|
|
|
|
|
|
|
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process(self, data_bundle: DataBundle): |
|
|
|
r""" |
|
|
|
传入的DataSet应该具备如下的结构 |
|
|
@@ -73,15 +77,15 @@ class CLSBasePipe(Pipe): |
|
|
|
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) |
|
|
|
# 建立词表并index |
|
|
|
data_bundle = _indexize(data_bundle=data_bundle) |
|
|
|
|
|
|
|
|
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.add_seq_len(Const.INPUT) |
|
|
|
|
|
|
|
|
|
|
|
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) |
|
|
|
data_bundle.set_target(Const.TARGET) |
|
|
|
|
|
|
|
|
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process_from_file(self, paths) -> DataBundle: |
|
|
|
r""" |
|
|
|
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` |
|
|
@@ -151,7 +155,7 @@ class YelpFullPipe(CLSBasePipe): |
|
|
|
""" |
|
|
|
if self.tag_map is not None: |
|
|
|
data_bundle = _granularize(data_bundle, self.tag_map) |
|
|
|
|
|
|
|
|
|
|
|
data_bundle = super().process(data_bundle) |
|
|
|
|
|
|
|
return data_bundle |
|
|
@@ -231,7 +235,7 @@ class AGsNewsPipe(CLSBasePipe): |
|
|
|
+-------------+-----------+--------+-------+---------+ |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): |
|
|
|
r""" |
|
|
|
|
|
|
@@ -239,7 +243,7 @@ class AGsNewsPipe(CLSBasePipe): |
|
|
|
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 |
|
|
|
""" |
|
|
|
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') |
|
|
|
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None): |
|
|
|
r""" |
|
|
|
:param str paths: |
|
|
@@ -272,7 +276,7 @@ class DBPediaPipe(CLSBasePipe): |
|
|
|
+-------------+-----------+--------+-------+---------+ |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): |
|
|
|
r""" |
|
|
|
|
|
|
@@ -280,7 +284,7 @@ class DBPediaPipe(CLSBasePipe): |
|
|
|
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 |
|
|
|
""" |
|
|
|
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') |
|
|
|
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None): |
|
|
|
r""" |
|
|
|
:param str paths: |
|
|
@@ -369,7 +373,7 @@ class SSTPipe(CLSBasePipe): |
|
|
|
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) |
|
|
|
ds.append(instance) |
|
|
|
data_bundle.set_dataset(ds, name) |
|
|
|
|
|
|
|
|
|
|
|
# 根据granularity设置tag |
|
|
|
data_bundle = _granularize(data_bundle, tag_map=self.tag_map) |
|
|
|
|
|
|
@@ -525,6 +529,7 @@ class ChnSentiCorpPipe(Pipe): |
|
|
|
+-------------+-----------+--------+-------+---------+ |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, bigrams=False, trigrams=False): |
|
|
|
r""" |
|
|
|
|
|
|
@@ -536,10 +541,10 @@ class ChnSentiCorpPipe(Pipe): |
|
|
|
data_bundle.get_vocab('trigrams')获取. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.bigrams = bigrams |
|
|
|
self.trigrams = trigrams |
|
|
|
|
|
|
|
|
|
|
|
def _tokenize(self, data_bundle): |
|
|
|
r""" |
|
|
|
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 |
|
|
@@ -549,8 +554,8 @@ class ChnSentiCorpPipe(Pipe): |
|
|
|
""" |
|
|
|
data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) |
|
|
|
return data_bundle |
|
|
|
|
|
|
|
def process(self, data_bundle:DataBundle): |
|
|
|
|
|
|
|
def process(self, data_bundle: DataBundle): |
|
|
|
r""" |
|
|
|
可以处理的DataSet应该具备以下的field |
|
|
|
|
|
|
@@ -565,9 +570,9 @@ class ChnSentiCorpPipe(Pipe): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
_add_chars_field(data_bundle, lower=False) |
|
|
|
|
|
|
|
|
|
|
|
data_bundle = self._tokenize(data_bundle) |
|
|
|
|
|
|
|
|
|
|
|
input_field_names = [Const.CHAR_INPUT] |
|
|
|
if self.bigrams: |
|
|
|
for name, dataset in data_bundle.iter_datasets(): |
|
|
@@ -580,21 +585,21 @@ class ChnSentiCorpPipe(Pipe): |
|
|
|
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], |
|
|
|
field_name=Const.CHAR_INPUT, new_field_name='trigrams') |
|
|
|
input_field_names.append('trigrams') |
|
|
|
|
|
|
|
|
|
|
|
# index |
|
|
|
_indexize(data_bundle, input_field_names, Const.TARGET) |
|
|
|
|
|
|
|
|
|
|
|
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names |
|
|
|
target_fields = [Const.TARGET] |
|
|
|
|
|
|
|
|
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.add_seq_len(Const.CHAR_INPUT) |
|
|
|
|
|
|
|
|
|
|
|
data_bundle.set_input(*input_fields) |
|
|
|
data_bundle.set_target(*target_fields) |
|
|
|
|
|
|
|
|
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None): |
|
|
|
r""" |
|
|
|
|
|
|
@@ -604,7 +609,7 @@ class ChnSentiCorpPipe(Pipe): |
|
|
|
# 读取数据 |
|
|
|
data_bundle = ChnSentiCorpLoader().load(paths) |
|
|
|
data_bundle = self.process(data_bundle) |
|
|
|
|
|
|
|
|
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
@@ -637,26 +642,26 @@ class THUCNewsPipe(CLSBasePipe): |
|
|
|
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 |
|
|
|
data_bundle.get_vocab('trigrams')获取. |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, bigrams=False, trigrams=False): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.bigrams = bigrams |
|
|
|
self.trigrams = trigrams |
|
|
|
|
|
|
|
|
|
|
|
def _chracter_split(self, sent): |
|
|
|
return list(sent) |
|
|
|
# return [w for w in sent] |
|
|
|
|
|
|
|
|
|
|
|
def _raw_split(self, sent): |
|
|
|
return sent.split() |
|
|
|
|
|
|
|
|
|
|
|
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): |
|
|
|
new_field_name = new_field_name or field_name |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) |
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process(self, data_bundle: DataBundle): |
|
|
|
r""" |
|
|
|
可处理的DataSet应具备如下的field |
|
|
@@ -673,14 +678,14 @@ class THUCNewsPipe(CLSBasePipe): |
|
|
|
# 根据granularity设置tag |
|
|
|
tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9} |
|
|
|
data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map) |
|
|
|
|
|
|
|
|
|
|
|
# clean,lower |
|
|
|
|
|
|
|
|
|
|
|
# CWS(tokenize) |
|
|
|
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') |
|
|
|
|
|
|
|
|
|
|
|
input_field_names = [Const.CHAR_INPUT] |
|
|
|
|
|
|
|
|
|
|
|
# n-grams |
|
|
|
if self.bigrams: |
|
|
|
for name, dataset in data_bundle.iter_datasets(): |
|
|
@@ -693,22 +698,22 @@ class THUCNewsPipe(CLSBasePipe): |
|
|
|
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], |
|
|
|
field_name=Const.CHAR_INPUT, new_field_name='trigrams') |
|
|
|
input_field_names.append('trigrams') |
|
|
|
|
|
|
|
|
|
|
|
# index |
|
|
|
data_bundle = _indexize(data_bundle=data_bundle, input_field_names=Const.CHAR_INPUT) |
|
|
|
|
|
|
|
|
|
|
|
# add length |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN) |
|
|
|
|
|
|
|
|
|
|
|
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names |
|
|
|
target_fields = [Const.TARGET] |
|
|
|
|
|
|
|
|
|
|
|
data_bundle.set_input(*input_fields) |
|
|
|
data_bundle.set_target(*target_fields) |
|
|
|
|
|
|
|
|
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None): |
|
|
|
r""" |
|
|
|
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 |
|
|
@@ -749,22 +754,22 @@ class WeiboSenti100kPipe(CLSBasePipe): |
|
|
|
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 |
|
|
|
data_bundle.get_vocab('trigrams')获取. |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, bigrams=False, trigrams=False): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.bigrams = bigrams |
|
|
|
self.trigrams = trigrams |
|
|
|
|
|
|
|
|
|
|
|
def _chracter_split(self, sent): |
|
|
|
return list(sent) |
|
|
|
|
|
|
|
|
|
|
|
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): |
|
|
|
new_field_name = new_field_name or field_name |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) |
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process(self, data_bundle: DataBundle): |
|
|
|
r""" |
|
|
|
可处理的DataSet应具备以下的field |
|
|
@@ -779,12 +784,12 @@ class WeiboSenti100kPipe(CLSBasePipe): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
# clean,lower |
|
|
|
|
|
|
|
|
|
|
|
# CWS(tokenize) |
|
|
|
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') |
|
|
|
|
|
|
|
|
|
|
|
input_field_names = [Const.CHAR_INPUT] |
|
|
|
|
|
|
|
|
|
|
|
# n-grams |
|
|
|
if self.bigrams: |
|
|
|
for name, dataset in data_bundle.iter_datasets(): |
|
|
@@ -797,22 +802,22 @@ class WeiboSenti100kPipe(CLSBasePipe): |
|
|
|
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], |
|
|
|
field_name=Const.CHAR_INPUT, new_field_name='trigrams') |
|
|
|
input_field_names.append('trigrams') |
|
|
|
|
|
|
|
|
|
|
|
# index |
|
|
|
data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') |
|
|
|
|
|
|
|
|
|
|
|
# add length |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN) |
|
|
|
|
|
|
|
|
|
|
|
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names |
|
|
|
target_fields = [Const.TARGET] |
|
|
|
|
|
|
|
|
|
|
|
data_bundle.set_input(*input_fields) |
|
|
|
data_bundle.set_target(*target_fields) |
|
|
|
|
|
|
|
|
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None): |
|
|
|
r""" |
|
|
|
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 |
|
|
@@ -822,4 +827,3 @@ class WeiboSenti100kPipe(CLSBasePipe): |
|
|
|
data_bundle = data_loader.load(paths) |
|
|
|
data_bundle = self.process(data_bundle) |
|
|
|
return data_bundle |
|
|
|
|