@@ -37,10 +37,11 @@ from fastNLP.core.log import logger | |||||
class CLSBasePipe(Pipe): | class CLSBasePipe(Pipe): | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en', num_proc=0): | |||||
super().__init__() | super().__init__() | ||||
self.lower = lower | self.lower = lower | ||||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | self.tokenizer = get_tokenizer(tokenizer, lang=lang) | ||||
self.num_proc = num_proc | |||||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | ||||
r""" | r""" | ||||
@@ -53,7 +54,8 @@ class CLSBasePipe(Pipe): | |||||
""" | """ | ||||
new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, | |||||
num_proc=self.num_proc) | |||||
return data_bundle | return data_bundle | ||||
@@ -117,7 +119,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy', num_proc=0): | |||||
r""" | r""" | ||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
@@ -125,7 +127,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | ||||
self.granularity = granularity | self.granularity = granularity | ||||
@@ -191,13 +193,13 @@ class YelpPolarityPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
r""" | r""" | ||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
r""" | r""" | ||||
@@ -233,13 +235,13 @@ class AGsNewsPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
r""" | r""" | ||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
r""" | r""" | ||||
@@ -274,13 +276,13 @@ class DBPediaPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
r""" | r""" | ||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
r""" | r""" | ||||
@@ -315,7 +317,7 @@ class SSTPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy', num_proc=0): | |||||
r""" | r""" | ||||
:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | ||||
@@ -325,7 +327,7 @@ class SSTPipe(CLSBasePipe): | |||||
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
self.subtree = subtree | self.subtree = subtree | ||||
self.train_tree = train_subtree | self.train_tree = train_subtree | ||||
self.lower = lower | self.lower = lower | ||||
@@ -407,13 +409,13 @@ class SST2Pipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer='raw'): | |||||
def __init__(self, lower=False, tokenizer='raw', num_proc=0): | |||||
r""" | r""" | ||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。 | ||||
""" | """ | ||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
r""" | r""" | ||||
@@ -452,13 +454,13 @@ class IMDBPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
@@ -483,7 +485,7 @@ class IMDBPipe(CLSBasePipe): | |||||
return raw_words | return raw_words | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words') | |||||
dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words', num_proc=self.num_proc) | |||||
data_bundle = super().process(data_bundle) | data_bundle = super().process(data_bundle) | ||||
@@ -527,7 +529,7 @@ class ChnSentiCorpPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, bigrams=False, trigrams=False): | |||||
def __init__(self, bigrams=False, trigrams=False, num_proc: int = 0): | |||||
r""" | r""" | ||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | ||||
@@ -541,10 +543,11 @@ class ChnSentiCorpPipe(Pipe): | |||||
self.bigrams = bigrams | self.bigrams = bigrams | ||||
self.trigrams = trigrams | self.trigrams = trigrams | ||||
self.num_proc = num_proc | |||||
def _tokenize(self, data_bundle): | def _tokenize(self, data_bundle): | ||||
r""" | r""" | ||||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||||
将 DataSet 中的"复旦大学"拆分为 ["复", "旦", "大", "学"] . 未来可以通过扩展这个函数实现分词。 | |||||
:param data_bundle: | :param data_bundle: | ||||
:return: | :return: | ||||
@@ -571,24 +574,26 @@ class ChnSentiCorpPipe(Pipe): | |||||
data_bundle = self._tokenize(data_bundle) | data_bundle = self._tokenize(data_bundle) | ||||
input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
def bigrams(chars): | |||||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
def trigrams(chars): | |||||
return [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
if self.bigrams: | if self.bigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
field_name='chars', new_field_name='bigrams') | |||||
dataset.apply_field(bigrams,field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
if self.trigrams: | if self.trigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
field_name='chars', new_field_name='trigrams') | |||||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
# index | # index | ||||
_indexize(data_bundle, input_field_names, 'target') | _indexize(data_bundle, input_field_names, 'target') | ||||
input_fields = ['target', 'seq_len'] + input_field_names | |||||
target_fields = ['target'] | |||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len('chars') | dataset.add_seq_len('chars') | ||||
@@ -637,8 +642,8 @@ class THUCNewsPipe(CLSBasePipe): | |||||
data_bundle.get_vocab('trigrams')获取. | data_bundle.get_vocab('trigrams')获取. | ||||
""" | """ | ||||
def __init__(self, bigrams=False, trigrams=False): | |||||
super().__init__() | |||||
def __init__(self, bigrams=False, trigrams=False, num_proc=0): | |||||
super().__init__(num_proc=num_proc) | |||||
self.bigrams = bigrams | self.bigrams = bigrams | ||||
self.trigrams = trigrams | self.trigrams = trigrams | ||||
@@ -653,7 +658,7 @@ class THUCNewsPipe(CLSBasePipe): | |||||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | ||||
new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
@@ -680,17 +685,21 @@ class THUCNewsPipe(CLSBasePipe): | |||||
input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
def bigrams(chars): | |||||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
def trigrams(chars): | |||||
return [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
# n-grams | # n-grams | ||||
if self.bigrams: | if self.bigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
field_name='chars', new_field_name='bigrams') | |||||
dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
if self.trigrams: | if self.trigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
field_name='chars', new_field_name='trigrams') | |||||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
# index | # index | ||||
@@ -700,9 +709,6 @@ class THUCNewsPipe(CLSBasePipe): | |||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | ||||
input_fields = ['target', 'seq_len'] + input_field_names | |||||
target_fields = ['target'] | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
@@ -746,8 +752,8 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
data_bundle.get_vocab('trigrams')获取. | data_bundle.get_vocab('trigrams')获取. | ||||
""" | """ | ||||
def __init__(self, bigrams=False, trigrams=False): | |||||
super().__init__() | |||||
def __init__(self, bigrams=False, trigrams=False, num_proc=0): | |||||
super().__init__(num_proc=num_proc) | |||||
self.bigrams = bigrams | self.bigrams = bigrams | ||||
self.trigrams = trigrams | self.trigrams = trigrams | ||||
@@ -758,7 +764,8 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | ||||
new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||||
dataset.apply_field(self._chracter_split, field_name=field_name, | |||||
new_field_name=new_field_name, num_proc=self.num_proc) | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
@@ -779,20 +786,19 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
# CWS(tokenize) | # CWS(tokenize) | ||||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') | data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') | ||||
input_field_names = ['chars'] | |||||
def bigrams(chars): | |||||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
def trigrams(chars): | |||||
return [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
# n-grams | # n-grams | ||||
if self.bigrams: | if self.bigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
field_name='chars', new_field_name='bigrams') | |||||
input_field_names.append('bigrams') | |||||
dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
if self.trigrams: | if self.trigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
field_name='chars', new_field_name='trigrams') | |||||
input_field_names.append('trigrams') | |||||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
# index | # index | ||||
data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') | data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') | ||||
@@ -801,9 +807,6 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | ||||
input_fields = ['target', 'seq_len'] + input_field_names | |||||
target_fields = ['target'] | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
@@ -817,13 +820,13 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
return data_bundle | return data_bundle | ||||
class MRPipe(CLSBasePipe): | class MRPipe(CLSBasePipe): | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
self.lower = lower | self.lower = lower | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
@@ -840,13 +843,13 @@ class MRPipe(CLSBasePipe): | |||||
class R8Pipe(CLSBasePipe): | class R8Pipe(CLSBasePipe): | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc = 0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
self.lower = lower | self.lower = lower | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
@@ -863,13 +866,13 @@ class R8Pipe(CLSBasePipe): | |||||
class R52Pipe(CLSBasePipe): | class R52Pipe(CLSBasePipe): | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
self.lower = lower | self.lower = lower | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
@@ -886,13 +889,13 @@ class R52Pipe(CLSBasePipe): | |||||
class OhsumedPipe(CLSBasePipe): | class OhsumedPipe(CLSBasePipe): | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
self.lower = lower | self.lower = lower | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
@@ -909,13 +912,13 @@ class OhsumedPipe(CLSBasePipe): | |||||
class NG20Pipe(CLSBasePipe): | class NG20Pipe(CLSBasePipe): | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
self.lower = lower | self.lower = lower | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
@@ -30,7 +30,7 @@ class _NERPipe(Pipe): | |||||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | ||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False, num_proc=0): | |||||
r""" | r""" | ||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
@@ -39,10 +39,14 @@ class _NERPipe(Pipe): | |||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
elif encoding_type == 'bioes': | elif encoding_type == 'bioes': | ||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
def func(words): | |||||
return iob2bioes(iob2(words)) | |||||
# self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
self.convert_tag = func | |||||
else: | else: | ||||
raise ValueError("encoding_type only supports `bio` and `bioes`.") | raise ValueError("encoding_type only supports `bio` and `bioes`.") | ||||
self.lower = lower | self.lower = lower | ||||
self.num_proc = num_proc | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
r""" | r""" | ||||
@@ -60,16 +64,13 @@ class _NERPipe(Pipe): | |||||
""" | """ | ||||
# 转换tag | # 转换tag | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) | |||||
_add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
# index | # index | ||||
_indexize(data_bundle) | _indexize(data_bundle) | ||||
input_fields = ['target', 'words', 'seq_len'] | |||||
target_fields = ['target', 'seq_len'] | |||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.add_seq_len('words') | dataset.add_seq_len('words') | ||||
@@ -144,7 +145,7 @@ class Conll2003Pipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, num_proc: int = 0): | |||||
r""" | r""" | ||||
:param str chunk_encoding_type: 支持bioes, bio。 | :param str chunk_encoding_type: 支持bioes, bio。 | ||||
@@ -154,16 +155,23 @@ class Conll2003Pipe(Pipe): | |||||
if chunk_encoding_type == 'bio': | if chunk_encoding_type == 'bio': | ||||
self.chunk_convert_tag = iob2 | self.chunk_convert_tag = iob2 | ||||
elif chunk_encoding_type == 'bioes': | elif chunk_encoding_type == 'bioes': | ||||
self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
def func1(tags): | |||||
return iob2bioes(iob2(tags)) | |||||
# self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
self.chunk_convert_tag = func1 | |||||
else: | else: | ||||
raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") | raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") | ||||
if ner_encoding_type == 'bio': | if ner_encoding_type == 'bio': | ||||
self.ner_convert_tag = iob2 | self.ner_convert_tag = iob2 | ||||
elif ner_encoding_type == 'bioes': | elif ner_encoding_type == 'bioes': | ||||
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
def func2(tags): | |||||
return iob2bioes(iob2(tags)) | |||||
# self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
self.ner_convert_tag = func2 | |||||
else: | else: | ||||
raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") | raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") | ||||
self.lower = lower | self.lower = lower | ||||
self.num_proc = num_proc | |||||
def process(self, data_bundle) -> DataBundle: | def process(self, data_bundle) -> DataBundle: | ||||
r""" | r""" | ||||
@@ -182,8 +190,8 @@ class Conll2003Pipe(Pipe): | |||||
# 转换tag | # 转换tag | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) | dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) | ||||
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') | |||||
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') | |||||
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk', num_proc=self.num_proc) | |||||
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner', num_proc=self.num_proc) | |||||
_add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
@@ -194,10 +202,7 @@ class Conll2003Pipe(Pipe): | |||||
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | ||||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | ||||
data_bundle.set_vocab(tgt_vocab, 'chunk') | data_bundle.set_vocab(tgt_vocab, 'chunk') | ||||
input_fields = ['words', 'seq_len'] | |||||
target_fields = ['pos', 'ner', 'chunk', 'seq_len'] | |||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.add_seq_len('words') | dataset.add_seq_len('words') | ||||
@@ -256,7 +261,7 @@ class _CNNERPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False, num_proc: int = 0): | |||||
r""" | r""" | ||||
:param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
@@ -270,12 +275,16 @@ class _CNNERPipe(Pipe): | |||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
elif encoding_type == 'bioes': | elif encoding_type == 'bioes': | ||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
def func(words): | |||||
return iob2bioes(iob2(words)) | |||||
# self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
self.convert_tag = func | |||||
else: | else: | ||||
raise ValueError("encoding_type only supports `bio` and `bioes`.") | raise ValueError("encoding_type only supports `bio` and `bioes`.") | ||||
self.bigrams = bigrams | self.bigrams = bigrams | ||||
self.trigrams = trigrams | self.trigrams = trigrams | ||||
self.num_proc = num_proc | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
r""" | r""" | ||||
@@ -296,29 +305,31 @@ class _CNNERPipe(Pipe): | |||||
""" | """ | ||||
# 转换tag | # 转换tag | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) | |||||
_add_chars_field(data_bundle, lower=False) | _add_chars_field(data_bundle, lower=False) | ||||
input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
def bigrams(chars): | |||||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
def trigrams(chars): | |||||
return [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
if self.bigrams: | if self.bigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
field_name='chars', new_field_name='bigrams') | |||||
dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
if self.trigrams: | if self.trigrams: | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
field_name='chars', new_field_name='trigrams') | |||||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
# index | # index | ||||
_indexize(data_bundle, input_field_names, 'target') | _indexize(data_bundle, input_field_names, 'target') | ||||
input_fields = ['target', 'seq_len'] + input_field_names | |||||
target_fields = ['target', 'seq_len'] | |||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.add_seq_len('chars') | dataset.add_seq_len('chars') | ||||
@@ -157,7 +157,8 @@ class CWSPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | |||||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, | |||||
bigrams=False, trigrams=False, num_proc: int = 0): | |||||
r""" | r""" | ||||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | ||||
@@ -176,6 +177,7 @@ class CWSPipe(Pipe): | |||||
self.bigrams = bigrams | self.bigrams = bigrams | ||||
self.trigrams = trigrams | self.trigrams = trigrams | ||||
self.replace_num_alpha = replace_num_alpha | self.replace_num_alpha = replace_num_alpha | ||||
self.num_proc = num_proc | |||||
def _tokenize(self, data_bundle): | def _tokenize(self, data_bundle): | ||||
r""" | r""" | ||||
@@ -213,7 +215,7 @@ class CWSPipe(Pipe): | |||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(split_word_into_chars, field_name='chars', | dataset.apply_field(split_word_into_chars, field_name='chars', | ||||
new_field_name='chars') | |||||
new_field_name='chars', num_proc=self.num_proc) | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
@@ -233,33 +235,40 @@ class CWSPipe(Pipe): | |||||
data_bundle.copy_field('raw_words', 'chars') | data_bundle.copy_field('raw_words', 'chars') | ||||
if self.replace_num_alpha: | if self.replace_num_alpha: | ||||
data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars') | |||||
data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars') | |||||
data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars', num_proc=self.num_proc) | |||||
data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars', num_proc=self.num_proc) | |||||
self._tokenize(data_bundle) | self._tokenize(data_bundle) | ||||
def func1(chars): | |||||
return self.word_lens_to_tags(map(len, chars)) | |||||
def func2(chars): | |||||
return list(chain(*chars)) | |||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name='chars', | |||||
new_field_name='target') | |||||
dataset.apply_field(lambda chars: list(chain(*chars)), field_name='chars', | |||||
new_field_name='chars') | |||||
dataset.apply_field(func1, field_name='chars', new_field_name='target', num_proc=self.num_proc) | |||||
dataset.apply_field(func2, field_name='chars', new_field_name='chars', num_proc=self.num_proc) | |||||
input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
def bigram(chars): | |||||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
def trigrams(chars): | |||||
return [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
if self.bigrams: | if self.bigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
field_name='chars', new_field_name='bigrams') | |||||
dataset.apply_field(bigram, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
if self.trigrams: | if self.trigrams: | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
field_name='chars', new_field_name='trigrams') | |||||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
_indexize(data_bundle, input_field_names, 'target') | _indexize(data_bundle, input_field_names, 'target') | ||||
input_fields = ['target', 'seq_len'] + input_field_names | |||||
target_fields = ['target', 'seq_len'] | |||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.add_seq_len('chars') | dataset.add_seq_len('chars') | ||||
@@ -23,6 +23,7 @@ __all__ = [ | |||||
"GranularizePipe", | "GranularizePipe", | ||||
"MachingTruncatePipe", | "MachingTruncatePipe", | ||||
] | ] | ||||
from functools import partial | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
@@ -63,7 +64,7 @@ class MatchingBertPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||||
def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将word小写化。 | :param bool lower: 是否将word小写化。 | ||||
@@ -73,6 +74,7 @@ class MatchingBertPipe(Pipe): | |||||
self.lower = bool(lower) | self.lower = bool(lower) | ||||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | ||||
self.num_proc = num_proc | |||||
def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
r""" | r""" | ||||
@@ -84,8 +86,7 @@ class MatchingBertPipe(Pipe): | |||||
""" | """ | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
for field_name, new_field_name in zip(field_names, new_field_names): | for field_name, new_field_name in zip(field_names, new_field_names): | ||||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||||
new_field_name=new_field_name) | |||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
@@ -124,8 +125,8 @@ class MatchingBertPipe(Pipe): | |||||
words = words0 + ['[SEP]'] + words1 | words = words0 + ['[SEP]'] + words1 | ||||
return words | return words | ||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply(concat, new_field_name='words') | |||||
for name, dataset in data_bundle.iter_datasets(): | |||||
dataset.apply(concat, new_field_name='words', num_proc=self.num_proc) | |||||
dataset.delete_field('words1') | dataset.delete_field('words1') | ||||
dataset.delete_field('words2') | dataset.delete_field('words2') | ||||
@@ -155,10 +156,7 @@ class MatchingBertPipe(Pipe): | |||||
data_bundle.set_vocab(word_vocab, 'words') | data_bundle.set_vocab(word_vocab, 'words') | ||||
data_bundle.set_vocab(target_vocab, 'target') | data_bundle.set_vocab(target_vocab, 'target') | ||||
input_fields = ['words', 'seq_len'] | |||||
target_fields = ['target'] | |||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
dataset.add_seq_len('words') | dataset.add_seq_len('words') | ||||
@@ -223,7 +221,7 @@ class MatchingPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||||
def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): | |||||
r""" | r""" | ||||
:param bool lower: 是否将所有raw_words转为小写。 | :param bool lower: 是否将所有raw_words转为小写。 | ||||
@@ -233,6 +231,7 @@ class MatchingPipe(Pipe): | |||||
self.lower = bool(lower) | self.lower = bool(lower) | ||||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | ||||
self.num_proc = num_proc | |||||
def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
r""" | r""" | ||||
@@ -244,8 +243,7 @@ class MatchingPipe(Pipe): | |||||
""" | """ | ||||
for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
for field_name, new_field_name in zip(field_names, new_field_names): | for field_name, new_field_name in zip(field_names, new_field_names): | ||||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||||
new_field_name=new_field_name) | |||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
@@ -300,10 +298,7 @@ class MatchingPipe(Pipe): | |||||
data_bundle.set_vocab(word_vocab, 'words1') | data_bundle.set_vocab(word_vocab, 'words1') | ||||
data_bundle.set_vocab(target_vocab, 'target') | data_bundle.set_vocab(target_vocab, 'target') | ||||
input_fields = ['words1', 'words2', 'seq_len1', 'seq_len2'] | |||||
target_fields = ['target'] | |||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len('words1', 'seq_len1') | dataset.add_seq_len('words1', 'seq_len1') | ||||
dataset.add_seq_len('words2', 'seq_len2') | dataset.add_seq_len('words2', 'seq_len2') | ||||
@@ -342,8 +337,8 @@ class MNLIPipe(MatchingPipe): | |||||
class LCQMCPipe(MatchingPipe): | class LCQMCPipe(MatchingPipe): | ||||
def __init__(self, tokenizer='cn=char'): | |||||
super().__init__(tokenizer=tokenizer) | |||||
def __init__(self, tokenizer='cn=char', num_proc=0): | |||||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = LCQMCLoader().load(paths) | data_bundle = LCQMCLoader().load(paths) | ||||
@@ -354,8 +349,8 @@ class LCQMCPipe(MatchingPipe): | |||||
class CNXNLIPipe(MatchingPipe): | class CNXNLIPipe(MatchingPipe): | ||||
def __init__(self, tokenizer='cn-char'): | |||||
super().__init__(tokenizer=tokenizer) | |||||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = CNXNLILoader().load(paths) | data_bundle = CNXNLILoader().load(paths) | ||||
@@ -367,8 +362,8 @@ class CNXNLIPipe(MatchingPipe): | |||||
class BQCorpusPipe(MatchingPipe): | class BQCorpusPipe(MatchingPipe): | ||||
def __init__(self, tokenizer='cn-char'): | |||||
super().__init__(tokenizer=tokenizer) | |||||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = BQCorpusLoader().load(paths) | data_bundle = BQCorpusLoader().load(paths) | ||||
@@ -379,9 +374,10 @@ class BQCorpusPipe(MatchingPipe): | |||||
class RenamePipe(Pipe): | class RenamePipe(Pipe): | ||||
def __init__(self, task='cn-nli'): | |||||
def __init__(self, task='cn-nli', num_proc=0): | |||||
super().__init__() | super().__init__() | ||||
self.task = task | self.task = task | ||||
self.num_proc = num_proc | |||||
def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset | def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset | ||||
if (self.task == 'cn-nli'): | if (self.task == 'cn-nli'): | ||||
@@ -419,9 +415,10 @@ class RenamePipe(Pipe): | |||||
class GranularizePipe(Pipe): | class GranularizePipe(Pipe): | ||||
def __init__(self, task=None): | |||||
def __init__(self, task=None, num_proc=0): | |||||
super().__init__() | super().__init__() | ||||
self.task = task | self.task = task | ||||
self.num_proc = num_proc | |||||
def _granularize(self, data_bundle, tag_map): | def _granularize(self, data_bundle, tag_map): | ||||
r""" | r""" | ||||
@@ -434,8 +431,7 @@ class GranularizePipe(Pipe): | |||||
""" | """ | ||||
for name in list(data_bundle.datasets.keys()): | for name in list(data_bundle.datasets.keys()): | ||||
dataset = data_bundle.get_dataset(name) | dataset = data_bundle.get_dataset(name) | ||||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', | |||||
new_field_name='target') | |||||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', new_field_name='target') | |||||
dataset.drop(lambda ins: ins['target'] == -100) | dataset.drop(lambda ins: ins['target'] == -100) | ||||
data_bundle.set_dataset(dataset, name) | data_bundle.set_dataset(dataset, name) | ||||
return data_bundle | return data_bundle | ||||
@@ -462,8 +458,8 @@ class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len | |||||
class LCQMCBertPipe(MatchingBertPipe): | class LCQMCBertPipe(MatchingBertPipe): | ||||
def __init__(self, tokenizer='cn=char'): | |||||
super().__init__(tokenizer=tokenizer) | |||||
def __init__(self, tokenizer='cn=char', num_proc=0): | |||||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = LCQMCLoader().load(paths) | data_bundle = LCQMCLoader().load(paths) | ||||
@@ -475,8 +471,8 @@ class LCQMCBertPipe(MatchingBertPipe): | |||||
class BQCorpusBertPipe(MatchingBertPipe): | class BQCorpusBertPipe(MatchingBertPipe): | ||||
def __init__(self, tokenizer='cn-char'): | |||||
super().__init__(tokenizer=tokenizer) | |||||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = BQCorpusLoader().load(paths) | data_bundle = BQCorpusLoader().load(paths) | ||||
@@ -488,8 +484,8 @@ class BQCorpusBertPipe(MatchingBertPipe): | |||||
class CNXNLIBertPipe(MatchingBertPipe): | class CNXNLIBertPipe(MatchingBertPipe): | ||||
def __init__(self, tokenizer='cn-char'): | |||||
super().__init__(tokenizer=tokenizer) | |||||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = CNXNLILoader().load(paths) | data_bundle = CNXNLILoader().load(paths) | ||||
@@ -502,9 +498,10 @@ class CNXNLIBertPipe(MatchingBertPipe): | |||||
class TruncateBertPipe(Pipe): | class TruncateBertPipe(Pipe): | ||||
def __init__(self, task='cn'): | |||||
def __init__(self, task='cn', num_proc=0): | |||||
super().__init__() | super().__init__() | ||||
self.task = task | self.task = task | ||||
self.num_proc = num_proc | |||||
def _truncate(self, sentence_index:list, sep_index_vocab): | def _truncate(self, sentence_index:list, sep_index_vocab): | ||||
# 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index | # 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index | ||||
@@ -528,7 +525,8 @@ class TruncateBertPipe(Pipe): | |||||
for name in data_bundle.datasets.keys(): | for name in data_bundle.datasets.keys(): | ||||
dataset = data_bundle.get_dataset(name) | dataset = data_bundle.get_dataset(name) | ||||
sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') | sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') | ||||
dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') | |||||
dataset.apply_field(partial(self._truncate, sep_index_vocab=sep_index_vocab), field_name='words', | |||||
new_field_name='words', num_proc=self.num_proc) | |||||
# truncate之后需要更新seq_len | # truncate之后需要更新seq_len | ||||
dataset.add_seq_len(field_name='words') | dataset.add_seq_len(field_name='words') | ||||
@@ -1,6 +1,7 @@ | |||||
r"""undocumented""" | r"""undocumented""" | ||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
from functools import partial | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .utils import _drop_empty_instance | from .utils import _drop_empty_instance | ||||
@@ -25,7 +26,7 @@ class ExtCNNDMPipe(Pipe): | |||||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | ||||
""" | """ | ||||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False, num_proc=0): | |||||
r""" | r""" | ||||
:param vocab_size: int, 词表大小 | :param vocab_size: int, 词表大小 | ||||
@@ -39,6 +40,7 @@ class ExtCNNDMPipe(Pipe): | |||||
self.sent_max_len = sent_max_len | self.sent_max_len = sent_max_len | ||||
self.doc_max_timesteps = doc_max_timesteps | self.doc_max_timesteps = doc_max_timesteps | ||||
self.domain = domain | self.domain = domain | ||||
self.num_proc = num_proc | |||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
r""" | r""" | ||||
@@ -65,18 +67,29 @@ class ExtCNNDMPipe(Pipe): | |||||
error_msg = 'vocab file is not defined!' | error_msg = 'vocab file is not defined!' | ||||
print(error_msg) | print(error_msg) | ||||
raise RuntimeError(error_msg) | raise RuntimeError(error_msg) | ||||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
data_bundle.apply_field(_lower_text, field_name='text', new_field_name='text', num_proc=self.num_proc) | |||||
data_bundle.apply_field(_lower_text, field_name='summary', new_field_name='summary', num_proc=self.num_proc) | |||||
data_bundle.apply_field(_split_list, field_name='text', new_field_name='text_wd', num_proc=self.num_proc) | |||||
# data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
# data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
# data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') | data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') | ||||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') | |||||
data_bundle.apply_field(partial(_pad_sent, sent_max_len=self.sent_max_len), field_name="text_wd", | |||||
new_field_name="words", num_proc=self.num_proc) | |||||
# data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') | |||||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | ||||
# pad document | # pad document | ||||
data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') | |||||
data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') | |||||
data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') | |||||
data_bundle.apply_field(partial(_pad_doc, sent_max_len=self.sent_max_len, doc_max_timesteps=self.doc_max_timesteps), | |||||
field_name="words", new_field_name="words", num_proc=self.num_proc) | |||||
data_bundle.apply_field(partial(_sent_mask, doc_max_timesteps=self.doc_max_timesteps), field_name="words", | |||||
new_field_name="seq_len", num_proc=self.num_proc) | |||||
data_bundle.apply_field(partial(_pad_label, doc_max_timesteps=self.doc_max_timesteps), field_name="target", | |||||
new_field_name="target", num_proc=self.num_proc) | |||||
# data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') | |||||
# data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') | |||||
# data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') | |||||
data_bundle = _drop_empty_instance(data_bundle, "label") | data_bundle = _drop_empty_instance(data_bundle, "label") | ||||
@@ -12,14 +12,24 @@ class TestClassificationPipe: | |||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | ||||
print(pipe) | print(pipe) | ||||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file() | |||||
print(data_bundle) | print(data_bundle) | ||||
def test_process_from_file_proc(self, num_proc=2): | |||||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | |||||
print(pipe) | |||||
data_bundle = pipe(tokenizer='raw', num_proc=num_proc).process_from_file() | |||||
print(data_bundle) | |||||
class TestRunPipe: | class TestRunPipe: | ||||
def test_load(self): | def test_load(self): | ||||
for pipe in [IMDBPipe]: | for pipe in [IMDBPipe]: | ||||
data_bundle = pipe(tokenizer='raw').process_from_file('tests/data_for_tests/io/imdb') | |||||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file('tests/data_for_tests/io/imdb') | |||||
print(data_bundle) | |||||
def test_load_proc(self): | |||||
for pipe in [IMDBPipe]: | |||||
data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file('tests/data_for_tests/io/imdb') | |||||
print(data_bundle) | print(data_bundle) | ||||
@@ -31,7 +41,7 @@ class TestCNClassificationPipe: | |||||
print(data_bundle) | print(data_bundle) | ||||
@pytest.mark.skipif('download' not in os.environ, reason="Skip download") | |||||
# @pytest.mark.skipif('download' not in os.environ, reason="Skip download") | |||||
class TestRunClassificationPipe: | class TestRunClassificationPipe: | ||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
@@ -71,9 +81,9 @@ class TestRunClassificationPipe: | |||||
path, pipe, data_set, vocab, warns = v | path, pipe, data_set, vocab, warns = v | ||||
if 'Chn' not in k: | if 'Chn' not in k: | ||||
if warns: | if warns: | ||||
data_bundle = pipe(tokenizer='raw').process_from_file(path) | |||||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) | |||||
else: | else: | ||||
data_bundle = pipe(tokenizer='raw').process_from_file(path) | |||||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) | |||||
else: | else: | ||||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | ||||
@@ -87,3 +97,61 @@ class TestRunClassificationPipe: | |||||
for name, vocabs in data_bundle.iter_vocabs(): | for name, vocabs in data_bundle.iter_vocabs(): | ||||
assert(name in vocab.keys()) | assert(name in vocab.keys()) | ||||
assert(vocab[name] == len(vocabs)) | assert(vocab[name] == len(vocabs)) | ||||
def test_process_from_file_proc(self): | |||||
data_set_dict = { | |||||
'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, | |||||
False), | |||||
'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, | |||||
False), | |||||
'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe, | |||||
{'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, | |||||
True), | |||||
'sst': ('tests/data_for_tests/io/SST', SSTPipe, | |||||
{'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, | |||||
False), | |||||
'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, | |||||
False), | |||||
'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe, | |||||
{'train': 4, 'test': 5}, {'words': 257, 'target': 4}, | |||||
False), | |||||
'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe, | |||||
{'train': 14, 'test': 5}, {'words': 496, 'target': 14}, | |||||
False), | |||||
'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, | |||||
{'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, | |||||
False), | |||||
'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe, | |||||
{'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, | |||||
False), | |||||
'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, | |||||
{'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, | |||||
False), | |||||
} | |||||
for k, v in data_set_dict.items(): | |||||
path, pipe, data_set, vocab, warns = v | |||||
if 'Chn' not in k: | |||||
if warns: | |||||
data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
else: | |||||
data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
else: | |||||
# if k == 'ChnSentiCorp': | |||||
# data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | |||||
# else: | |||||
data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(path) | |||||
assert(isinstance(data_bundle, DataBundle)) | |||||
assert(len(data_set) == data_bundle.num_dataset) | |||||
for name, dataset in data_bundle.iter_datasets(): | |||||
assert(name in data_set.keys()) | |||||
assert(data_set[name] == len(dataset)) | |||||
assert(len(vocab) == data_bundle.num_vocab) | |||||
for name, vocabs in data_bundle.iter_vocabs(): | |||||
assert(name in vocab.keys()) | |||||
assert(vocab[name] == len(vocabs)) |
@@ -22,6 +22,12 @@ class TestRunPipe: | |||||
data_bundle = pipe().process_from_file('tests/data_for_tests/conll_2003_example.txt') | data_bundle = pipe().process_from_file('tests/data_for_tests/conll_2003_example.txt') | ||||
print(data_bundle) | print(data_bundle) | ||||
def test_conll2003_proc(self): | |||||
for pipe in [Conll2003Pipe, Conll2003NERPipe]: | |||||
print(pipe) | |||||
data_bundle = pipe(num_proc=2).process_from_file('tests/data_for_tests/conll_2003_example.txt') | |||||
print(data_bundle) | |||||
class TestNERPipe: | class TestNERPipe: | ||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
@@ -37,12 +43,33 @@ class TestNERPipe: | |||||
data_bundle = pipe(encoding_type='bioes').process_from_file(f'tests/data_for_tests/io/{k}') | data_bundle = pipe(encoding_type='bioes').process_from_file(f'tests/data_for_tests/io/{k}') | ||||
print(data_bundle) | print(data_bundle) | ||||
def test_process_from_file_proc(self): | |||||
data_dict = { | |||||
'weibo_NER': WeiboNERPipe, | |||||
'peopledaily': PeopleDailyPipe, | |||||
'MSRA_NER': MsraNERPipe, | |||||
} | |||||
for k, v in data_dict.items(): | |||||
pipe = v | |||||
data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') | |||||
print(data_bundle) | |||||
data_bundle = pipe(encoding_type='bioes', num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') | |||||
print(data_bundle) | |||||
class TestConll2003Pipe: | class TestConll2003Pipe: | ||||
def test_conll(self): | def test_conll(self): | ||||
data_bundle = Conll2003Pipe().process_from_file('tests/data_for_tests/io/conll2003') | data_bundle = Conll2003Pipe().process_from_file('tests/data_for_tests/io/conll2003') | ||||
print(data_bundle) | print(data_bundle) | ||||
def test_conll_proc(self): | |||||
data_bundle = Conll2003Pipe(num_proc=2).process_from_file('tests/data_for_tests/io/conll2003') | |||||
print(data_bundle) | |||||
def test_OntoNotes(self): | def test_OntoNotes(self): | ||||
data_bundle = OntoNotesNERPipe().process_from_file('tests/data_for_tests/io/OntoNotes') | data_bundle = OntoNotesNERPipe().process_from_file('tests/data_for_tests/io/OntoNotes') | ||||
print(data_bundle) | print(data_bundle) | ||||
def test_OntoNotes_proc(self): | |||||
data_bundle = OntoNotesNERPipe(num_proc=2).process_from_file('tests/data_for_tests/io/OntoNotes') | |||||
print(data_bundle) |
@@ -28,12 +28,25 @@ class TestRunCWSPipe: | |||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
dataset_names = ['msra', 'cityu', 'as', 'pku'] | dataset_names = ['msra', 'cityu', 'as', 'pku'] | ||||
for dataset_name in dataset_names: | for dataset_name in dataset_names: | ||||
data_bundle = CWSPipe(bigrams=True, trigrams=True).\ | |||||
data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=0).\ | |||||
process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | ||||
print(data_bundle) | print(data_bundle) | ||||
def test_replace_number(self): | def test_replace_number(self): | ||||
data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True).\ | |||||
data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=0).\ | |||||
process_from_file(f'tests/data_for_tests/io/cws_pku') | |||||
for word in ['<', '>', '<NUM>']: | |||||
assert(data_bundle.get_vocab('chars').to_index(word) != 1) | |||||
def test_process_from_file_proc(self): | |||||
dataset_names = ['msra', 'cityu', 'as', 'pku'] | |||||
for dataset_name in dataset_names: | |||||
data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=2).\ | |||||
process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | |||||
print(data_bundle) | |||||
def test_replace_number_proc(self): | |||||
data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=2).\ | |||||
process_from_file(f'tests/data_for_tests/io/cws_pku') | process_from_file(f'tests/data_for_tests/io/cws_pku') | ||||
for word in ['<', '>', '<NUM>']: | for word in ['<', '>', '<NUM>']: | ||||
assert(data_bundle.get_vocab('chars').to_index(word) != 1) | assert(data_bundle.get_vocab('chars').to_index(word) != 1) |
@@ -69,6 +69,47 @@ class TestRunMatchingPipe: | |||||
name, vocabs = y | name, vocabs = y | ||||
assert(x + 1 if name == 'words' else x == len(vocabs)) | assert(x + 1 if name == 'words' else x == len(vocabs)) | ||||
def test_load_proc(self): | |||||
data_set_dict = { | |||||
'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True), | |||||
'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), | |||||
'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | |||||
'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||||
'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||||
'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), | |||||
'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), | |||||
} | |||||
for k, v in data_set_dict.items(): | |||||
path, pipe1, pipe2, data_set, vocab, warns = v | |||||
if warns: | |||||
data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
else: | |||||
data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
assert (isinstance(data_bundle1, DataBundle)) | |||||
assert (len(data_set) == data_bundle1.num_dataset) | |||||
print(k) | |||||
print(data_bundle1) | |||||
print(data_bundle2) | |||||
for x, y in zip(data_set, data_bundle1.iter_datasets()): | |||||
name, dataset = y | |||||
assert (x == len(dataset)) | |||||
assert (len(data_set) == data_bundle2.num_dataset) | |||||
for x, y in zip(data_set, data_bundle2.iter_datasets()): | |||||
name, dataset = y | |||||
assert (x == len(dataset)) | |||||
assert (len(vocab) == data_bundle1.num_vocab) | |||||
for x, y in zip(vocab, data_bundle1.iter_vocabs()): | |||||
name, vocabs = y | |||||
assert (x == len(vocabs)) | |||||
assert (len(vocab) == data_bundle2.num_vocab) | |||||
for x, y in zip(vocab, data_bundle1.iter_vocabs()): | |||||
name, vocabs = y | |||||
assert (x + 1 if name == 'words' else x == len(vocabs)) | |||||
@pytest.mark.skipif('download' not in os.environ, reason="Skip download") | @pytest.mark.skipif('download' not in os.environ, reason="Skip download") | ||||
def test_spacy(self): | def test_spacy(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
@@ -69,3 +69,45 @@ class TestRunExtCNNDMPipe: | |||||
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | ||||
assert(isinstance(db5, DataBundle)) | assert(isinstance(db5, DataBundle)) | ||||
def test_load_proc(self): | |||||
data_dir = 'tests/data_for_tests/io/cnndm' | |||||
vocab_size = 100000 | |||||
VOCAL_FILE = 'tests/data_for_tests/io/cnndm/vocab' | |||||
sent_max_len = 100 | |||||
doc_max_timesteps = 50 | |||||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||||
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, | |||||
domain=True, num_proc=2) | |||||
db = dbPipe.process_from_file(data_dir) | |||||
db2 = dbPipe2.process_from_file(data_dir) | |||||
assert(isinstance(db, DataBundle)) | |||||
assert(isinstance(db2, DataBundle)) | |||||
dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, | |||||
domain=True, num_proc=2) | |||||
db3 = dbPipe3.process_from_file(data_dir) | |||||
assert(isinstance(db3, DataBundle)) | |||||
with pytest.raises(RuntimeError): | |||||
dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||||
db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||||
dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
vocab_path=VOCAL_FILE, | |||||
sent_max_len=sent_max_len, | |||||
doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||||
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||||
assert(isinstance(db5, DataBundle)) | |||||