* 升级parser API和模型 * update docs: add new pages for tutorials * upgrade CWS api download source * add a new method for dataset field access * add introduction for bert * add more unit tests for api/processor * remove unused test data. Add new test data.tags/v0.4.10
@@ -1,7 +1,8 @@ | |||
fastNLP上手教程 | |||
fastNLP 10分钟上手教程 | |||
=============== | |||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_10min_tutorial.ipynb | |||
fastNLP提供方便的数据预处理,训练和测试模型的功能 | |||
DataSet & Instance | |||
@@ -2,6 +2,8 @@ | |||
FastNLP 1分钟上手教程 | |||
===================== | |||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_1min_tutorial.ipynb | |||
step 1 | |||
------ | |||
@@ -0,0 +1,5 @@ | |||
fastNLP 进阶教程 | |||
=============== | |||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb | |||
@@ -0,0 +1,5 @@ | |||
fastNLP 开发者指南 | |||
=============== | |||
原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/tutorial_for_developer.md | |||
@@ -5,6 +5,7 @@ Installation | |||
.. contents:: | |||
:local: | |||
Make sure your environment satisfies https://github.com/fastnlp/fastNLP/blob/master/requirements.txt . | |||
Run the following commands to install fastNLP package: | |||
@@ -6,4 +6,6 @@ Quickstart | |||
../tutorials/fastnlp_1_minute_tutorial | |||
../tutorials/fastnlp_10tmin_tutorial | |||
../tutorials/fastnlp_advanced_tutorial | |||
../tutorials/fastnlp_developer_guide | |||
@@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.utils import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag | |||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
@@ -17,9 +17,9 @@ from fastNLP.api.processor import IndexerProcessor | |||
# TODO add pretrain urls | |||
model_urls = { | |||
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl", | |||
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl", | |||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", | |||
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl" | |||
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl" | |||
} | |||
@@ -90,38 +90,28 @@ class POS(API): | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
# def decode_tags(ins): | |||
# pred_tags = ins["tag"] | |||
# chars = ins["words"] | |||
# words = [] | |||
# start_idx = 0 | |||
# for idx, tag in enumerate(pred_tags): | |||
# if tag[0] == "S": | |||
# words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||
# start_idx = idx + 1 | |||
# elif tag[0] == "E": | |||
# words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||
# start_idx = idx + 1 | |||
# return words | |||
# | |||
# dataset.apply(decode_tags, new_field_name="tag_output") | |||
def merge_tag(words_list, tags_list): | |||
rtn = [] | |||
for words, tags in zip(words_list, tags_list): | |||
rtn.append([w + "/" + t for w, t in zip(words, tags)]) | |||
return rtn | |||
output = dataset.field_arrays["tag"].content | |||
if isinstance(content, str): | |||
return output[0] | |||
elif isinstance(content, list): | |||
return output | |||
return merge_tag(content, output) | |||
def test(self, file_path): | |||
test_data = ConllxDataLoader().load(file_path) | |||
with open("model_pp_0117.pkl", "rb") as f: | |||
save_dict = torch.load(f) | |||
save_dict = self._dict | |||
tag_vocab = save_dict["tag_vocab"] | |||
pipeline = save_dict["pipeline"] | |||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||
pipeline.pipeline = [index_tag] + pipeline.pipeline | |||
test_data.rename_field("pos_tags", "tag") | |||
pipeline(test_data) | |||
test_data.set_target("truth") | |||
prediction = test_data.field_arrays["predict"].content | |||
@@ -235,7 +225,7 @@ class CWS(API): | |||
rec = eval_res['BMESF1PreRecMetric']['rec'] | |||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
return f1, pre, rec | |||
return {"F1": f1, "precision": pre, "recall": rec} | |||
class Parser(API): | |||
@@ -260,6 +250,7 @@ class Parser(API): | |||
dataset.add_field('wp', pos_out) | |||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | |||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | |||
dataset.rename_field("words", "raw_words") | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
@@ -269,31 +260,74 @@ class Parser(API): | |||
# output like: [['2/top', '0/root', '4/nn', '2/dep']] | |||
return dataset.field_arrays['output'].content | |||
def test(self, filepath): | |||
data = ConllxDataLoader().load(filepath) | |||
ds = DataSet() | |||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||
gold_words=ins2[0], gold_pos=ins2[1], | |||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||
def load_test_file(self, path): | |||
def get_one(sample): | |||
sample = list(map(list, zip(*sample))) | |||
if len(sample) == 0: | |||
return None | |||
for w in sample[7]: | |||
if w == '_': | |||
print('Error Sample {}'.format(sample)) | |||
return None | |||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split('\t')) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
data = [get_one(sample) for sample in datalist] | |||
data_list = list(filter(lambda x: x is not None, data)) | |||
return data_list | |||
def test(self, filepath): | |||
data = self.load_test_file(filepath) | |||
def convert(data): | |||
BOS = '<BOS>' | |||
dataset = DataSet() | |||
for sample in data: | |||
word_seq = [BOS] + sample[0] | |||
pos_seq = [BOS] + sample[1] | |||
heads = [0] + sample[2] | |||
head_tags = [BOS] + sample[3] | |||
dataset.append(Instance(raw_words=word_seq, | |||
pos=pos_seq, | |||
gold_heads=heads, | |||
arc_true=heads, | |||
tags=head_tags)) | |||
return dataset | |||
ds = convert(data) | |||
pp = self.pipeline | |||
for p in pp: | |||
if p.field_name == 'word_list': | |||
p.field_name = 'gold_words' | |||
elif p.field_name == 'pos_list': | |||
p.field_name = 'gold_pos' | |||
# ds.rename_field("words", "raw_words") | |||
# ds.rename_field("tag", "pos") | |||
pp(ds) | |||
head_cor, label_cor, total = 0, 0, 0 | |||
for ins in ds: | |||
head_gold = ins['gold_heads'] | |||
head_pred = ins['heads'] | |||
head_pred = ins['arc_pred'] | |||
length = len(head_gold) | |||
total += length | |||
for i in range(length): | |||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | |||
uas = head_cor / total | |||
print('uas:{:.2f}'.format(uas)) | |||
# print('uas:{:.2f}'.format(uas)) | |||
for p in pp: | |||
if p.field_name == 'gold_words': | |||
@@ -301,7 +335,7 @@ class Parser(API): | |||
elif p.field_name == 'gold_pos': | |||
p.field_name = 'pos_list' | |||
return uas | |||
return {"USA": round(uas, 5)} | |||
class Analyzer: | |||
@@ -15,19 +15,40 @@ def chinese_word_segmentation(): | |||
print(cws.predict(text)) | |||
def chinese_word_segmentation_test(): | |||
cws = CWS(device='cpu') | |||
print(cws.test("../../test/data_for_tests/zh_sample.conllx")) | |||
def pos_tagging(): | |||
# 输入已分词序列 | |||
text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。'] | |||
text = [text[0].split()] | |||
print(text) | |||
pos = POS(device='cpu') | |||
print(pos.predict(text)) | |||
def pos_tagging_test(): | |||
pos = POS(device='cpu') | |||
print(pos.test("../../test/data_for_tests/zh_sample.conllx")) | |||
def syntactic_parsing(): | |||
text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。'] | |||
text = [text[0].split()] | |||
parser = Parser(device='cpu') | |||
print(parser.predict(text)) | |||
def syntactic_parsing_test(): | |||
parser = Parser(device='cpu') | |||
print(parser.test("../../test/data_for_tests/zh_sample.conllx")) | |||
if __name__ == "__main__": | |||
chinese_word_segmentation() | |||
chinese_word_segmentation_test() | |||
pos_tagging() | |||
pos_tagging_test() | |||
syntactic_parsing() | |||
syntactic_parsing_test() |
@@ -102,6 +102,7 @@ class PreAppendProcessor(Processor): | |||
[data] + instance[field_name] | |||
""" | |||
def __init__(self, data, field_name, new_added_field_name=None): | |||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||
self.data = data | |||
@@ -116,6 +117,7 @@ class SliceProcessor(Processor): | |||
从某个field中只取部分内容。等价于instance[field_name][start:end:step] | |||
""" | |||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||
for o in (start, end, step): | |||
@@ -132,6 +134,7 @@ class Num2TagProcessor(Processor): | |||
将一句话中的数字转换为某个tag。 | |||
""" | |||
def __init__(self, tag, field_name, new_added_field_name=None): | |||
""" | |||
@@ -163,6 +166,7 @@ class IndexerProcessor(Processor): | |||
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | |||
['我', '是', xxx] | |||
""" | |||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | |||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||
@@ -215,6 +219,7 @@ class SeqLenProcessor(Processor): | |||
根据某个field新增一个sequence length的field。取该field的第一维 | |||
""" | |||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | |||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||
self.is_input = is_input | |||
@@ -229,6 +234,7 @@ class SeqLenProcessor(Processor): | |||
from fastNLP.core.utils import _build_args | |||
class ModelProcessor(Processor): | |||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | |||
""" | |||
@@ -292,6 +298,7 @@ class Index2WordProcessor(Processor): | |||
将DataSet中某个为index的field根据vocab转换为str | |||
""" | |||
def __init__(self, vocab, field_name, new_added_field_name): | |||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||
self.vocab = vocab | |||
@@ -303,7 +310,6 @@ class Index2WordProcessor(Processor): | |||
class SetTargetProcessor(Processor): | |||
# TODO; remove it. | |||
def __init__(self, *fields, flag=True): | |||
super(SetTargetProcessor, self).__init__(None, None) | |||
self.fields = fields | |||
@@ -313,6 +319,7 @@ class SetTargetProcessor(Processor): | |||
dataset.set_target(*self.fields, flag=self.flag) | |||
return dataset | |||
class SetInputProcessor(Processor): | |||
def __init__(self, *fields, flag=True): | |||
super(SetInputProcessor, self).__init__(None, None) | |||
@@ -92,6 +92,10 @@ class DataSet(object): | |||
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | |||
is_input=field.is_input, is_target=field.is_target) | |||
return data_set | |||
elif isinstance(idx, str): | |||
if idx not in self: | |||
raise KeyError("No such field called {} in DataSet.".format(idx)) | |||
return self.field_arrays[idx] | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
@@ -1,3 +1,7 @@ | |||
""" | |||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||
""" | |||
import copy | |||
import json | |||
import math | |||
@@ -220,7 +224,23 @@ class BertPooler(nn.Module): | |||
class BertModel(nn.Module): | |||
"""BERT model ("Bidirectional Embedding Representations from a Transformer"). | |||
"""Bidirectional Embedding Representations from Transformers. | |||
If you want to use pre-trained weights, please download from the following sources provided by pytorch-pretrained-BERT. | |||
sources:: | |||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||
Construct a BERT model with pre-trained weights:: | |||
model = BertModel.from_pretrained("path/to/weights/directory") | |||
""" | |||
@@ -1,5 +1,5 @@ | |||
[train] | |||
n_epochs = 1 | |||
n_epochs = 20 | |||
batch_size = 32 | |||
use_cuda = true | |||
use_tqdm=true | |||
@@ -1,9 +1,12 @@ | |||
import random | |||
import unittest | |||
from fastNLP import Vocabulary | |||
import numpy as np | |||
from fastNLP import Vocabulary, Instance | |||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \ | |||
IndexerProcessor, VocabProcessor, SeqLenProcessor | |||
IndexerProcessor, VocabProcessor, SeqLenProcessor, ModelProcessor, Index2WordProcessor, SetTargetProcessor, \ | |||
SetInputProcessor, VocabIndexerProcessor | |||
from fastNLP.core.dataset import DataSet | |||
@@ -53,3 +56,46 @@ class TestProcessor(unittest.TestCase): | |||
ds = proc(ds) | |||
for data in ds.field_arrays["len"].content: | |||
self.assertEqual(data, 30) | |||
def test_ModelProcessor(self): | |||
from fastNLP.models.cnn_text_classification import CNNText | |||
model = CNNText(100, 100, 5) | |||
ins_list = [] | |||
for _ in range(64): | |||
seq_len = np.random.randint(5, 30) | |||
ins_list.append(Instance(word_seq=[np.random.randint(0, 100) for _ in range(seq_len)], seq_lens=seq_len)) | |||
data_set = DataSet(ins_list) | |||
data_set.set_input("word_seq", "seq_lens") | |||
proc = ModelProcessor(model) | |||
data_set = proc(data_set) | |||
self.assertTrue("pred" in data_set) | |||
def test_Index2WordProcessor(self): | |||
vocab = Vocabulary() | |||
vocab.add_word_lst(["a", "b", "c", "d", "e"]) | |||
proc = Index2WordProcessor(vocab, "tag_id", "tag") | |||
data_set = DataSet([Instance(tag_id=[np.random.randint(0, 7) for _ in range(32)])]) | |||
data_set = proc(data_set) | |||
self.assertTrue("tag" in data_set) | |||
def test_SetTargetProcessor(self): | |||
proc = SetTargetProcessor("a", "b", "c") | |||
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | |||
data_set = proc(data_set) | |||
self.assertTrue(data_set["a"].is_target) | |||
self.assertTrue(data_set["b"].is_target) | |||
self.assertTrue(data_set["c"].is_target) | |||
def test_SetInputProcessor(self): | |||
proc = SetInputProcessor("a", "b", "c") | |||
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | |||
data_set = proc(data_set) | |||
self.assertTrue(data_set["a"].is_input) | |||
self.assertTrue(data_set["b"].is_input) | |||
self.assertTrue(data_set["c"].is_input) | |||
def test_VocabIndexerProcessor(self): | |||
proc = VocabIndexerProcessor("word_seq", "word_ids") | |||
data_set = DataSet([Instance(word_seq=["a", "b", "c", "d", "e"])]) | |||
data_set = proc(data_set) | |||
self.assertTrue("word_ids" in data_set) |
@@ -1,2 +0,0 @@ | |||
迈向充满希望的新世纪——一九九八年新年讲话 | |||
(附图片1张) |
@@ -0,0 +1,100 @@ | |||
1 上海 _ NR NR _ 3 nsubj _ _ | |||
2 积极 _ AD AD _ 3 advmod _ _ | |||
3 准备 _ VV VV _ 0 root _ _ | |||
4 迎接 _ VV VV _ 3 ccomp _ _ | |||
5 欧元 _ NN NN _ 6 nn _ _ | |||
6 诞生 _ NN NN _ 4 dobj _ _ | |||
1 新华社 _ NR NR _ 7 dep _ _ | |||
2 上海 _ NR NR _ 7 dep _ _ | |||
3 十二月 _ NT NT _ 7 dep _ _ | |||
4 三十日 _ NT NT _ 7 dep _ _ | |||
5 电 _ NN NN _ 7 dep _ _ | |||
6 ( _ PU PU _ 7 punct _ _ | |||
7 记者 _ NN NN _ 0 root _ _ | |||
8 潘清 _ NR NR _ 7 dep _ _ | |||
9 ) _ PU PU _ 7 punct _ _ | |||
1 即将 _ AD AD _ 2 advmod _ _ | |||
2 诞生 _ VV VV _ 4 rcmod _ _ | |||
3 的 _ DEC DEC _ 2 cpm _ _ | |||
4 欧元 _ NN NN _ 6 nsubj _ _ | |||
5 , _ PU PU _ 6 punct _ _ | |||
6 引起 _ VV VV _ 0 root _ _ | |||
7 了 _ AS AS _ 6 asp _ _ | |||
8 上海 _ NR NR _ 14 nn _ _ | |||
9 这 _ DT DT _ 14 det _ _ | |||
10 个 _ M M _ 9 clf _ _ | |||
11 中国 _ NR NR _ 13 nn _ _ | |||
12 金融 _ NN NN _ 13 nn _ _ | |||
13 中心 _ NN NN _ 14 nn _ _ | |||
14 城市 _ NN NN _ 16 assmod _ _ | |||
15 的 _ DEG DEG _ 14 assm _ _ | |||
16 关注 _ NN NN _ 6 dobj _ _ | |||
17 。 _ PU PU _ 6 punct _ _ | |||
1 上海 _ NR NR _ 2 nn _ _ | |||
2 银行界 _ NN NN _ 4 nsubj _ _ | |||
3 纷纷 _ AD AD _ 4 advmod _ _ | |||
4 推出 _ VV VV _ 0 root _ _ | |||
5 了 _ AS AS _ 4 asp _ _ | |||
6 与 _ P P _ 8 prep _ _ | |||
7 之 _ PN PN _ 6 pobj _ _ | |||
8 相关 _ VA VA _ 15 rcmod _ _ | |||
9 的 _ DEC DEC _ 8 cpm _ _ | |||
10 外汇 _ NN NN _ 15 nn _ _ | |||
11 业务 _ NN NN _ 15 nn _ _ | |||
12 品种 _ NN NN _ 15 conj _ _ | |||
13 和 _ CC CC _ 15 cc _ _ | |||
14 服务 _ NN NN _ 15 nn _ _ | |||
15 举措 _ NN NN _ 4 dobj _ _ | |||
16 , _ PU PU _ 4 punct _ _ | |||
17 积极 _ AD AD _ 18 advmod _ _ | |||
18 准备 _ VV VV _ 4 dep _ _ | |||
19 启动 _ VV VV _ 18 ccomp _ _ | |||
20 欧元 _ NN NN _ 21 nn _ _ | |||
21 业务 _ NN NN _ 19 dobj _ _ | |||
22 。 _ PU PU _ 4 punct _ _ | |||
1 一些 _ CD CD _ 8 nummod _ _ | |||
2 热衷于 _ VV VV _ 8 rcmod _ _ | |||
3 个人 _ NN NN _ 5 nn _ _ | |||
4 外汇 _ NN NN _ 5 nn _ _ | |||
5 交易 _ NN NN _ 2 dobj _ _ | |||
6 的 _ DEC DEC _ 2 cpm _ _ | |||
7 上海 _ NR NR _ 8 nn _ _ | |||
8 市民 _ NN NN _ 13 nsubj _ _ | |||
9 , _ PU PU _ 13 punct _ _ | |||
10 也 _ AD AD _ 13 advmod _ _ | |||
11 对 _ P P _ 13 prep _ _ | |||
12 欧元 _ NN NN _ 11 pobj _ _ | |||
13 表示 _ VV VV _ 0 root _ _ | |||
14 出 _ VV VV _ 13 rcomp _ _ | |||
15 极 _ AD AD _ 16 advmod _ _ | |||
16 大 _ VA VA _ 18 rcmod _ _ | |||
17 的 _ DEC DEC _ 16 cpm _ _ | |||
18 兴趣 _ NN NN _ 13 dobj _ _ | |||
19 。 _ PU PU _ 13 punct _ _ | |||
1 继 _ P P _ 38 prep _ _ | |||
2 上海 _ NR NR _ 6 nn _ _ | |||
3 大众 _ NR NR _ 6 nn _ _ | |||
4 汽车 _ NN NN _ 6 nn _ _ | |||
5 有限 _ JJ JJ _ 6 amod _ _ | |||
6 公司 _ NN NN _ 13 nsubj _ _ | |||
7 十八日 _ NT NT _ 13 tmod _ _ | |||
8 在 _ P P _ 13 prep _ _ | |||
9 中国 _ NR NR _ 10 nn _ _ | |||
10 银行 _ NN NN _ 12 nn _ _ | |||
11 上海 _ NR NR _ 12 nn _ _ | |||
12 分行 _ NN NN _ 8 pobj _ _ | |||
13 开立 _ VV VV _ 19 lccomp _ _ | |||
14 上海 _ NR NR _ 16 dep _ _ | |||
15 第一 _ OD OD _ 16 ordmod _ _ | |||
16 个 _ M M _ 18 clf _ _ | |||
17 欧元 _ NN NN _ 18 nn _ _ | |||
18 帐户 _ NN NN _ 13 dobj _ _ | |||
19 后 _ LC LC _ 1 plmod _ _ | |||
20 , _ PU PU _ 38 punct _ _ | |||
21 工商 _ NN NN _ 28 nn _ _ | |||
22 银行 _ NN NN _ 28 conj _ _ |