You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_CRF.py 16 kB

Dev0.4.0 (#149) * 1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释 * BucketSampler增加一条错误检测 * 1.修改ClipGradientCallback的bug;删除LRSchedulerCallback中的print,之后应该传入pbar进行打印;2.增加MLP注释 * update MLP module * 增加metric注释;修改trainer save过程中的bug * Update README.md fix tutorial link * Add ENAS (Efficient Neural Architecture Search) * add ignore_type in DataSet.add_field * * AutoPadder will not pad when dtype is None * add ignore_type in DataSet.apply * 修复fieldarray中padder潜在bug * 修复crf中typo; 以及可能导致数值不稳定的地方 * 修复CRF中可能存在的bug * change two default init arguments of Trainer into None * Changes to Callbacks: * 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担 * * 将enas相关代码放到automl目录下 * 修复fast_param_mapping的一个bug * Trainer添加自动创建save目录 * Vocabulary的打印,显示内容 * * 给vocabulary添加遍历方法 * 修复CRF为负数的bug * add SQuAD metric * add sigmoid activate function in MLP * - add star transformer model - add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performance * - fix test * - fix callback & tests * - update README * 修改部分bug;调整callback * 准备发布0.4.0版本“ * update readme * support parallel loss * 防止多卡的情况导致无法正确计算loss“ * update advance_tutorial jupyter notebook * 1. 在embedding_loader中增加新的读取函数load_with_vocab(), load_without_vocab, 比之前的函数改变主要在(1)不再需要传入embed_dim(2)自动判断当前是word2vec还是glove. 2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性 * 1.DataSet.apply()报错时提供错误的index 2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序 3.embedloader在embed读取时遇到不规则的数据跳过这一行. * update attention * doc tools * fix some doc errors * 修改为中文注释,增加viterbi解码方法 * 样例版本 * - add pad sequence for lstm - add csv, conll, json filereader - update dataloader - remove useless dataloader - fix trainer loss print - fix tests * - fix test_tutorial * 注释增加 * 测试文档 * 本地暂存 * 本地暂存 * 修改文档的顺序 * - add document * 本地暂存 * update pooling * update bert * update documents in MLP * update documents in snli * combine self attention module to attention.py * update documents on losses.py * 对DataSet的文档进行更新 * update documents on metrics * 1. 删除了LSTM中print的内容; 2. 将Trainer和Tester的use_cuda修改为了device; 3.补充Trainer的文档 * 增加对Trainer的注释 * 完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏 * update char level encoder * update documents on embedding.py * - update doc * 补充注释,并修改部分代码 * - update doc - add get_embeddings * 修改了文档配置项 * 修改embedding为init_embed初始化 * 1.增加对Trainer和Tester的多卡支持; * - add test - fix jsonloader * 删除了注释教程 * 给 dataset 增加了get_field_names * 修复bug * - add Const - fix bugs * 修改部分注释 * - add model runner for easier test models - add model tests * 修改了 docs 的配置和架构 * 修改了核心部分的一大部分文档,TODO: 1. 完善 trainer 和 tester 部分的文档 2. 研究注释样例与测试 * core部分的注释基本检查完成 * 修改了 io 部分的注释 * 全部改为相对路径引用 * 全部改为相对路径引用 * small change * 1. 从安装文件中删除api/automl的安装 2. metric中存在seq_len的bug 3. sampler中存在命名错误,已修改 * 修复 bug :兼容 cpu 版本的 PyTorch TODO:其它地方可能也存在类似的 bug * 修改文档中的引用部分 * 把 tqdm.autonotebook 换成tqdm.auto * - fix batch & vocab * 上传了文档文件 *.rst * 上传了文档文件和若干 TODO * 讨论并整合了若干模块 * core部分的测试和一些小修改 * 删除了一些冗余文档 * update init files * update const files * update const files * 增加cnn的测试 * fix a little bug * - update attention - fix tests * 完善测试 * 完成快速入门教程 * 修改了sequence_modeling 命名为 sequence_labeling 的文档 * 重新 apidoc 解决改名的遗留问题 * 修改文档格式 * 统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask * 增加了一行提示 * 在文档中展示 dataset_loader * 提示 Dataset.read_csv 会被 CSVLoader 替换 * 完成 Callback 和 Trainer 之间的文档 * index更新了部分 * 删除冗余的print * 删除用于分词的metric,因为有可能引起错误 * 修改文档中的中文名称 * 完成了详细介绍文档 * tutorial 的 ipynb 文件 * 修改了一些介绍文档 * 修改了 models 和 modules 的主页介绍 * 加上了 titlesonly 这个设置 * 修改了模块文档展示的标题 * 修改了 core 和 io 的开篇介绍 * 修改了 modules 和 models 开篇介绍 * 使用 .. todo:: 隐藏了可能被抽到文档中的 TODO 注释 * 修改了一些注释 * delete an old metric in test * 修改 tutorials 的测试文件 * 把暂不发布的功能移到 legacy 文件夹 * 删除了不能运行的测试 * 修改 callback 的测试文件 * 删除了过时的教程和测试文件 * cache_results 参数的修改 * 修改 io 的测试文件; 删除了一些过时的测试 * 修复bug * 修复无法通过test_utils.py的测试 * 修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar * 1. 修复metric中的bug; 2.增加metric测试 * add model summary * 增加别名 * 删除encoder中的嵌套层 * 修改了 core 部分 import 的顺序,__all__ 暴露的内容 * 修改了 models 部分 import 的顺序,__all__ 暴露的内容 * 修改了文件名 * 修改了 modules 模块的__all__ 和 import * fix var runn * 增加vocab的clear方法 * 一些符合 PEP8 的微调 * 更新了cache_results的例子 * 1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index * 修改了一个typo * 修改了 README.md * update documents on bert * update documents on encoder/bert * 增加一个fitlog callback,实现与fitlog实验记录 * typo * - update dataset_loader * 增加了到 fitlog 文档的链接。 * 增加了 DataSet Loader 的文档 * - add star-transformer reproduction
6 years ago
Dev0.4.0 (#149) * 1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释 * BucketSampler增加一条错误检测 * 1.修改ClipGradientCallback的bug;删除LRSchedulerCallback中的print,之后应该传入pbar进行打印;2.增加MLP注释 * update MLP module * 增加metric注释;修改trainer save过程中的bug * Update README.md fix tutorial link * Add ENAS (Efficient Neural Architecture Search) * add ignore_type in DataSet.add_field * * AutoPadder will not pad when dtype is None * add ignore_type in DataSet.apply * 修复fieldarray中padder潜在bug * 修复crf中typo; 以及可能导致数值不稳定的地方 * 修复CRF中可能存在的bug * change two default init arguments of Trainer into None * Changes to Callbacks: * 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担 * * 将enas相关代码放到automl目录下 * 修复fast_param_mapping的一个bug * Trainer添加自动创建save目录 * Vocabulary的打印,显示内容 * * 给vocabulary添加遍历方法 * 修复CRF为负数的bug * add SQuAD metric * add sigmoid activate function in MLP * - add star transformer model - add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performance * - fix test * - fix callback & tests * - update README * 修改部分bug;调整callback * 准备发布0.4.0版本“ * update readme * support parallel loss * 防止多卡的情况导致无法正确计算loss“ * update advance_tutorial jupyter notebook * 1. 在embedding_loader中增加新的读取函数load_with_vocab(), load_without_vocab, 比之前的函数改变主要在(1)不再需要传入embed_dim(2)自动判断当前是word2vec还是glove. 2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性 * 1.DataSet.apply()报错时提供错误的index 2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序 3.embedloader在embed读取时遇到不规则的数据跳过这一行. * update attention * doc tools * fix some doc errors * 修改为中文注释,增加viterbi解码方法 * 样例版本 * - add pad sequence for lstm - add csv, conll, json filereader - update dataloader - remove useless dataloader - fix trainer loss print - fix tests * - fix test_tutorial * 注释增加 * 测试文档 * 本地暂存 * 本地暂存 * 修改文档的顺序 * - add document * 本地暂存 * update pooling * update bert * update documents in MLP * update documents in snli * combine self attention module to attention.py * update documents on losses.py * 对DataSet的文档进行更新 * update documents on metrics * 1. 删除了LSTM中print的内容; 2. 将Trainer和Tester的use_cuda修改为了device; 3.补充Trainer的文档 * 增加对Trainer的注释 * 完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏 * update char level encoder * update documents on embedding.py * - update doc * 补充注释,并修改部分代码 * - update doc - add get_embeddings * 修改了文档配置项 * 修改embedding为init_embed初始化 * 1.增加对Trainer和Tester的多卡支持; * - add test - fix jsonloader * 删除了注释教程 * 给 dataset 增加了get_field_names * 修复bug * - add Const - fix bugs * 修改部分注释 * - add model runner for easier test models - add model tests * 修改了 docs 的配置和架构 * 修改了核心部分的一大部分文档,TODO: 1. 完善 trainer 和 tester 部分的文档 2. 研究注释样例与测试 * core部分的注释基本检查完成 * 修改了 io 部分的注释 * 全部改为相对路径引用 * 全部改为相对路径引用 * small change * 1. 从安装文件中删除api/automl的安装 2. metric中存在seq_len的bug 3. sampler中存在命名错误,已修改 * 修复 bug :兼容 cpu 版本的 PyTorch TODO:其它地方可能也存在类似的 bug * 修改文档中的引用部分 * 把 tqdm.autonotebook 换成tqdm.auto * - fix batch & vocab * 上传了文档文件 *.rst * 上传了文档文件和若干 TODO * 讨论并整合了若干模块 * core部分的测试和一些小修改 * 删除了一些冗余文档 * update init files * update const files * update const files * 增加cnn的测试 * fix a little bug * - update attention - fix tests * 完善测试 * 完成快速入门教程 * 修改了sequence_modeling 命名为 sequence_labeling 的文档 * 重新 apidoc 解决改名的遗留问题 * 修改文档格式 * 统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask * 增加了一行提示 * 在文档中展示 dataset_loader * 提示 Dataset.read_csv 会被 CSVLoader 替换 * 完成 Callback 和 Trainer 之间的文档 * index更新了部分 * 删除冗余的print * 删除用于分词的metric,因为有可能引起错误 * 修改文档中的中文名称 * 完成了详细介绍文档 * tutorial 的 ipynb 文件 * 修改了一些介绍文档 * 修改了 models 和 modules 的主页介绍 * 加上了 titlesonly 这个设置 * 修改了模块文档展示的标题 * 修改了 core 和 io 的开篇介绍 * 修改了 modules 和 models 开篇介绍 * 使用 .. todo:: 隐藏了可能被抽到文档中的 TODO 注释 * 修改了一些注释 * delete an old metric in test * 修改 tutorials 的测试文件 * 把暂不发布的功能移到 legacy 文件夹 * 删除了不能运行的测试 * 修改 callback 的测试文件 * 删除了过时的教程和测试文件 * cache_results 参数的修改 * 修改 io 的测试文件; 删除了一些过时的测试 * 修复bug * 修复无法通过test_utils.py的测试 * 修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar * 1. 修复metric中的bug; 2.增加metric测试 * add model summary * 增加别名 * 删除encoder中的嵌套层 * 修改了 core 部分 import 的顺序,__all__ 暴露的内容 * 修改了 models 部分 import 的顺序,__all__ 暴露的内容 * 修改了文件名 * 修改了 modules 模块的__all__ 和 import * fix var runn * 增加vocab的clear方法 * 一些符合 PEP8 的微调 * 更新了cache_results的例子 * 1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index * 修改了一个typo * 修改了 README.md * update documents on bert * update documents on encoder/bert * 增加一个fitlog callback,实现与fitlog实验记录 * typo * - update dataset_loader * 增加了到 fitlog 文档的链接。 * 增加了 DataSet Loader 的文档 * - add star-transformer reproduction
6 years ago
Dev0.4.0 (#149) * 1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释 * BucketSampler增加一条错误检测 * 1.修改ClipGradientCallback的bug;删除LRSchedulerCallback中的print,之后应该传入pbar进行打印;2.增加MLP注释 * update MLP module * 增加metric注释;修改trainer save过程中的bug * Update README.md fix tutorial link * Add ENAS (Efficient Neural Architecture Search) * add ignore_type in DataSet.add_field * * AutoPadder will not pad when dtype is None * add ignore_type in DataSet.apply * 修复fieldarray中padder潜在bug * 修复crf中typo; 以及可能导致数值不稳定的地方 * 修复CRF中可能存在的bug * change two default init arguments of Trainer into None * Changes to Callbacks: * 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担 * * 将enas相关代码放到automl目录下 * 修复fast_param_mapping的一个bug * Trainer添加自动创建save目录 * Vocabulary的打印,显示内容 * * 给vocabulary添加遍历方法 * 修复CRF为负数的bug * add SQuAD metric * add sigmoid activate function in MLP * - add star transformer model - add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performance * - fix test * - fix callback & tests * - update README * 修改部分bug;调整callback * 准备发布0.4.0版本“ * update readme * support parallel loss * 防止多卡的情况导致无法正确计算loss“ * update advance_tutorial jupyter notebook * 1. 在embedding_loader中增加新的读取函数load_with_vocab(), load_without_vocab, 比之前的函数改变主要在(1)不再需要传入embed_dim(2)自动判断当前是word2vec还是glove. 2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性 * 1.DataSet.apply()报错时提供错误的index 2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序 3.embedloader在embed读取时遇到不规则的数据跳过这一行. * update attention * doc tools * fix some doc errors * 修改为中文注释,增加viterbi解码方法 * 样例版本 * - add pad sequence for lstm - add csv, conll, json filereader - update dataloader - remove useless dataloader - fix trainer loss print - fix tests * - fix test_tutorial * 注释增加 * 测试文档 * 本地暂存 * 本地暂存 * 修改文档的顺序 * - add document * 本地暂存 * update pooling * update bert * update documents in MLP * update documents in snli * combine self attention module to attention.py * update documents on losses.py * 对DataSet的文档进行更新 * update documents on metrics * 1. 删除了LSTM中print的内容; 2. 将Trainer和Tester的use_cuda修改为了device; 3.补充Trainer的文档 * 增加对Trainer的注释 * 完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏 * update char level encoder * update documents on embedding.py * - update doc * 补充注释,并修改部分代码 * - update doc - add get_embeddings * 修改了文档配置项 * 修改embedding为init_embed初始化 * 1.增加对Trainer和Tester的多卡支持; * - add test - fix jsonloader * 删除了注释教程 * 给 dataset 增加了get_field_names * 修复bug * - add Const - fix bugs * 修改部分注释 * - add model runner for easier test models - add model tests * 修改了 docs 的配置和架构 * 修改了核心部分的一大部分文档,TODO: 1. 完善 trainer 和 tester 部分的文档 2. 研究注释样例与测试 * core部分的注释基本检查完成 * 修改了 io 部分的注释 * 全部改为相对路径引用 * 全部改为相对路径引用 * small change * 1. 从安装文件中删除api/automl的安装 2. metric中存在seq_len的bug 3. sampler中存在命名错误,已修改 * 修复 bug :兼容 cpu 版本的 PyTorch TODO:其它地方可能也存在类似的 bug * 修改文档中的引用部分 * 把 tqdm.autonotebook 换成tqdm.auto * - fix batch & vocab * 上传了文档文件 *.rst * 上传了文档文件和若干 TODO * 讨论并整合了若干模块 * core部分的测试和一些小修改 * 删除了一些冗余文档 * update init files * update const files * update const files * 增加cnn的测试 * fix a little bug * - update attention - fix tests * 完善测试 * 完成快速入门教程 * 修改了sequence_modeling 命名为 sequence_labeling 的文档 * 重新 apidoc 解决改名的遗留问题 * 修改文档格式 * 统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask * 增加了一行提示 * 在文档中展示 dataset_loader * 提示 Dataset.read_csv 会被 CSVLoader 替换 * 完成 Callback 和 Trainer 之间的文档 * index更新了部分 * 删除冗余的print * 删除用于分词的metric,因为有可能引起错误 * 修改文档中的中文名称 * 完成了详细介绍文档 * tutorial 的 ipynb 文件 * 修改了一些介绍文档 * 修改了 models 和 modules 的主页介绍 * 加上了 titlesonly 这个设置 * 修改了模块文档展示的标题 * 修改了 core 和 io 的开篇介绍 * 修改了 modules 和 models 开篇介绍 * 使用 .. todo:: 隐藏了可能被抽到文档中的 TODO 注释 * 修改了一些注释 * delete an old metric in test * 修改 tutorials 的测试文件 * 把暂不发布的功能移到 legacy 文件夹 * 删除了不能运行的测试 * 修改 callback 的测试文件 * 删除了过时的教程和测试文件 * cache_results 参数的修改 * 修改 io 的测试文件; 删除了一些过时的测试 * 修复bug * 修复无法通过test_utils.py的测试 * 修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar * 1. 修复metric中的bug; 2.增加metric测试 * add model summary * 增加别名 * 删除encoder中的嵌套层 * 修改了 core 部分 import 的顺序,__all__ 暴露的内容 * 修改了 models 部分 import 的顺序,__all__ 暴露的内容 * 修改了文件名 * 修改了 modules 模块的__all__ 和 import * fix var runn * 增加vocab的clear方法 * 一些符合 PEP8 的微调 * 更新了cache_results的例子 * 1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index * 修改了一个typo * 修改了 README.md * update documents on bert * update documents on encoder/bert * 增加一个fitlog callback,实现与fitlog实验记录 * typo * - update dataset_loader * 增加了到 fitlog 文档的链接。 * 增加了 DataSet Loader 的文档 * - add star-transformer reproduction
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import unittest
  2. from fastNLP import Vocabulary
  3. class TestCRF(unittest.TestCase):
  4. def test_case1(self):
  5. # 检查allowed_transitions()能否正确使用
  6. from fastNLP.modules.decoder.crf import allowed_transitions
  7. id2label = {0: 'B', 1: 'I', 2:'O'}
  8. expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
  9. (2, 4), (3, 0), (3, 2)}
  10. self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True)))
  11. id2label = {0: 'B', 1:'M', 2:'E', 3:'S'}
  12. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
  13. self.assertSetEqual(expected_res, set(
  14. allowed_transitions(id2label, encoding_type='BMES', include_start_end=True)))
  15. id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"}
  16. allowed_transitions(id2label, include_start_end=True)
  17. labels = ['O']
  18. for label in ['X', 'Y']:
  19. for tag in 'BI':
  20. labels.append('{}-{}'.format(tag, label))
  21. id2label = {idx:label for idx, label in enumerate(labels)}
  22. expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
  23. (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
  24. (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
  25. self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True)))
  26. labels = []
  27. for label in ['X', 'Y']:
  28. for tag in 'BMES':
  29. labels.append('{}-{}'.format(tag, label))
  30. id2label = {idx: label for idx, label in enumerate(labels)}
  31. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
  32. (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
  33. (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
  34. self.assertSetEqual(expected_res, set(
  35. allowed_transitions(id2label, include_start_end=True)))
  36. def test_case11(self):
  37. # 测试自动推断encoding类型
  38. from fastNLP.modules.decoder.crf import allowed_transitions
  39. id2label = {0: 'B', 1: 'I', 2: 'O'}
  40. expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
  41. (2, 4), (3, 0), (3, 2)}
  42. self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True)))
  43. id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'}
  44. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
  45. self.assertSetEqual(expected_res, set(
  46. allowed_transitions(id2label, include_start_end=True)))
  47. id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"}
  48. allowed_transitions(id2label, include_start_end=True)
  49. labels = ['O']
  50. for label in ['X', 'Y']:
  51. for tag in 'BI':
  52. labels.append('{}-{}'.format(tag, label))
  53. id2label = {idx: label for idx, label in enumerate(labels)}
  54. expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
  55. (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
  56. (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
  57. self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True)))
  58. labels = []
  59. for label in ['X', 'Y']:
  60. for tag in 'BMES':
  61. labels.append('{}-{}'.format(tag, label))
  62. id2label = {idx: label for idx, label in enumerate(labels)}
  63. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
  64. (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
  65. (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
  66. self.assertSetEqual(expected_res, set(
  67. allowed_transitions(id2label, include_start_end=True)))
  68. def test_case12(self):
  69. # 测试能否通过vocab生成转移矩阵
  70. from fastNLP.modules.decoder.crf import allowed_transitions
  71. id2label = {0: 'B', 1: 'I', 2: 'O'}
  72. vocab = Vocabulary(unknown=None, padding=None)
  73. for idx, tag in id2label.items():
  74. vocab.add_word(tag)
  75. expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
  76. (2, 4), (3, 0), (3, 2)}
  77. self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True)))
  78. id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'}
  79. vocab = Vocabulary(unknown=None, padding=None)
  80. for idx, tag in id2label.items():
  81. vocab.add_word(tag)
  82. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
  83. self.assertSetEqual(expected_res, set(
  84. allowed_transitions(vocab, include_start_end=True)))
  85. id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"}
  86. vocab = Vocabulary()
  87. for idx, tag in id2label.items():
  88. vocab.add_word(tag)
  89. allowed_transitions(vocab, include_start_end=True)
  90. labels = ['O']
  91. for label in ['X', 'Y']:
  92. for tag in 'BI':
  93. labels.append('{}-{}'.format(tag, label))
  94. id2label = {idx: label for idx, label in enumerate(labels)}
  95. expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
  96. (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
  97. (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
  98. vocab = Vocabulary(unknown=None, padding=None)
  99. for idx, tag in id2label.items():
  100. vocab.add_word(tag)
  101. self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True)))
  102. labels = []
  103. for label in ['X', 'Y']:
  104. for tag in 'BMES':
  105. labels.append('{}-{}'.format(tag, label))
  106. id2label = {idx: label for idx, label in enumerate(labels)}
  107. vocab = Vocabulary(unknown=None, padding=None)
  108. for idx, tag in id2label.items():
  109. vocab.add_word(tag)
  110. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
  111. (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
  112. (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
  113. self.assertSetEqual(expected_res, set(
  114. allowed_transitions(vocab, include_start_end=True)))
  115. # def test_case2(self):
  116. # # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。
  117. # pass
  118. # import torch
  119. # from fastNLP import seq_len_to_mask
  120. #
  121. # labels = ['O']
  122. # for label in ['X', 'Y']:
  123. # for tag in 'BI':
  124. # labels.append('{}-{}'.format(tag, label))
  125. # id2label = {idx: label for idx, label in enumerate(labels)}
  126. # num_tags = len(id2label)
  127. # max_len = 10
  128. # batch_size = 4
  129. # bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log()
  130. # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
  131. # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label),
  132. # include_start_end_transitions=False)
  133. # bio_trans_m = allen_CRF.transitions
  134. # bio_seq_lens = torch.randint(1, max_len, size=(batch_size,))
  135. # bio_seq_lens[0] = 1
  136. # bio_seq_lens[-1] = max_len
  137. # mask = seq_len_to_mask(bio_seq_lens)
  138. # allen_res = allen_CRF.viterbi_tags(bio_logits, mask)
  139. #
  140. # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
  141. # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
  142. # include_start_end=True))
  143. # fast_CRF.trans_m = bio_trans_m
  144. # fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
  145. # bio_scores = [round(score, 4) for _, score in allen_res]
  146. # # score equal
  147. # self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()])
  148. # # seq equal
  149. # bio_path = [_ for _, score in allen_res]
  150. # self.assertListEqual(bio_path, fast_res[0])
  151. #
  152. # labels = []
  153. # for label in ['X', 'Y']:
  154. # for tag in 'BMES':
  155. # labels.append('{}-{}'.format(tag, label))
  156. # id2label = {idx: label for idx, label in enumerate(labels)}
  157. # num_tags = len(id2label)
  158. #
  159. # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
  160. # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label),
  161. # include_start_end_transitions=False)
  162. # bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log()
  163. # bmes_trans_m = allen_CRF.transitions
  164. # bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,))
  165. # bmes_seq_lens[0] = 1
  166. # bmes_seq_lens[-1] = max_len
  167. # mask = seq_len_to_mask(bmes_seq_lens)
  168. # allen_res = allen_CRF.viterbi_tags(bmes_logits, mask)
  169. #
  170. # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
  171. # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
  172. # encoding_type='BMES',
  173. # include_start_end=True))
  174. # fast_CRF.trans_m = bmes_trans_m
  175. # fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
  176. # # score equal
  177. # bmes_scores = [round(score, 4) for _, score in allen_res]
  178. # self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()])
  179. # # seq equal
  180. # bmes_path = [_ for _, score in allen_res]
  181. # self.assertListEqual(bmes_path, fast_res[0])
  182. #
  183. # data = {
  184. # 'bio_logits': bio_logits.tolist(),
  185. # 'bio_scores': bio_scores,
  186. # 'bio_path': bio_path,
  187. # 'bio_trans_m': bio_trans_m.tolist(),
  188. # 'bio_seq_lens': bio_seq_lens.tolist(),
  189. # 'bmes_logits': bmes_logits.tolist(),
  190. # 'bmes_scores': bmes_scores,
  191. # 'bmes_path': bmes_path,
  192. # 'bmes_trans_m': bmes_trans_m.tolist(),
  193. # 'bmes_seq_lens': bmes_seq_lens.tolist(),
  194. # }
  195. #
  196. # with open('weights.json', 'w') as f:
  197. # import json
  198. # json.dump(data, f)
  199. def test_case2(self):
  200. # 测试CRF是否正常work。
  201. import json
  202. import torch
  203. from fastNLP import seq_len_to_mask
  204. with open('test/data_for_tests/modules/decoder/crf.json', 'r') as f:
  205. data = json.load(f)
  206. bio_logits = torch.FloatTensor(data['bio_logits'])
  207. bio_scores = data['bio_scores']
  208. bio_path = data['bio_path']
  209. bio_trans_m = torch.FloatTensor(data['bio_trans_m'])
  210. bio_seq_lens = torch.LongTensor(data['bio_seq_lens'])
  211. bmes_logits = torch.FloatTensor(data['bmes_logits'])
  212. bmes_scores = data['bmes_scores']
  213. bmes_path = data['bmes_path']
  214. bmes_trans_m = torch.FloatTensor(data['bmes_trans_m'])
  215. bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens'])
  216. labels = ['O']
  217. for label in ['X', 'Y']:
  218. for tag in 'BI':
  219. labels.append('{}-{}'.format(tag, label))
  220. id2label = {idx: label for idx, label in enumerate(labels)}
  221. num_tags = len(id2label)
  222. mask = seq_len_to_mask(bio_seq_lens)
  223. from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
  224. fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
  225. include_start_end=True))
  226. fast_CRF.trans_m.data = bio_trans_m
  227. fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
  228. # score equal
  229. self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()])
  230. # seq equal
  231. self.assertListEqual(bio_path, fast_res[0])
  232. labels = []
  233. for label in ['X', 'Y']:
  234. for tag in 'BMES':
  235. labels.append('{}-{}'.format(tag, label))
  236. id2label = {idx: label for idx, label in enumerate(labels)}
  237. num_tags = len(id2label)
  238. mask = seq_len_to_mask(bmes_seq_lens)
  239. from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
  240. fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
  241. encoding_type='BMES',
  242. include_start_end=True))
  243. fast_CRF.trans_m.data = bmes_trans_m
  244. fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
  245. # score equal
  246. self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()])
  247. # seq equal
  248. self.assertListEqual(bmes_path, fast_res[0])
  249. def test_case3(self):
  250. # 测试crf的loss不会出现负数
  251. import torch
  252. from fastNLP.modules.decoder.crf import ConditionalRandomField
  253. from fastNLP.core.utils import seq_len_to_mask
  254. from torch import optim
  255. from torch import nn
  256. num_tags, include_start_end_trans = 4, True
  257. num_samples = 4
  258. lengths = torch.randint(3, 50, size=(num_samples, )).long()
  259. max_len = lengths.max()
  260. tags = torch.randint(num_tags, size=(num_samples, max_len))
  261. masks = seq_len_to_mask(lengths)
  262. feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
  263. crf = ConditionalRandomField(num_tags, include_start_end_trans)
  264. optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1)
  265. for _ in range(10):
  266. loss = crf(feats, tags, masks).mean()
  267. optimizer.zero_grad()
  268. loss.backward()
  269. optimizer.step()
  270. if _%1000==0:
  271. print(loss)
  272. self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")
  273. def test_masking(self):
  274. # 测试crf的pad masking正常运行
  275. import torch
  276. from fastNLP.modules.decoder.crf import ConditionalRandomField
  277. max_len = 5
  278. n_tags = 5
  279. pad_len = 5
  280. torch.manual_seed(4)
  281. logit = torch.rand(1, max_len+pad_len, n_tags)
  282. # logit[0, -1, :] = 0.0
  283. mask = torch.ones(1, max_len+pad_len)
  284. mask[0,-pad_len] = 0
  285. model = ConditionalRandomField(n_tags)
  286. pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len])
  287. mask_pred, mask_score = model.viterbi_decode(logit, mask)
  288. self.assertEqual(pred[0].tolist(), mask_pred[0,:-pad_len].tolist())