@@ -184,7 +184,7 @@ def add_r(base_path='../fastNLP'): | |||
for f in files: | |||
if f.endswith(".py"): | |||
check_file_r(os.path.abspath(os.path.join(path,f))) | |||
sys.exit(0) | |||
# sys.exit(0) | |||
if __name__ == "__main__": | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、 | |||
:mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import, | |||
例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式:: | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, | |||
具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API | |||
使用方式: | |||
@@ -124,11 +124,11 @@ class FastNLPLogger(logging.getLoggerClass()): | |||
super().__init__(name) | |||
def add_file(self, path='./log.txt', level='INFO'): | |||
"""add log output file and the output level""" | |||
r"""add log output file and the output level""" | |||
_add_file_handler(self, path, level) | |||
def set_stdout(self, stdout='tqdm', level='INFO'): | |||
"""set stdout format and the output level""" | |||
r"""set stdout format and the output level""" | |||
_set_stdout_handler(self, stdout, level) | |||
@@ -139,7 +139,7 @@ logging.setLoggerClass(FastNLPLogger) | |||
# print(logging.getLogger()) | |||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | |||
"""initialize logger""" | |||
r"""initialize logger""" | |||
level = _get_level(level) | |||
# logger = logging.getLogger() | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [] | |||
@@ -74,7 +74,7 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
""" | |||
r""" | |||
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | |||
:param str, func_name: 对network中的这个函数进行多卡运行 | |||
@@ -95,7 +95,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
def _model_contains_inner_module(model): | |||
""" | |||
r""" | |||
:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, | |||
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
batch 模块实现了 fastNLP 所需的 :class:`~fastNLP.core.batch.DataSetIter` 类。 | |||
""" | |||
@@ -49,7 +49,7 @@ def _pad(batch_dict, dataset, as_numpy): | |||
class DataSetGetter: | |||
""" | |||
r""" | |||
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 | |||
""" | |||
def __init__(self, dataset: DataSet, as_numpy=False): | |||
@@ -70,7 +70,7 @@ class DataSetGetter: | |||
return len(self.dataset) | |||
def collate_fn(self, ins_list: list): | |||
""" | |||
r""" | |||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | |||
:return: | |||
@@ -104,7 +104,7 @@ class DataSetGetter: | |||
class SamplerAdapter(torch.utils.data.Sampler): | |||
""" | |||
r""" | |||
用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int) | |||
""" | |||
@@ -121,7 +121,7 @@ class SamplerAdapter(torch.utils.data.Sampler): | |||
class BatchIter: | |||
""" | |||
r""" | |||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及dataset属性。 | |||
""" | |||
@@ -166,7 +166,7 @@ class BatchIter: | |||
@staticmethod | |||
def get_num_batches(num_samples, batch_size, drop_last): | |||
""" | |||
r""" | |||
计算batch的数量。用于前端显示进度 | |||
:param int num_samples: | |||
@@ -180,7 +180,7 @@ class BatchIter: | |||
return num_batches | |||
def get_batch_indices(self): | |||
""" | |||
r""" | |||
获取最近输出的batch的index。用于溯源当前batch的数据 | |||
:return: | |||
@@ -192,7 +192,7 @@ class BatchIter: | |||
@property | |||
def dataset(self): | |||
""" | |||
r""" | |||
获取正在参与iterate的dataset | |||
:return: | |||
@@ -201,7 +201,7 @@ class BatchIter: | |||
@abc.abstractmethod | |||
def __iter__(self): | |||
""" | |||
r""" | |||
用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target | |||
:return: | |||
@@ -210,7 +210,7 @@ class BatchIter: | |||
class DataSetIter(BatchIter): | |||
""" | |||
r""" | |||
DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||
组成 `x` 和 `y`:: | |||
@@ -223,7 +223,7 @@ class DataSetIter(BatchIter): | |||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None): | |||
""" | |||
r""" | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
@@ -258,7 +258,7 @@ class DataSetIter(BatchIter): | |||
class TorchLoaderIter(BatchIter): | |||
""" | |||
r""" | |||
与DataSetIter类似,但可以用于非fastNLP的数据容器对象,然后将其传入到Trainer中。 | |||
只需要保证数据容器实现了实现了以下的方法 | |||
@@ -387,7 +387,7 @@ class TorchLoaderIter(BatchIter): | |||
def __init__(self, dataset, batch_size=1, sampler=None, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None): | |||
""" | |||
r""" | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
@@ -421,7 +421,7 @@ class TorchLoaderIter(BatchIter): | |||
def _to_tensor(batch, field_dtype): | |||
""" | |||
r""" | |||
:param batch: np.array() | |||
:param field_dtype: 数据类型 | |||
@@ -101,7 +101,7 @@ except: | |||
class Callback(object): | |||
""" | |||
r""" | |||
Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 | |||
如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, | |||
具体调用时机可以通过 :mod:`trainer 模块<fastNLP.core.trainer>` 查看。 | |||
@@ -116,60 +116,60 @@ class Callback(object): | |||
@property | |||
def trainer(self): | |||
""" | |||
r""" | |||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | |||
""" | |||
return self._trainer | |||
@property | |||
def step(self): | |||
"""当前运行到的step, 范围为[1, self.n_steps+1)""" | |||
r"""当前运行到的step, 范围为[1, self.n_steps+1)""" | |||
return self._trainer.step | |||
@property | |||
def n_steps(self): | |||
"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" | |||
r"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" | |||
return self._trainer.n_steps | |||
@property | |||
def batch_size(self): | |||
"""train和evaluate时的batch_size为多大""" | |||
r"""train和evaluate时的batch_size为多大""" | |||
return self._trainer.batch_size | |||
@property | |||
def epoch(self): | |||
"""当前运行的epoch数,范围是[1, self.n_epochs+1)""" | |||
r"""当前运行的epoch数,范围是[1, self.n_epochs+1)""" | |||
return self._trainer.epoch | |||
@property | |||
def n_epochs(self): | |||
"""一共会运行多少个epoch""" | |||
r"""一共会运行多少个epoch""" | |||
return self._trainer.n_epochs | |||
@property | |||
def optimizer(self): | |||
"""初始化Trainer时传递的Optimizer""" | |||
r"""初始化Trainer时传递的Optimizer""" | |||
return self._trainer.optimizer | |||
@property | |||
def model(self): | |||
"""正在被Trainer训练的模型""" | |||
r"""正在被Trainer训练的模型""" | |||
return self._trainer.model | |||
@property | |||
def pbar(self): | |||
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在 | |||
r"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在 | |||
on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。""" | |||
return self._trainer.pbar | |||
@property | |||
def update_every(self): | |||
"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | |||
r"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | |||
return self._trainer.update_every | |||
@property | |||
def batch_per_epoch(self): | |||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||
r"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||
return self._trainer.batch_per_epoch | |||
@property | |||
@@ -185,7 +185,7 @@ class Callback(object): | |||
return getattr(self._trainer, 'logger', logger) | |||
def on_train_begin(self): | |||
""" | |||
r""" | |||
在Train过程开始之前调用。 | |||
:return: | |||
@@ -193,7 +193,7 @@ class Callback(object): | |||
pass | |||
def on_epoch_begin(self): | |||
""" | |||
r""" | |||
在每个epoch开始之前调用一次 | |||
:return: | |||
@@ -201,7 +201,7 @@ class Callback(object): | |||
pass | |||
def on_batch_begin(self, batch_x, batch_y, indices): | |||
""" | |||
r""" | |||
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | |||
可以进行一些负采样之类的操作 | |||
@@ -214,7 +214,7 @@ class Callback(object): | |||
pass | |||
def on_loss_begin(self, batch_y, predict_y): | |||
""" | |||
r""" | |||
在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。 | |||
:param dict batch_y: 在DataSet中被设置为target的field的batch集合。 | |||
@@ -224,7 +224,7 @@ class Callback(object): | |||
pass | |||
def on_backward_begin(self, loss): | |||
""" | |||
r""" | |||
在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。 | |||
:param torch.Tensor loss: 计算得到的loss值 | |||
@@ -233,7 +233,7 @@ class Callback(object): | |||
pass | |||
def on_backward_end(self): | |||
""" | |||
r""" | |||
反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。 | |||
:return: | |||
@@ -241,7 +241,7 @@ class Callback(object): | |||
pass | |||
def on_step_end(self): | |||
""" | |||
r""" | |||
到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。 | |||
:return: | |||
@@ -249,14 +249,14 @@ class Callback(object): | |||
pass | |||
def on_batch_end(self): | |||
""" | |||
r""" | |||
这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。 | |||
""" | |||
pass | |||
def on_valid_begin(self): | |||
""" | |||
r""" | |||
如果Trainer中设置了验证,则发生验证前会调用该函数 | |||
:return: | |||
@@ -264,7 +264,7 @@ class Callback(object): | |||
pass | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
""" | |||
r""" | |||
每次执行验证集的evaluation后会调用。 | |||
:param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即 | |||
@@ -277,19 +277,19 @@ class Callback(object): | |||
pass | |||
def on_epoch_end(self): | |||
""" | |||
r""" | |||
每个epoch结束将会调用该方法 | |||
""" | |||
pass | |||
def on_train_end(self): | |||
""" | |||
r""" | |||
训练结束,调用该方法 | |||
""" | |||
pass | |||
def on_exception(self, exception): | |||
""" | |||
r""" | |||
当训练过程出现异常,会触发该方法 | |||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 | |||
""" | |||
@@ -297,7 +297,7 @@ class Callback(object): | |||
def _transfer(func): | |||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||
r"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||
:param func: | |||
:return: | |||
@@ -315,11 +315,11 @@ def _transfer(func): | |||
class CallbackManager(Callback): | |||
""" | |||
r""" | |||
内部使用的Callback管理类 | |||
""" | |||
def __init__(self, env, callbacks=None): | |||
""" | |||
r""" | |||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
:param List[Callback] callbacks: | |||
@@ -433,12 +433,12 @@ class DistCallbackManager(CallbackManager): | |||
class GradientClipCallback(Callback): | |||
""" | |||
r""" | |||
每次backward前,将parameter的gradient clip到某个范围。 | |||
""" | |||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | |||
""" | |||
r""" | |||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||
如果为None则默认对Trainer的model中所有参数进行clip | |||
@@ -477,12 +477,12 @@ class GradientClipCallback(Callback): | |||
class EarlyStopCallback(Callback): | |||
""" | |||
r""" | |||
多少个epoch没有变好就停止训练,相关类 :class:`~fastNLP.core.callback.EarlyStopError` | |||
""" | |||
def __init__(self, patience): | |||
""" | |||
r""" | |||
:param int patience: epoch的数量 | |||
""" | |||
@@ -508,7 +508,7 @@ class EarlyStopCallback(Callback): | |||
class FitlogCallback(Callback): | |||
""" | |||
r""" | |||
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | |||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
@@ -516,7 +516,7 @@ class FitlogCallback(Callback): | |||
""" | |||
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): | |||
""" | |||
r""" | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | |||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | |||
@@ -608,13 +608,13 @@ class FitlogCallback(Callback): | |||
class EvaluateCallback(Callback): | |||
""" | |||
r""" | |||
通过使用该Callback可以使得Trainer在evaluate dev之外还可以evaluate其它数据集,比如测试集。每一次验证dev之前都会先验证EvaluateCallback | |||
中的数据。 | |||
""" | |||
def __init__(self, data=None, tester=None): | |||
""" | |||
r""" | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象, 通过使用Tester对象,可以使得验证的metric与Trainer中 | |||
@@ -668,12 +668,12 @@ class EvaluateCallback(Callback): | |||
raise e | |||
class LRScheduler(Callback): | |||
""" | |||
r""" | |||
对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||
""" | |||
def __init__(self, lr_scheduler): | |||
""" | |||
r""" | |||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||
""" | |||
super(LRScheduler, self).__init__() | |||
@@ -688,12 +688,12 @@ class LRScheduler(Callback): | |||
class ControlC(Callback): | |||
""" | |||
r""" | |||
检测到 control+C 时的反馈 | |||
""" | |||
def __init__(self, quit_all): | |||
""" | |||
r""" | |||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
""" | |||
super(ControlC, self).__init__() | |||
@@ -713,26 +713,26 @@ class ControlC(Callback): | |||
class SmoothValue(object): | |||
"""work for LRFinder""" | |||
r"""work for LRFinder""" | |||
def __init__(self, beta: float): | |||
self.beta, self.n, self.mov_avg = beta, 0, 0 | |||
self.smooth = None | |||
def add_value(self, val: float) -> None: | |||
"""Add `val` to calculate updated smoothed value.""" | |||
r"""Add `val` to calculate updated smoothed value.""" | |||
self.n += 1 | |||
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | |||
self.smooth = self.mov_avg / (1 - self.beta ** self.n) | |||
class LRFinder(Callback): | |||
""" | |||
r""" | |||
用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||
""" | |||
def __init__(self, start_lr=1e-6, end_lr=10): | |||
""" | |||
r""" | |||
:param float start_lr: 学习率下界 | |||
:param float end_lr: 学习率上界 | |||
@@ -798,7 +798,7 @@ class LRFinder(Callback): | |||
class TensorboardCallback(Callback): | |||
""" | |||
r""" | |||
接受以下一个或多个字符串作为参数: | |||
- "model" | |||
- "loss" | |||
@@ -873,7 +873,7 @@ class TensorboardCallback(Callback): | |||
class CheckPointCallback(Callback): | |||
def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True): | |||
""" | |||
r""" | |||
用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练 | |||
一段示例代码 | |||
Example1:: | |||
@@ -918,7 +918,7 @@ class CheckPointCallback(Callback): | |||
logger.error("Fail to recovery the fitlog states.") | |||
def on_train_begin(self): | |||
""" | |||
r""" | |||
当train开始时,且需要恢复上次训练时,会做以下的操作 | |||
(1) 重新加载model权重 | |||
(2) 重新加载optimizer的状态 | |||
@@ -944,7 +944,7 @@ class CheckPointCallback(Callback): | |||
self.trainer.best_metric_indicator = states['best_metric_indicator'] | |||
def on_epoch_end(self): | |||
""" | |||
r""" | |||
保存状态,使得结果可以被恢复 | |||
:param self: | |||
@@ -984,11 +984,11 @@ class CheckPointCallback(Callback): | |||
class WarmupCallback(Callback): | |||
""" | |||
r""" | |||
learning rate按照一定的速率从0上升到设置的learning rate。 | |||
""" | |||
def __init__(self, warmup=0.1, schedule='constant'): | |||
""" | |||
r""" | |||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||
@@ -1035,7 +1035,7 @@ class WarmupCallback(Callback): | |||
class SaveModelCallback(Callback): | |||
""" | |||
r""" | |||
由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | |||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型:: | |||
@@ -1047,7 +1047,7 @@ class SaveModelCallback(Callback): | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||
""" | |||
def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False): | |||
""" | |||
r""" | |||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型。如果save_dir不存在将自动创建 | |||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | |||
@@ -1116,12 +1116,12 @@ class SaveModelCallback(Callback): | |||
class CallbackException(BaseException): | |||
""" | |||
r""" | |||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | |||
""" | |||
def __init__(self, msg): | |||
""" | |||
r""" | |||
:param str msg: Exception的信息。 | |||
""" | |||
@@ -1129,7 +1129,7 @@ class CallbackException(BaseException): | |||
class EarlyStopError(CallbackException): | |||
""" | |||
r""" | |||
用于EarlyStop时从Trainer训练循环中跳出。 | |||
""" | |||
@@ -1139,7 +1139,7 @@ class EarlyStopError(CallbackException): | |||
class EchoCallback(Callback): | |||
""" | |||
r""" | |||
用于测试分布式训练 | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
from builtins import sorted | |||
import torch | |||
@@ -37,7 +37,7 @@ def batching(samples, max_len=0, padding_val=0): | |||
class Collector: | |||
""" | |||
r""" | |||
辅助DataSet管理collect_fn的类 | |||
""" | |||
@@ -45,7 +45,7 @@ class Collector: | |||
self.collect_fns = {} | |||
def add_fn(self, fn, name=None): | |||
""" | |||
r""" | |||
向collector新增一个collect_fn函数 | |||
:param callable fn: | |||
@@ -59,7 +59,7 @@ class Collector: | |||
self.collect_fns[name] = fn | |||
def is_empty(self): | |||
""" | |||
r""" | |||
返回是否包含collect_fn | |||
:return: | |||
@@ -67,7 +67,7 @@ class Collector: | |||
return len(self.collect_fns)==0 | |||
def delete_fn(self, name=None): | |||
""" | |||
r""" | |||
删除collect_fn | |||
:param str,int name: 如果为None就删除最近加入的collect_fn | |||
@@ -100,7 +100,7 @@ class Collector: | |||
class ConcatCollectFn: | |||
""" | |||
r""" | |||
field拼接collect_fn,将不同field按序拼接后,padding产生数据。 | |||
:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | |||
@@ -8,7 +8,7 @@ __all__ = [ | |||
class Const: | |||
""" | |||
r""" | |||
fastNLP中field命名常量。 | |||
.. todo:: | |||
@@ -37,48 +37,48 @@ class Const: | |||
@staticmethod | |||
def INPUTS(i): | |||
"""得到第 i 个 ``INPUT`` 的命名""" | |||
r"""得到第 i 个 ``INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT + str(i) | |||
@staticmethod | |||
def CHAR_INPUTS(i): | |||
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||
r"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.CHAR_INPUT + str(i) | |||
@staticmethod | |||
def RAW_WORDS(i): | |||
"""得到第 i 个 ``RAW_WORDS`` 的命名""" | |||
r"""得到第 i 个 ``RAW_WORDS`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.RAW_WORD + str(i) | |||
@staticmethod | |||
def RAW_CHARS(i): | |||
"""得到第 i 个 ``RAW_CHARS`` 的命名""" | |||
r"""得到第 i 个 ``RAW_CHARS`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.RAW_CHAR + str(i) | |||
@staticmethod | |||
def INPUT_LENS(i): | |||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
r"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT_LEN + str(i) | |||
@staticmethod | |||
def OUTPUTS(i): | |||
"""得到第 i 个 ``OUTPUT`` 的命名""" | |||
r"""得到第 i 个 ``OUTPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.OUTPUT + str(i) | |||
@staticmethod | |||
def TARGETS(i): | |||
"""得到第 i 个 ``TARGET`` 的命名""" | |||
r"""得到第 i 个 ``TARGET`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.TARGET + str(i) | |||
@staticmethod | |||
def LOSSES(i): | |||
"""得到第 i 个 ``LOSS`` 的命名""" | |||
r"""得到第 i 个 ``LOSS`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.LOSS + str(i) |
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
:class:`~fastNLP.core.dataset.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格, | |||
每一行是一个sample (在fastNLP中被称为 :mod:`~fastNLP.core.instance` ), | |||
每一列是一个feature (在fastNLP中称为 :mod:`~fastNLP.core.field` )。 | |||
@@ -380,12 +380,12 @@ class ApplyResultException(Exception): | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
class DataSet(object): | |||
""" | |||
r""" | |||
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | |||
""" | |||
def __init__(self, data=None): | |||
""" | |||
r""" | |||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
@@ -447,7 +447,7 @@ class DataSet(object): | |||
return inner_iter_func() | |||
def __getitem__(self, idx): | |||
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||
r"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||
:param idx: can be int or slice. | |||
:return: If `idx` is int, return an Instance object. | |||
@@ -495,7 +495,7 @@ class DataSet(object): | |||
return self.__dict__ | |||
def __len__(self): | |||
"""Fetch the length of the dataset. | |||
r"""Fetch the length of the dataset. | |||
:return length: | |||
""" | |||
@@ -508,7 +508,7 @@ class DataSet(object): | |||
return str(pretty_table_printer(self)) | |||
def print_field_meta(self): | |||
""" | |||
r""" | |||
输出当前field的meta信息, 形似下列的输出:: | |||
+-------------+-------+-------+ | |||
@@ -564,7 +564,7 @@ class DataSet(object): | |||
return table | |||
def append(self, instance): | |||
""" | |||
r""" | |||
将一个instance对象append到DataSet后面。 | |||
:param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。 | |||
@@ -589,7 +589,7 @@ class DataSet(object): | |||
raise e | |||
def add_fieldarray(self, field_name, fieldarray): | |||
""" | |||
r""" | |||
将fieldarray添加到DataSet中. | |||
:param str field_name: 新加入的field的名称 | |||
@@ -604,7 +604,7 @@ class DataSet(object): | |||
self.field_arrays[field_name] = fieldarray | |||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | |||
""" | |||
r""" | |||
新增一个field | |||
:param str field_name: 新增的field的名称 | |||
@@ -623,7 +623,7 @@ class DataSet(object): | |||
padder=padder, ignore_type=ignore_type) | |||
def delete_instance(self, index): | |||
""" | |||
r""" | |||
删除第index个instance | |||
:param int index: 需要删除的instance的index,序号从0开始。 | |||
@@ -639,7 +639,7 @@ class DataSet(object): | |||
return self | |||
def delete_field(self, field_name): | |||
""" | |||
r""" | |||
删除名为field_name的field | |||
:param str field_name: 需要删除的field的名称. | |||
@@ -648,7 +648,7 @@ class DataSet(object): | |||
return self | |||
def copy_field(self, field_name, new_field_name): | |||
""" | |||
r""" | |||
深度copy名为field_name的field到new_field_name | |||
:param str field_name: 需要copy的field。 | |||
@@ -662,7 +662,7 @@ class DataSet(object): | |||
return self | |||
def has_field(self, field_name): | |||
""" | |||
r""" | |||
判断DataSet中是否有名为field_name这个field | |||
:param str field_name: field的名称 | |||
@@ -673,7 +673,7 @@ class DataSet(object): | |||
return False | |||
def get_field(self, field_name): | |||
""" | |||
r""" | |||
获取field_name这个field | |||
:param str field_name: field的名称 | |||
@@ -684,7 +684,7 @@ class DataSet(object): | |||
return self.field_arrays[field_name] | |||
def get_all_fields(self): | |||
""" | |||
r""" | |||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||
:return dict: 返回如上所述的字典 | |||
@@ -692,7 +692,7 @@ class DataSet(object): | |||
return self.field_arrays | |||
def get_field_names(self) -> list: | |||
""" | |||
r""" | |||
返回一个list,包含所有 field 的名字 | |||
:return list: 返回如上所述的列表 | |||
@@ -700,7 +700,7 @@ class DataSet(object): | |||
return sorted(self.field_arrays.keys()) | |||
def get_length(self): | |||
""" | |||
r""" | |||
获取DataSet的元素数量 | |||
:return: int: DataSet中Instance的个数。 | |||
@@ -708,7 +708,7 @@ class DataSet(object): | |||
return len(self) | |||
def rename_field(self, field_name, new_field_name): | |||
""" | |||
r""" | |||
将某个field重新命名. | |||
:param str field_name: 原来的field名称。 | |||
@@ -722,7 +722,7 @@ class DataSet(object): | |||
return self | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
r""" | |||
将field_names的field设置为target | |||
Example:: | |||
@@ -749,7 +749,7 @@ class DataSet(object): | |||
return self | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
r""" | |||
将field_names的field设置为input:: | |||
dataset.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | |||
@@ -773,7 +773,7 @@ class DataSet(object): | |||
return self | |||
def set_ignore_type(self, *field_names, flag=True): | |||
""" | |||
r""" | |||
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | |||
默认情况下也不进行pad。如果仍需要pad该field,可通过自定义Padder实现,若该field需要转换为tensor,需要在padder | |||
中转换,但不需要在padder中移动到gpu。 | |||
@@ -791,7 +791,7 @@ class DataSet(object): | |||
return self | |||
def set_padder(self, field_name, padder): | |||
""" | |||
r""" | |||
为field_name设置padder:: | |||
from fastNLP import EngChar2DPadder | |||
@@ -807,7 +807,7 @@ class DataSet(object): | |||
return self | |||
def set_pad_val(self, field_name, pad_val): | |||
""" | |||
r""" | |||
为某个field设置对应的pad_val. | |||
:param str field_name: 修改该field的pad_val | |||
@@ -819,7 +819,7 @@ class DataSet(object): | |||
return self | |||
def get_input_name(self): | |||
""" | |||
r""" | |||
返回所有is_input被设置为True的field名称 | |||
:return list: 里面的元素为被设置为input的field名称 | |||
@@ -827,7 +827,7 @@ class DataSet(object): | |||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||
def get_target_name(self): | |||
""" | |||
r""" | |||
返回所有is_target被设置为True的field名称 | |||
:return list: 里面的元素为被设置为target的field名称 | |||
@@ -835,7 +835,7 @@ class DataSet(object): | |||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
""" | |||
r""" | |||
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
@@ -858,7 +858,7 @@ class DataSet(object): | |||
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) | |||
def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): | |||
""" | |||
r""" | |||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | |||
func 可以返回一个或多个 field 上的结果。 | |||
@@ -885,7 +885,7 @@ class DataSet(object): | |||
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) | |||
def _add_apply_field(self, results, new_field_name, kwargs): | |||
""" | |||
r""" | |||
将results作为加入到新的field中,field名称为new_field_name | |||
:param List[str] results: 一般是apply*()之后的结果 | |||
@@ -917,7 +917,7 @@ class DataSet(object): | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
def apply_more(self, func, modify_fields=True, **kwargs): | |||
""" | |||
r""" | |||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | |||
.. note:: | |||
@@ -978,7 +978,7 @@ class DataSet(object): | |||
return results | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
""" | |||
r""" | |||
将DataSet中每个instance传入到func中,并获取它的返回值. | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` | |||
@@ -1015,7 +1015,7 @@ class DataSet(object): | |||
return results | |||
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | |||
""" | |||
r""" | |||
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。 | |||
:param field_name: str. | |||
@@ -1029,7 +1029,7 @@ class DataSet(object): | |||
return self | |||
def drop(self, func, inplace=True): | |||
""" | |||
r""" | |||
func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。 | |||
:param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance | |||
@@ -1053,7 +1053,7 @@ class DataSet(object): | |||
return DataSet() | |||
def split(self, ratio, shuffle=True): | |||
""" | |||
r""" | |||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `(1-ratio)` 这么多数据,第二个DataSet拥有`ratio`这么多数据 | |||
@@ -1088,7 +1088,7 @@ class DataSet(object): | |||
return train_set, dev_set | |||
def save(self, path): | |||
""" | |||
r""" | |||
保存DataSet. | |||
:param str path: 将DataSet存在哪个路径 | |||
@@ -1110,7 +1110,7 @@ class DataSet(object): | |||
return d | |||
def add_collect_fn(self, fn, name=None): | |||
""" | |||
r""" | |||
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | |||
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 | |||
@@ -1126,7 +1126,7 @@ class DataSet(object): | |||
self.collector.add_fn(fn, name=name) | |||
def delete_collect_fn(self, name=None): | |||
""" | |||
r""" | |||
删除某个collect_fn | |||
:param str,int name: 如果为None,则删除最近加入的collect_fn | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
分布式 Trainer | |||
使用步骤 | |||
1. 在代码中调用 DistTrainer,类似 Trainer,传入模型和数据等等参数 | |||
@@ -41,7 +41,7 @@ __all__ = [ | |||
] | |||
def get_local_rank(): | |||
""" | |||
r""" | |||
返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数 | |||
""" | |||
if 'LOCAL_RANK' in os.environ: | |||
@@ -57,7 +57,7 @@ def get_local_rank(): | |||
class DistTrainer(): | |||
""" | |||
r""" | |||
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 | |||
Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, | |||
@@ -71,7 +71,7 @@ class DistTrainer(): | |||
update_every=1, print_every=10, validate_every=-1, | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None, use_tqdm=True): | |||
""" | |||
r""" | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param nn.modules model: 待训练的模型 | |||
@@ -227,11 +227,11 @@ class DistTrainer(): | |||
@property | |||
def is_master(self): | |||
"""是否是主进程""" | |||
r"""是否是主进程""" | |||
return self.rank == 0 | |||
def train(self, load_best_model=True, on_exception='auto'): | |||
""" | |||
r""" | |||
使用该函数使Trainer开始训练。 | |||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | |||
@@ -374,7 +374,7 @@ class DistTrainer(): | |||
# ============ tqdm end ============== # | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
r"""Perform weight update on a model. | |||
""" | |||
if self.step % self.update_every == 0: | |||
@@ -390,7 +390,7 @@ class DistTrainer(): | |||
return y | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
r"""Compute loss given prediction and ground truth. | |||
:param predict: prediction dict, produced by model.forward | |||
:param truth: ground truth dict, produced by batch_y | |||
@@ -404,7 +404,7 @@ class DistTrainer(): | |||
return loss | |||
def save_check_point(self, name=None, only_params=False): | |||
"""保存当前模型""" | |||
r"""保存当前模型""" | |||
# only master save models | |||
if self.is_master: | |||
if name is None: | |||
@@ -446,5 +446,5 @@ class DistTrainer(): | |||
dist.barrier() | |||
def close(self): | |||
"""关闭Trainer,销毁进程""" | |||
r"""关闭Trainer,销毁进程""" | |||
dist.destroy_process_group() |
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -88,7 +88,7 @@ class FieldArray: | |||
@is_input.setter | |||
def is_input(self, value): | |||
""" | |||
r""" | |||
当 field_array.is_input = True / False 时被调用 | |||
""" | |||
# 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) | |||
@@ -107,7 +107,7 @@ class FieldArray: | |||
@is_target.setter | |||
def is_target(self, value): | |||
""" | |||
r""" | |||
当 field_array.is_target = True / False 时被调用 | |||
""" | |||
if value is True and \ | |||
@@ -120,7 +120,7 @@ class FieldArray: | |||
self._is_target = value | |||
def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | |||
""" | |||
r""" | |||
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | |||
通过将直接报错. | |||
@@ -150,7 +150,7 @@ class FieldArray: | |||
raise e | |||
def append(self, val: Any): | |||
""" | |||
r""" | |||
:param val: 把该val append到fieldarray。 | |||
:return: | |||
""" | |||
@@ -167,7 +167,7 @@ class FieldArray: | |||
self.content.append(val) | |||
def pop(self, index): | |||
""" | |||
r""" | |||
删除该field中index处的元素 | |||
:param int index: 从0开始的数据下标。 | |||
:return: | |||
@@ -190,7 +190,7 @@ class FieldArray: | |||
self.content[idx] = val | |||
def get(self, indices, pad=True): | |||
""" | |||
r""" | |||
根据给定的indices返回内容。 | |||
:param int,List[int] indices: 获取indices对应的内容。 | |||
@@ -210,7 +210,7 @@ class FieldArray: | |||
return np.array(contents) | |||
def pad(self, contents): | |||
""" | |||
r""" | |||
传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 | |||
:param list contents: | |||
@@ -219,7 +219,7 @@ class FieldArray: | |||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
def set_padder(self, padder): | |||
""" | |||
r""" | |||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | |||
:param padder: :class:`~fastNLP.Padder` 类型,设置为None即删除padder。 | |||
@@ -231,7 +231,7 @@ class FieldArray: | |||
self.padder = None | |||
def set_pad_val(self, pad_val): | |||
""" | |||
r""" | |||
修改padder的pad_val. | |||
:param int pad_val: 该field的pad值设置为该值。 | |||
@@ -241,7 +241,7 @@ class FieldArray: | |||
return self | |||
def __len__(self): | |||
""" | |||
r""" | |||
Returns the size of FieldArray. | |||
:return int length: | |||
@@ -249,7 +249,7 @@ class FieldArray: | |||
return len(self.content) | |||
def to(self, other): | |||
""" | |||
r""" | |||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | |||
属性包括 is_input, is_target, padder, ignore_type | |||
@@ -266,7 +266,7 @@ class FieldArray: | |||
return self | |||
def split(self, sep: str = None, inplace: bool = True): | |||
""" | |||
r""" | |||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||
:param sep: 分割符,如果为None则直接调用str.split()。 | |||
@@ -283,7 +283,7 @@ class FieldArray: | |||
return self._after_process(new_contents, inplace=inplace) | |||
def int(self, inplace: bool = True): | |||
""" | |||
r""" | |||
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
@@ -303,7 +303,7 @@ class FieldArray: | |||
return self._after_process(new_contents, inplace=inplace) | |||
def float(self, inplace=True): | |||
""" | |||
r""" | |||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
@@ -323,7 +323,7 @@ class FieldArray: | |||
return self._after_process(new_contents, inplace=inplace) | |||
def bool(self, inplace=True): | |||
""" | |||
r""" | |||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
@@ -344,7 +344,7 @@ class FieldArray: | |||
return self._after_process(new_contents, inplace=inplace) | |||
def lower(self, inplace=True): | |||
""" | |||
r""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
@@ -364,7 +364,7 @@ class FieldArray: | |||
return self._after_process(new_contents, inplace=inplace) | |||
def upper(self, inplace=True): | |||
""" | |||
r""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
@@ -384,7 +384,7 @@ class FieldArray: | |||
return self._after_process(new_contents, inplace=inplace) | |||
def value_count(self): | |||
""" | |||
r""" | |||
返回该field下不同value的数量。多用于统计label数量 | |||
:return: Counter, key是label,value是出现次数 | |||
@@ -403,7 +403,7 @@ class FieldArray: | |||
return count | |||
def _after_process(self, new_contents, inplace): | |||
""" | |||
r""" | |||
当调用处理函数之后,决定是否要替换field。 | |||
:param new_contents: | |||
@@ -424,7 +424,7 @@ class FieldArray: | |||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||
""" | |||
r""" | |||
识别cell的类别与dimension的数量 | |||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||
@@ -470,7 +470,7 @@ def _get_ele_type_and_dim(cell: Any, dim=0): | |||
class Padder: | |||
""" | |||
r""" | |||
所有padder都需要继承这个类,并覆盖__call__方法。 | |||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||
@@ -479,7 +479,7 @@ class Padder: | |||
""" | |||
def __init__(self, pad_val=0, **kwargs): | |||
""" | |||
r""" | |||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
@@ -497,7 +497,7 @@ class Padder: | |||
@abstractmethod | |||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | |||
""" | |||
r""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
@@ -541,7 +541,7 @@ class Padder: | |||
class AutoPadder(Padder): | |||
""" | |||
r""" | |||
根据contents的数据自动判定是否需要做padding。 | |||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
@@ -633,7 +633,7 @@ class AutoPadder(Padder): | |||
class EngChar2DPadder(Padder): | |||
""" | |||
r""" | |||
用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||
但这个Padder只能处理index为int的情况。 | |||
@@ -655,7 +655,7 @@ class EngChar2DPadder(Padder): | |||
""" | |||
def __init__(self, pad_val=0, pad_length=0): | |||
""" | |||
r""" | |||
:param pad_val: int, pad的位置使用该index | |||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度 | |||
都pad或截取到该长度. | |||
@@ -665,7 +665,7 @@ class EngChar2DPadder(Padder): | |||
self.pad_length = pad_length | |||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||
""" | |||
r""" | |||
期望输入类似于 | |||
[ | |||
[[0, 2], [2, 3, 4], ..], | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 | |||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格 | |||
@@ -12,7 +12,7 @@ from .utils import pretty_table_printer | |||
class Instance(object): | |||
""" | |||
r""" | |||
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | |||
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | |||
@@ -29,7 +29,7 @@ class Instance(object): | |||
self.fields = fields | |||
def add_field(self, field_name, field): | |||
""" | |||
r""" | |||
向Instance中增加一个field | |||
:param str field_name: 新增field的名称 | |||
@@ -38,7 +38,7 @@ class Instance(object): | |||
self.fields[field_name] = field | |||
def items(self): | |||
""" | |||
r""" | |||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | |||
:return: 一个迭代器 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
""" | |||
@@ -34,7 +34,7 @@ from ..core.const import Const | |||
class LossBase(object): | |||
""" | |||
r""" | |||
所有loss的基类。如果想了解其中的原理,请查看源码。 | |||
""" | |||
@@ -55,7 +55,7 @@ class LossBase(object): | |||
raise NotImplementedError | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
r"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
:param dict key_map: 表示key的映射关系 | |||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | |||
@@ -102,7 +102,7 @@ class LossBase(object): | |||
# f"positional argument.).") | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
""" | |||
r""" | |||
:param dict pred_dict: 模型的forward函数返回的dict | |||
:param dict target_dict: DataSet.batch_y里的键-值对所组成的dict | |||
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | |||
@@ -168,7 +168,7 @@ class LossBase(object): | |||
class LossFunc(LossBase): | |||
""" | |||
r""" | |||
提供给用户使用自定义损失函数的类 | |||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | |||
@@ -199,7 +199,7 @@ class LossFunc(LossBase): | |||
class CrossEntropyLoss(LossBase): | |||
""" | |||
r""" | |||
交叉熵损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -246,7 +246,7 @@ class CrossEntropyLoss(LossBase): | |||
class L1Loss(LossBase): | |||
""" | |||
r""" | |||
L1损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -266,7 +266,7 @@ class L1Loss(LossBase): | |||
class BCELoss(LossBase): | |||
""" | |||
r""" | |||
二分类交叉熵损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -285,12 +285,12 @@ class BCELoss(LossBase): | |||
class NLLLoss(LossBase): | |||
""" | |||
r""" | |||
负对数似然损失函数 | |||
""" | |||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||
""" | |||
r""" | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
@@ -309,12 +309,12 @@ class NLLLoss(LossBase): | |||
class LossInForward(LossBase): | |||
""" | |||
r""" | |||
从forward()函数返回结果中获取loss | |||
""" | |||
def __init__(self, loss_key=Const.LOSS): | |||
""" | |||
r""" | |||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||
""" | |||
@@ -349,7 +349,7 @@ class LossInForward(LossBase): | |||
class CMRC2018Loss(LossBase): | |||
""" | |||
r""" | |||
用于计算CMRC2018中文问答任务。 | |||
""" | |||
@@ -364,7 +364,7 @@ class CMRC2018Loss(LossBase): | |||
self.reduction = reduction | |||
def get_loss(self, target_start, target_end, context_len, pred_start, pred_end): | |||
""" | |||
r""" | |||
:param target_start: batch_size | |||
:param target_end: batch_size | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
""" | |||
@@ -33,7 +33,7 @@ from .utils import ConfusionMatrix | |||
class MetricBase(object): | |||
""" | |||
r""" | |||
所有metrics的基类,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | |||
evaluate(xxx)中传入的是一个batch的数据。 | |||
@@ -145,7 +145,7 @@ class MetricBase(object): | |||
raise NotImplemented | |||
def set_metric_name(self, name: str): | |||
""" | |||
r""" | |||
设置metric的名称,默认是Metric的class name. | |||
:param str name: | |||
@@ -155,7 +155,7 @@ class MetricBase(object): | |||
return self | |||
def get_metric_name(self): | |||
""" | |||
r""" | |||
返回metric的名称 | |||
:return: | |||
@@ -163,7 +163,7 @@ class MetricBase(object): | |||
return self._metric_name | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
r"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
:param dict key_map: 表示key的映射关系 | |||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | |||
@@ -205,7 +205,7 @@ class MetricBase(object): | |||
f"initialization parameters, or change its signature.") | |||
def __call__(self, pred_dict, target_dict): | |||
""" | |||
r""" | |||
这个方法会调用self.evaluate 方法. | |||
在调用之前,会进行以下检测: | |||
1. self.evaluate当中是否有varargs, 这是不支持的. | |||
@@ -315,7 +315,7 @@ class ConfusionMatrixMetric(MetricBase): | |||
seq_len=None, | |||
print_ratio=False | |||
): | |||
""" | |||
r""" | |||
:param vocab: vocab词表类,要求有to_word()方法。 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
@@ -330,7 +330,7 @@ class ConfusionMatrixMetric(MetricBase): | |||
) | |||
def evaluate(self, pred, target, seq_len=None): | |||
""" | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
@@ -379,7 +379,7 @@ class ConfusionMatrixMetric(MetricBase): | |||
target.tolist()) | |||
def get_metric(self, reset=True): | |||
""" | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | |||
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} | |||
@@ -394,12 +394,12 @@ class ConfusionMatrixMetric(MetricBase): | |||
class AccuracyMetric(MetricBase): | |||
""" | |||
r""" | |||
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | |||
""" | |||
def __init__(self, pred=None, target=None, seq_len=None): | |||
""" | |||
r""" | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
@@ -414,7 +414,7 @@ class AccuracyMetric(MetricBase): | |||
self.acc_count = 0 | |||
def evaluate(self, pred, target, seq_len=None): | |||
""" | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
@@ -463,7 +463,7 @@ class AccuracyMetric(MetricBase): | |||
self.total += np.prod(list(pred.size())) | |||
def get_metric(self, reset=True): | |||
""" | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | |||
@@ -537,7 +537,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||
# tp: truth=T, classify=T; fp: truth=T, classify=F; fn: truth=F, classify=T | |||
def evaluate(self, pred, target, seq_len=None): | |||
""" | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
@@ -586,7 +586,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||
self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item() | |||
def get_metric(self, reset=True): | |||
""" | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | |||
@@ -646,7 +646,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||
def _bmes_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
r""" | |||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | |||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | |||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | |||
@@ -676,7 +676,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||
def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
@@ -707,7 +707,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||
def _bioes_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
@@ -738,7 +738,7 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): | |||
def _bio_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
@@ -766,7 +766,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | |||
""" | |||
r""" | |||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
@@ -802,7 +802,7 @@ def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str | |||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | |||
""" | |||
r""" | |||
检查vocab中的tag是否与encoding_type是匹配的 | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
@@ -913,7 +913,7 @@ class SpanFPreRecMetric(MetricBase): | |||
self._false_negatives = defaultdict(int) | |||
def evaluate(self, pred, target, seq_len): | |||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
r"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | |||
:param target: [batch, seq_len], 真实值 | |||
@@ -967,7 +967,7 @@ class SpanFPreRecMetric(MetricBase): | |||
self._false_negatives[span[0]] += 1 | |||
def get_metric(self, reset=True): | |||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||
r"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._false_negatives.keys()) | |||
@@ -1018,7 +1018,7 @@ class SpanFPreRecMetric(MetricBase): | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
""" | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
@@ -1033,7 +1033,7 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
def _prepare_metrics(metrics): | |||
""" | |||
r""" | |||
Prepare list of Metric based on input | |||
:param metrics: | |||
@@ -1064,7 +1064,7 @@ def _prepare_metrics(metrics): | |||
def _accuracy_topk(y_true, y_prob, k=1): | |||
"""Compute accuracy of y_true matching top-k probable labels in y_prob. | |||
r"""Compute accuracy of y_true matching top-k probable labels in y_prob. | |||
:param y_true: ndarray, true label, [n_samples] | |||
:param y_prob: ndarray, label probabilities, [n_samples, n_classes] | |||
@@ -1080,7 +1080,7 @@ def _accuracy_topk(y_true, y_prob, k=1): | |||
def _pred_topk(y_prob, k=1): | |||
"""Return top-k predicted labels and corresponding probabilities. | |||
r"""Return top-k predicted labels and corresponding probabilities. | |||
:param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels | |||
:param k: int, k of top-k | |||
@@ -1110,7 +1110,7 @@ class CMRC2018Metric(MetricBase): | |||
self.f1 = 0 | |||
def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None): | |||
""" | |||
r""" | |||
:param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | |||
:param list[str] raw_chars: [["这", "是", ...], [...]] | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
""" | |||
@@ -16,12 +16,12 @@ from torch.optim.optimizer import Optimizer as TorchOptimizer | |||
class Optimizer(object): | |||
""" | |||
r""" | |||
Optimizer | |||
""" | |||
def __init__(self, model_params, **kwargs): | |||
""" | |||
r""" | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
:param kwargs: additional parameters. | |||
@@ -36,7 +36,7 @@ class Optimizer(object): | |||
@staticmethod | |||
def _get_require_grads_param(params): | |||
""" | |||
r""" | |||
将params中不需要gradient的删除 | |||
:param iterable params: parameters | |||
@@ -46,7 +46,7 @@ class Optimizer(object): | |||
class NullOptimizer(Optimizer): | |||
""" | |||
r""" | |||
当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | |||
""" | |||
@@ -64,12 +64,12 @@ class NullOptimizer(Optimizer): | |||
class SGD(Optimizer): | |||
""" | |||
r""" | |||
SGD | |||
""" | |||
def __init__(self, lr=0.001, momentum=0, model_params=None): | |||
""" | |||
r""" | |||
:param float lr: learning rate. Default: 0.01 | |||
:param float momentum: momentum. Default: 0 | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
@@ -87,12 +87,12 @@ class SGD(Optimizer): | |||
class Adam(Optimizer): | |||
""" | |||
r""" | |||
Adam | |||
""" | |||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | |||
""" | |||
r""" | |||
:param float lr: learning rate | |||
:param float weight_decay: | |||
@@ -133,7 +133,7 @@ class AdamW(TorchOptimizer): | |||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | |||
weight_decay=1e-2, amsgrad=False): | |||
""" | |||
r""" | |||
:param params (iterable): iterable of parameters to optimize or dicts defining | |||
parameter groups | |||
@@ -164,7 +164,7 @@ class AdamW(TorchOptimizer): | |||
group.setdefault('amsgrad', False) | |||
def step(self, closure=None): | |||
"""Performs a single optimization step. | |||
r"""Performs a single optimization step. | |||
:param closure: (callable, optional) A closure that reevaluates the model | |||
and returns the loss. | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"Predictor" | |||
@@ -15,7 +15,7 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
class Predictor(object): | |||
""" | |||
r""" | |||
一个根据训练模型预测输出的预测器(Predictor) | |||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||
@@ -23,7 +23,7 @@ class Predictor(object): | |||
""" | |||
def __init__(self, network): | |||
""" | |||
r""" | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
@@ -35,7 +35,7 @@ class Predictor(object): | |||
self.batch_output = [] | |||
def predict(self, data: DataSet, seq_len_field_name=None): | |||
"""用已经训练好的模型进行inference. | |||
r"""用已经训练好的模型进行inference. | |||
:param fastNLP.DataSet data: 待预测的数据集 | |||
:param str seq_len_field_name: 表示序列长度信息的field名字 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
sampler 子类实现了 fastNLP 所需的各种采样器。 | |||
""" | |||
__all__ = [ | |||
@@ -14,14 +14,14 @@ import numpy as np | |||
class Sampler(object): | |||
""" | |||
r""" | |||
`Sampler` 类的基类. 规定以何种顺序取出data中的元素 | |||
子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | |||
""" | |||
def __call__(self, data_set): | |||
""" | |||
r""" | |||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | |||
:return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | |||
""" | |||
@@ -29,7 +29,7 @@ class Sampler(object): | |||
class SequentialSampler(Sampler): | |||
""" | |||
r""" | |||
顺序取出元素的 `Sampler` | |||
""" | |||
@@ -39,7 +39,7 @@ class SequentialSampler(Sampler): | |||
class RandomSampler(Sampler): | |||
""" | |||
r""" | |||
随机化取元素的 `Sampler` | |||
""" | |||
@@ -49,12 +49,12 @@ class RandomSampler(Sampler): | |||
class BucketSampler(Sampler): | |||
""" | |||
r""" | |||
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | |||
""" | |||
def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'): | |||
""" | |||
r""" | |||
:param int num_buckets: bucket的数量 | |||
:param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 | |||
@@ -66,7 +66,7 @@ class BucketSampler(Sampler): | |||
self.seq_len_field_name = seq_len_field_name | |||
def set_batch_size(self, batch_size): | |||
""" | |||
r""" | |||
:param int batch_size: 每个batch的大小 | |||
:return: | |||
@@ -111,7 +111,7 @@ class BucketSampler(Sampler): | |||
def simple_sort_bucketing(lengths): | |||
""" | |||
r""" | |||
:param lengths: list of int, the lengths of all examples. | |||
:return data: 2-level list | |||
@@ -131,7 +131,7 @@ def simple_sort_bucketing(lengths): | |||
def k_means_1d(x, k, max_iter=100): | |||
"""Perform k-means on 1-D data. | |||
r"""Perform k-means on 1-D data. | |||
:param x: list of int, representing points in 1-D. | |||
:param k: the number of clusters required. | |||
@@ -161,7 +161,7 @@ def k_means_1d(x, k, max_iter=100): | |||
def k_means_bucketing(lengths, buckets): | |||
"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths. | |||
r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths. | |||
:param lengths: list of int, the length of all samples. | |||
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
tester模块实现了 fastNLP 所需的Tester类,能在提供数据、模型以及metric的情况下进行性能测试。 | |||
.. code-block:: | |||
@@ -64,12 +64,12 @@ __all__ = [ | |||
class Tester(object): | |||
""" | |||
r""" | |||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 | |||
""" | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | |||
""" | |||
r""" | |||
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 | |||
:param torch.nn.Module model: 使用的模型 | |||
@@ -196,7 +196,7 @@ class Tester(object): | |||
return eval_results | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
r"""Train mode or Test mode. This is for PyTorch currently. | |||
:param model: a PyTorch model | |||
:param is_test: bool, whether in test mode or not. | |||
@@ -208,13 +208,13 @@ class Tester(object): | |||
model.train() | |||
def _data_forward(self, func, x): | |||
"""A forward pass of the model. """ | |||
r"""A forward pass of the model. """ | |||
x = _build_args(func, **x) | |||
y = self._predict_func_wrapper(**x) | |||
return y | |||
def _format_eval_results(self, results): | |||
"""Override this method to support more print formats. | |||
r"""Override this method to support more print formats. | |||
:param results: dict, (str: float) is (metrics name: value) | |||
@@ -357,7 +357,7 @@ from ._logger import logger | |||
class Trainer(object): | |||
""" | |||
r""" | |||
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 | |||
(1) epoch循环; | |||
(2) 将数据分成不同的Batch; | |||
@@ -572,7 +572,7 @@ class Trainer(object): | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True, on_exception='auto'): | |||
""" | |||
r""" | |||
使用该函数使Trainer开始训练。 | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
@@ -728,7 +728,7 @@ class Trainer(object): | |||
return res | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
r"""Train mode or Test mode. This is for PyTorch currently. | |||
:param model: a PyTorch model | |||
:param bool is_test: whether in test mode or not. | |||
@@ -740,7 +740,7 @@ class Trainer(object): | |||
model.train() | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
r"""Perform weight update on a model. | |||
""" | |||
if self.step % self.update_every == 0: | |||
@@ -755,7 +755,7 @@ class Trainer(object): | |||
return y | |||
def _grad_backward(self, loss): | |||
"""Compute gradient with link rules. | |||
r"""Compute gradient with link rules. | |||
:param loss: a scalar where back-prop starts | |||
@@ -766,7 +766,7 @@ class Trainer(object): | |||
loss.backward() | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
r"""Compute loss given prediction and ground truth. | |||
:param predict: prediction dict, produced by model.forward | |||
:param truth: ground truth dict, produced by batch_y | |||
@@ -775,7 +775,7 @@ class Trainer(object): | |||
return self.losser(predict, truth) | |||
def _save_model(self, model, model_name, only_param=False): | |||
""" 存储不含有显卡信息的state_dict或model | |||
r""" 存储不含有显卡信息的state_dict或model | |||
:param model: | |||
:param model_name: | |||
:param only_param: | |||
@@ -816,7 +816,7 @@ class Trainer(object): | |||
return True | |||
def _better_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
r"""Check if the current epoch yields better validation results. | |||
:return bool value: True means current results on dev set is the best. | |||
""" | |||
@@ -842,7 +842,7 @@ class Trainer(object): | |||
@property | |||
def is_master(self): | |||
"""是否是主进程""" | |||
r"""是否是主进程""" | |||
return True | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||
""" | |||
@@ -35,9 +35,9 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||
class ConfusionMatrix: | |||
"""a dict can provide Confusion Matrix""" | |||
r"""a dict can provide Confusion Matrix""" | |||
def __init__(self, vocab=None, print_ratio=False): | |||
""" | |||
r""" | |||
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 | |||
:param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列 | |||
""" | |||
@@ -52,7 +52,7 @@ class ConfusionMatrix: | |||
self.print_ratio = print_ratio | |||
def add_pred_target(self, pred, target): # 一组结果 | |||
""" | |||
r""" | |||
通过这个函数向ConfusionMatrix加入一组预测结果 | |||
:param list pred: 预测的标签列表 | |||
:param list target: 真实值的标签列表 | |||
@@ -80,7 +80,7 @@ class ConfusionMatrix: | |||
return self.confusiondict | |||
def clear(self): | |||
""" | |||
r""" | |||
清空ConfusionMatrix,等待再次新加入 | |||
:return: | |||
""" | |||
@@ -89,7 +89,7 @@ class ConfusionMatrix: | |||
self.predcount = {} | |||
def get_result(self): | |||
""" | |||
r""" | |||
:return list output: ConfusionMatrix content,具体值与汇总统计 | |||
""" | |||
row2idx = {} | |||
@@ -121,7 +121,7 @@ class ConfusionMatrix: | |||
return output | |||
def get_percent(self, dim=0): | |||
""" | |||
r""" | |||
:param dim int: 0/1, 0 for row,1 for column | |||
:return list output: ConfusionMatrix content,具体值与汇总统计 | |||
""" | |||
@@ -139,7 +139,7 @@ class ConfusionMatrix: | |||
return tmp.tolist() | |||
def get_aligned_table(self, data, flag="result"): | |||
""" | |||
r""" | |||
:param data: highly recommend use get_percent/ get_result return as dataset here, or make sure data is a n*n list type data | |||
:param flag: only difference between result and other words is whether "%" is in output string | |||
:return: an aligned_table ready to print out | |||
@@ -197,7 +197,7 @@ class ConfusionMatrix: | |||
return "\n" + out | |||
def __repr__(self): | |||
""" | |||
r""" | |||
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 | |||
""" | |||
result = self.get_result() | |||
@@ -218,7 +218,7 @@ class ConfusionMatrix: | |||
class Option(dict): | |||
"""a dict can treat keys as attributes""" | |||
r"""a dict can treat keys as attributes""" | |||
def __getattr__(self, item): | |||
try: | |||
@@ -245,7 +245,7 @@ class Option(dict): | |||
def _prepare_cache_filepath(filepath): | |||
""" | |||
r""" | |||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||
:param filepath: str. | |||
:return: None, if not, this function will raise error | |||
@@ -259,7 +259,7 @@ def _prepare_cache_filepath(filepath): | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
""" | |||
r""" | |||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | |||
import time | |||
@@ -358,7 +358,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
def _save_model(model, model_name, save_dir, only_param=False): | |||
""" 存储不含有显卡信息的state_dict或model | |||
r""" 存储不含有显卡信息的state_dict或model | |||
:param model: | |||
:param model_name: | |||
:param save_dir: 保存的directory | |||
@@ -383,7 +383,7 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||
def _move_model_to_device(model, device): | |||
""" | |||
r""" | |||
将model移动到device | |||
:param model: torch.nn.DataParallel or torch.nn.Module. 当为torch.nn.DataParallel, 则只是调用一次cuda。device必须为 | |||
@@ -454,7 +454,7 @@ def _move_model_to_device(model, device): | |||
def _get_model_device(model): | |||
""" | |||
r""" | |||
传入一个nn.Module的模型,获取它所在的device | |||
:param model: nn.Module | |||
@@ -471,7 +471,7 @@ def _get_model_device(model): | |||
def _build_args(func, **kwargs): | |||
""" | |||
r""" | |||
根据func的初始化参数,从kwargs中选择func需要的参数 | |||
:param func: callable | |||
@@ -555,7 +555,7 @@ def _check_arg_dict_list(func, args): | |||
def _get_func_signature(func): | |||
""" | |||
r""" | |||
Given a function or method, return its signature. | |||
For example: | |||
@@ -596,7 +596,7 @@ def _get_func_signature(func): | |||
def _is_function_or_method(func): | |||
""" | |||
r""" | |||
:param func: | |||
:return: | |||
@@ -612,7 +612,7 @@ def _check_function_or_method(func): | |||
def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||
""" | |||
r""" | |||
move data to model's device, element in *args should be dict. This is a inplace change. | |||
:param device: torch.device | |||
@@ -636,7 +636,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||
class _CheckError(Exception): | |||
""" | |||
r""" | |||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
@@ -807,7 +807,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
def seq_len_to_mask(seq_len, max_len=None): | |||
""" | |||
r""" | |||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||
转变 1-d seq_len到2-d mask. | |||
@@ -851,7 +851,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||
class _pseudo_tqdm: | |||
""" | |||
r""" | |||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||
""" | |||
@@ -878,7 +878,7 @@ class _pseudo_tqdm: | |||
def iob2(tags: List[str]) -> List[str]: | |||
""" | |||
r""" | |||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见 | |||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | |||
@@ -902,7 +902,7 @@ def iob2(tags: List[str]) -> List[str]: | |||
def iob2bioes(tags: List[str]) -> List[str]: | |||
""" | |||
r""" | |||
将iob的tag转换为bioes编码 | |||
:param tags: List[str]. 编码需要是大写的。 | |||
:return: | |||
@@ -938,7 +938,7 @@ def _is_iterable(value): | |||
def get_seq_len(words, pad_value=0): | |||
""" | |||
r""" | |||
给定batch_size x max_len的words矩阵,返回句子长度 | |||
:param words: batch_size x max_len | |||
@@ -949,7 +949,7 @@ def get_seq_len(words, pad_value=0): | |||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
""" | |||
r""" | |||
:param dataset_or_ins: 传入一个dataSet或者instance | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||
+-----------+-----------+-----------------+ | |||
@@ -990,7 +990,7 @@ def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
""" | |||
r""" | |||
:param string: 要被截断的字符串 | |||
:param c: 命令行列数 | |||
:param c_size: instance或dataset field数 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -33,7 +33,7 @@ class VocabularyOption(Option): | |||
def _check_build_vocab(func): | |||
"""A decorator to make sure the indexing is built before used. | |||
r"""A decorator to make sure the indexing is built before used. | |||
""" | |||
@@ -47,7 +47,7 @@ def _check_build_vocab(func): | |||
def _check_build_status(func): | |||
"""A decorator to check whether the vocabulary updates after the last build. | |||
r"""A decorator to check whether the vocabulary updates after the last build. | |||
""" | |||
@@ -65,7 +65,7 @@ def _check_build_status(func): | |||
class Vocabulary(object): | |||
""" | |||
r""" | |||
用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | |||
vocab = Vocabulary() | |||
@@ -76,7 +76,7 @@ class Vocabulary(object): | |||
""" | |||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
""" | |||
r""" | |||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||
@@ -121,7 +121,7 @@ class Vocabulary(object): | |||
@_check_build_status | |||
def update(self, word_lst, no_create_entry=False): | |||
"""依次增加序列中词在词典中的出现频率 | |||
r"""依次增加序列中词在词典中的出现频率 | |||
:param list word_lst: a list of strings | |||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||
@@ -137,7 +137,7 @@ class Vocabulary(object): | |||
@_check_build_status | |||
def add(self, word, no_create_entry=False): | |||
""" | |||
r""" | |||
增加一个新词在词典中的出现频率 | |||
:param str word: 新词 | |||
@@ -153,7 +153,7 @@ class Vocabulary(object): | |||
return self | |||
def _add_no_create_entry(self, word, no_create_entry): | |||
""" | |||
r""" | |||
在新加入word时,检查_no_create_word的设置。 | |||
:param str List[str] word: | |||
@@ -170,7 +170,7 @@ class Vocabulary(object): | |||
@_check_build_status | |||
def add_word(self, word, no_create_entry=False): | |||
""" | |||
r""" | |||
增加一个新词在词典中的出现频率 | |||
:param str word: 新词 | |||
@@ -185,7 +185,7 @@ class Vocabulary(object): | |||
@_check_build_status | |||
def add_word_lst(self, word_lst, no_create_entry=False): | |||
""" | |||
r""" | |||
依次增加序列中词在词典中的出现频率 | |||
:param list[str] word_lst: 词的序列 | |||
@@ -200,7 +200,7 @@ class Vocabulary(object): | |||
return self | |||
def build_vocab(self): | |||
""" | |||
r""" | |||
根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | |||
但已经记录在词典中的词, 不会改变对应的 `int` | |||
@@ -225,7 +225,7 @@ class Vocabulary(object): | |||
return self | |||
def build_reverse_vocab(self): | |||
""" | |||
r""" | |||
基于 `word to index` dict, 构建 `index to word` dict. | |||
""" | |||
@@ -238,7 +238,7 @@ class Vocabulary(object): | |||
@_check_build_vocab | |||
def __contains__(self, item): | |||
""" | |||
r""" | |||
检查词是否被记录 | |||
:param item: the word | |||
@@ -247,7 +247,7 @@ class Vocabulary(object): | |||
return item in self._word2idx | |||
def has_word(self, w): | |||
""" | |||
r""" | |||
检查词是否被记录:: | |||
has_abc = vocab.has_word('abc') | |||
@@ -261,7 +261,7 @@ class Vocabulary(object): | |||
@_check_build_vocab | |||
def __getitem__(self, w): | |||
""" | |||
r""" | |||
To support usage like:: | |||
vocab[w] | |||
@@ -275,7 +275,7 @@ class Vocabulary(object): | |||
@_check_build_vocab | |||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
""" | |||
r""" | |||
将DataSet中对应field的词转为数字,Example:: | |||
# remember to use `field_name` | |||
@@ -289,7 +289,7 @@ class Vocabulary(object): | |||
""" | |||
def index_instance(field): | |||
""" | |||
r""" | |||
有几种情况, str, 1d-list, 2d-list | |||
:param ins: | |||
:return: | |||
@@ -333,7 +333,7 @@ class Vocabulary(object): | |||
return len(self._no_create_word) | |||
def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None): | |||
""" | |||
r""" | |||
使用dataset的对应field中词构建词典:: | |||
# remember to use `field_name` | |||
@@ -395,7 +395,7 @@ class Vocabulary(object): | |||
return self | |||
def _is_word_no_create_entry(self, word): | |||
""" | |||
r""" | |||
判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 | |||
:param word: str | |||
:return: bool | |||
@@ -403,7 +403,7 @@ class Vocabulary(object): | |||
return word in self._no_create_word | |||
def to_index(self, w): | |||
""" | |||
r""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | |||
index = vocab.to_index('abc') | |||
@@ -418,7 +418,7 @@ class Vocabulary(object): | |||
@property | |||
@_check_build_vocab | |||
def unknown_idx(self): | |||
""" | |||
r""" | |||
unknown 对应的数字. | |||
""" | |||
if self.unknown is None: | |||
@@ -428,7 +428,7 @@ class Vocabulary(object): | |||
@property | |||
@_check_build_vocab | |||
def padding_idx(self): | |||
""" | |||
r""" | |||
padding 对应的数字 | |||
""" | |||
if self.padding is None: | |||
@@ -437,7 +437,7 @@ class Vocabulary(object): | |||
@_check_build_vocab | |||
def to_word(self, idx): | |||
""" | |||
r""" | |||
给定一个数字, 将其转为对应的词. | |||
:param int idx: the index | |||
@@ -446,7 +446,7 @@ class Vocabulary(object): | |||
return self._idx2word[idx] | |||
def clear(self): | |||
""" | |||
r""" | |||
删除Vocabulary中的词表数据。相当于重新初始化一下。 | |||
:return: | |||
@@ -459,7 +459,7 @@ class Vocabulary(object): | |||
return self | |||
def __getstate__(self): | |||
"""Use to prepare data for pickle. | |||
r"""Use to prepare data for pickle. | |||
""" | |||
len(self) # make sure vocab has been built | |||
@@ -469,7 +469,7 @@ class Vocabulary(object): | |||
return state | |||
def __setstate__(self, state): | |||
"""Use to restore state from pickle. | |||
r"""Use to restore state from pickle. | |||
""" | |||
self.__dict__.update(state) | |||
@@ -484,7 +484,7 @@ class Vocabulary(object): | |||
yield word, index | |||
def save(self, filepath): | |||
""" | |||
r""" | |||
:param str filepath: Vocabulary的储存路径 | |||
:return: | |||
@@ -508,7 +508,7 @@ class Vocabulary(object): | |||
@staticmethod | |||
def load(filepath): | |||
""" | |||
r""" | |||
:param str filepath: Vocabulary的读取路径 | |||
:return: Vocabulary | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented | |||
r"""undocumented | |||
用于辅助生成 fastNLP 文档的代码 | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有 | |||
embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的 | |||
torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -24,7 +24,7 @@ from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | |||
class BertEmbedding(ContextualEmbedding): | |||
""" | |||
r""" | |||
使用BERT对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | |||
预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | |||
时切分),在分割之后长度可能会超过最大长度限制。 | |||
@@ -57,7 +57,7 @@ class BertEmbedding(ContextualEmbedding): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | |||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | |||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs): | |||
""" | |||
r""" | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | |||
@@ -112,7 +112,7 @@ class BertEmbedding(ContextualEmbedding): | |||
del self.model | |||
def forward(self, words): | |||
""" | |||
r""" | |||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||
删除这两个token的表示。 | |||
@@ -129,7 +129,7 @@ class BertEmbedding(ContextualEmbedding): | |||
return self.dropout(outputs) | |||
def drop_word(self, words): | |||
""" | |||
r""" | |||
按照设定随机将words设置为unknown_index。 | |||
:param torch.LongTensor words: batch_size x max_len | |||
@@ -151,7 +151,7 @@ class BertEmbedding(ContextualEmbedding): | |||
class BertWordPieceEncoder(nn.Module): | |||
""" | |||
r""" | |||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | |||
BertWordPieceEncoder可以支持自动下载权重,当前支持的模型: | |||
@@ -170,7 +170,7 @@ class BertWordPieceEncoder(nn.Module): | |||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||
""" | |||
r""" | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
@@ -204,7 +204,7 @@ class BertWordPieceEncoder(nn.Module): | |||
return self.model.encoder.config.vocab_size | |||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | |||
""" | |||
r""" | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | |||
bert的pad value。 | |||
@@ -216,7 +216,7 @@ class BertWordPieceEncoder(nn.Module): | |||
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | |||
def forward(self, word_pieces, token_type_ids=None): | |||
""" | |||
r""" | |||
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | |||
:param words: batch_size x max_len | |||
@@ -239,7 +239,7 @@ class BertWordPieceEncoder(nn.Module): | |||
return self.dropout_layer(outputs) | |||
def drop_word(self, words): | |||
""" | |||
r""" | |||
按照设定随机将words设置为unknown_index。 | |||
:param torch.LongTensor words: batch_size x max_len | |||
@@ -353,7 +353,7 @@ class _WordBertModel(nn.Module): | |||
logger.debug("Successfully generate word pieces.") | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param words: torch.LongTensor, batch_size x max_len | |||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
该文件中主要包含的是character的Embedding,包括基于CNN与LSTM的character Embedding。与其它Embedding一样,这里的Embedding输入也是 | |||
词的index而不需要使用词语中的char的index来获取表达。 | |||
""" | |||
@@ -24,7 +24,7 @@ from ..modules.encoder.lstm import LSTM | |||
class CNNCharEmbedding(TokenEmbedding): | |||
""" | |||
r""" | |||
使用CNN生成character embedding。CNN的结构为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | |||
不同的kernel大小的fitler结果是concat起来然后通过一层fully connected layer, 然后输出word的表示。 | |||
@@ -46,7 +46,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1), | |||
pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None, | |||
requires_grad:bool=True, include_word_start_end:bool=True): | |||
""" | |||
r""" | |||
:param vocab: 词表 | |||
:param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50. | |||
@@ -122,7 +122,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
self.requires_grad = requires_grad | |||
def forward(self, words): | |||
""" | |||
r""" | |||
输入words的index后,生成对应的words的表示。 | |||
:param words: [batch_size, max_len] | |||
@@ -155,7 +155,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
class LSTMCharEmbedding(TokenEmbedding): | |||
""" | |||
r""" | |||
使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout | |||
Example:: | |||
@@ -176,7 +176,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu', | |||
min_char_freq: int = 2, bidirectional=True, pre_train_char_embed: str = None, | |||
requires_grad:bool=True, include_word_start_end:bool=True): | |||
""" | |||
r""" | |||
:param vocab: 词表 | |||
:param embed_size: LSTMCharEmbedding的输出维度。默认值为50. | |||
@@ -250,7 +250,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
self.requires_grad = requires_grad | |||
def forward(self, words): | |||
""" | |||
r""" | |||
输入words的index后,生成对应的words的表示。 | |||
:param words: [batch_size, max_len] | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -28,7 +28,7 @@ class ContextualEmbedding(TokenEmbedding): | |||
super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True): | |||
""" | |||
r""" | |||
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 | |||
:param datasets: DataSet对象 | |||
@@ -77,7 +77,7 @@ class ContextualEmbedding(TokenEmbedding): | |||
self._delete_model_weights() | |||
def _get_sent_reprs(self, words): | |||
""" | |||
r""" | |||
获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None | |||
:param words: torch.LongTensor | |||
@@ -101,11 +101,11 @@ class ContextualEmbedding(TokenEmbedding): | |||
@abstractmethod | |||
def _delete_model_weights(self): | |||
"""删除计算表示的模型以节省资源""" | |||
r"""删除计算表示的模型以节省资源""" | |||
raise NotImplementedError | |||
def remove_sentence_cache(self): | |||
""" | |||
r""" | |||
删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 | |||
:return: | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -23,7 +23,7 @@ from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | |||
class ElmoEmbedding(ContextualEmbedding): | |||
""" | |||
r""" | |||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 | |||
当前支持的使用名称初始化的模型: | |||
@@ -56,7 +56,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = True, | |||
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False): | |||
""" | |||
r""" | |||
:param vocab: 词表 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件, | |||
@@ -110,7 +110,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
return self.gamma.to(outputs) * outputs | |||
def set_mix_weights_requires_grad(self, flag=True): | |||
""" | |||
r""" | |||
当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用 | |||
该方法没有用。 | |||
@@ -130,7 +130,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
return outputs | |||
def forward(self, words: torch.LongTensor): | |||
""" | |||
r""" | |||
计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 | |||
被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens; | |||
backward_hiddens]. | |||
@@ -153,7 +153,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
class _ElmoModel(nn.Module): | |||
""" | |||
r""" | |||
该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 | |||
(1) 根据配置,加载模型; | |||
(2) 根据vocab,对模型中的embedding进行调整. 并将其正确初始化 | |||
@@ -295,7 +295,7 @@ class _ElmoModel(nn.Module): | |||
logger.info("There is no need to cache word representations, since no character information is used.") | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param words: batch_size x max_len | |||
:return: num_layers x batch_size x max_len x hidden_size | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
该模块中的Embedding主要用于随机初始化的embedding(更推荐使用 :class:`fastNLP.embeddings.StaticEmbedding` ),或按照预训练权重初始化Embedding。 | |||
""" | |||
@@ -17,7 +17,7 @@ from .utils import get_embeddings | |||
class Embedding(nn.Module): | |||
""" | |||
r""" | |||
词向量嵌入,支持输入多种方式初始化. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度. | |||
Example:: | |||
@@ -32,7 +32,7 @@ class Embedding(nn.Module): | |||
""" | |||
def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None): | |||
""" | |||
r""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: 支持传入Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 或传入Tensor, Embedding, numpy.ndarray等则直接使用该值初始化Embedding; | |||
@@ -62,7 +62,7 @@ class Embedding(nn.Module): | |||
self.word_dropout = word_dropout | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch, seq_len] | |||
:return: torch.Tensor : [batch, seq_len, embed_dim] | |||
""" | |||
@@ -93,7 +93,7 @@ class Embedding(nn.Module): | |||
@property | |||
def requires_grad(self): | |||
""" | |||
r""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
@@ -118,7 +118,7 @@ class Embedding(nn.Module): | |||
class TokenEmbedding(nn.Module): | |||
""" | |||
r""" | |||
fastNLP中各种Embedding的基类 | |||
""" | |||
@@ -136,7 +136,7 @@ class TokenEmbedding(nn.Module): | |||
self.dropout_layer = nn.Dropout(dropout) | |||
def drop_word(self, words): | |||
""" | |||
r""" | |||
按照设定随机将words设置为unknown_index。 | |||
:param torch.LongTensor words: batch_size x max_len | |||
@@ -151,7 +151,7 @@ class TokenEmbedding(nn.Module): | |||
return words | |||
def dropout(self, words): | |||
""" | |||
r""" | |||
对embedding后的word表示进行drop。 | |||
:param torch.FloatTensor words: batch_size x max_len x embed_size | |||
@@ -161,7 +161,7 @@ class TokenEmbedding(nn.Module): | |||
@property | |||
def requires_grad(self): | |||
""" | |||
r""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
@@ -189,14 +189,14 @@ class TokenEmbedding(nn.Module): | |||
@property | |||
def num_embedding(self) -> int: | |||
""" | |||
r""" | |||
这个值可能会大于实际的embedding矩阵的大小。 | |||
:return: | |||
""" | |||
return len(self._word_vocab) | |||
def get_word_vocab(self): | |||
""" | |||
r""" | |||
返回embedding的词典。 | |||
:return: Vocabulary | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -16,7 +16,7 @@ from .embedding import TokenEmbedding | |||
class StackEmbedding(TokenEmbedding): | |||
""" | |||
r""" | |||
支持将多个embedding集合成一个embedding。 | |||
Example:: | |||
@@ -31,7 +31,7 @@ class StackEmbedding(TokenEmbedding): | |||
""" | |||
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): | |||
""" | |||
r""" | |||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 | |||
@@ -54,7 +54,7 @@ class StackEmbedding(TokenEmbedding): | |||
self._embed_size = sum([embed.embed_size for embed in self.embeds]) | |||
def append(self, embed: TokenEmbedding): | |||
""" | |||
r""" | |||
添加一个embedding到结尾。 | |||
:param embed: | |||
:return: | |||
@@ -65,7 +65,7 @@ class StackEmbedding(TokenEmbedding): | |||
return self | |||
def pop(self): | |||
""" | |||
r""" | |||
弹出最后一个embed | |||
:return: | |||
""" | |||
@@ -75,14 +75,14 @@ class StackEmbedding(TokenEmbedding): | |||
@property | |||
def embed_size(self): | |||
""" | |||
r""" | |||
该Embedding输出的vector的最后一维的维度。 | |||
:return: | |||
""" | |||
return self._embed_size | |||
def forward(self, words): | |||
""" | |||
r""" | |||
得到多个embedding的结果,并把结果按照顺序concat起来。 | |||
:param words: batch_size x max_len | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -23,7 +23,7 @@ from ..modules.utils import _get_file_name_base_on_postfix | |||
class StaticEmbedding(TokenEmbedding): | |||
""" | |||
r""" | |||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | |||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | |||
当前支持自动下载的预训练vector有: | |||
@@ -72,7 +72,7 @@ class StaticEmbedding(TokenEmbedding): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
""" | |||
r""" | |||
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | |||
@@ -204,7 +204,7 @@ class StaticEmbedding(TokenEmbedding): | |||
return self.embedding.weight | |||
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | |||
""" | |||
r""" | |||
:param int num_embedding: embedding的entry的数量 | |||
:param int embedding_dim: embedding的维度大小 | |||
@@ -222,7 +222,7 @@ class StaticEmbedding(TokenEmbedding): | |||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | |||
error='ignore', init_method=None): | |||
""" | |||
r""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
@@ -309,7 +309,7 @@ class StaticEmbedding(TokenEmbedding): | |||
return vectors | |||
def forward(self, words): | |||
""" | |||
r""" | |||
传入words的index | |||
:param words: torch.LongTensor, [batch_size, max_len] | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -14,7 +14,7 @@ __all__ = [ | |||
def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True): | |||
""" | |||
r""" | |||
给定一个word的vocabulary生成character的vocabulary. | |||
:param vocab: 从vocab | |||
@@ -32,7 +32,7 @@ def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, inclu | |||
def get_embeddings(init_embed): | |||
""" | |||
r""" | |||
根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | |||
的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | |||
返回原对象。 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
用于IO的模块, 具体包括: | |||
1. 用于读入 embedding 的 :mod:`EmbedLoader <fastNLP.io.embed_loader>` 类, | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -14,7 +14,7 @@ from ..core._logger import logger | |||
class DataBundle: | |||
""" | |||
r""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | |||
Loader的load函数生成,可以通过以下的方法获取里面的内容 | |||
@@ -28,7 +28,7 @@ class DataBundle: | |||
""" | |||
def __init__(self, vocabs: dict = None, datasets: dict = None): | |||
""" | |||
r""" | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||
@@ -37,7 +37,7 @@ class DataBundle: | |||
self.datasets = datasets or {} | |||
def set_vocab(self, vocab, field_name): | |||
""" | |||
r""" | |||
向DataBunlde中增加vocab | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
@@ -49,7 +49,7 @@ class DataBundle: | |||
return self | |||
def set_dataset(self, dataset, name: str): | |||
""" | |||
r""" | |||
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet | |||
:param str name: dataset的名称 | |||
@@ -60,7 +60,7 @@ class DataBundle: | |||
return self | |||
def get_dataset(self, name: str) -> DataSet: | |||
""" | |||
r""" | |||
获取名为name的dataset | |||
:param str name: dataset的名称,一般为'train', 'dev', 'test' | |||
@@ -75,7 +75,7 @@ class DataBundle: | |||
raise KeyError(error_msg) | |||
def delete_dataset(self, name: str): | |||
""" | |||
r""" | |||
删除名为name的DataSet | |||
:param str name: | |||
@@ -85,7 +85,7 @@ class DataBundle: | |||
return self | |||
def get_vocab(self, field_name: str) -> Vocabulary: | |||
""" | |||
r""" | |||
获取field名为field_name对应的vocab | |||
:param str field_name: 名称 | |||
@@ -100,7 +100,7 @@ class DataBundle: | |||
raise KeyError(error_msg) | |||
def delete_vocab(self, field_name: str): | |||
""" | |||
r""" | |||
删除vocab | |||
:param str field_name: | |||
:return: self | |||
@@ -117,7 +117,7 @@ class DataBundle: | |||
return len(self.vocabs) | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | |||
""" | |||
r""" | |||
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | |||
data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | |||
@@ -142,7 +142,7 @@ class DataBundle: | |||
return self | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | |||
""" | |||
r""" | |||
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: | |||
data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True | |||
@@ -167,7 +167,7 @@ class DataBundle: | |||
return self | |||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | |||
""" | |||
r""" | |||
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | |||
:param str field_name: | |||
@@ -184,7 +184,7 @@ class DataBundle: | |||
return self | |||
def set_ignore_type(self, *field_names, flag=True, ignore_miss_dataset=True): | |||
""" | |||
r""" | |||
将DataBundle中所有的DataSet中名为*field_names的Field的ignore_type设置为flag状态 | |||
:param str field_names: | |||
@@ -202,7 +202,7 @@ class DataBundle: | |||
return self | |||
def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True): | |||
""" | |||
r""" | |||
将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. | |||
:param str field_name: | |||
@@ -219,7 +219,7 @@ class DataBundle: | |||
return self | |||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): | |||
""" | |||
r""" | |||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | |||
:param str field_name: | |||
@@ -241,7 +241,7 @@ class DataBundle: | |||
return self | |||
def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): | |||
""" | |||
r""" | |||
将DataBundle中所有DataSet中名为field_name的field删除掉. | |||
:param str field_name: | |||
@@ -261,7 +261,7 @@ class DataBundle: | |||
return self | |||
def iter_datasets(self) -> Union[str, DataSet]: | |||
""" | |||
r""" | |||
迭代data_bundle中的DataSet | |||
Example:: | |||
@@ -275,7 +275,7 @@ class DataBundle: | |||
yield name, dataset | |||
def get_dataset_names(self) -> List[str]: | |||
""" | |||
r""" | |||
返回DataBundle中DataSet的名称 | |||
:return: | |||
@@ -283,7 +283,7 @@ class DataBundle: | |||
return list(self.datasets.keys()) | |||
def get_vocab_names(self)->List[str]: | |||
""" | |||
r""" | |||
返回DataBundle中Vocabulary的名称 | |||
:return: | |||
@@ -291,7 +291,7 @@ class DataBundle: | |||
return list(self.vocabs.keys()) | |||
def iter_vocabs(self) -> Union[str, Vocabulary]: | |||
""" | |||
r""" | |||
迭代data_bundle中的DataSet | |||
Example: | |||
@@ -305,7 +305,7 @@ class DataBundle: | |||
yield field_name, vocab | |||
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): | |||
""" | |||
r""" | |||
对DataBundle中所有的dataset使用apply_field方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
@@ -330,7 +330,7 @@ class DataBundle: | |||
return self | |||
def apply(self, func, new_field_name:str, **kwargs): | |||
""" | |||
r""" | |||
对DataBundle中所有的dataset使用apply方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
@@ -349,7 +349,7 @@ class DataBundle: | |||
return self | |||
def add_collect_fn(self, fn, name=None): | |||
""" | |||
r""" | |||
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | |||
:param callable fn: | |||
@@ -360,7 +360,7 @@ class DataBundle: | |||
dataset.add_collect_fn(fn=fn, name=name) | |||
def delete_collect_fn(self, name=None): | |||
""" | |||
r""" | |||
删除DataSet中的collect_fn | |||
:param name: | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -32,7 +32,7 @@ class EmbeddingOption(Option): | |||
class EmbedLoader: | |||
""" | |||
r""" | |||
用于读取预训练的embedding, 读取结果可直接载入为模型参数。 | |||
""" | |||
@@ -42,7 +42,7 @@ class EmbedLoader: | |||
@staticmethod | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore', init_method=None): | |||
""" | |||
r""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
@@ -114,7 +114,7 @@ class EmbedLoader: | |||
@staticmethod | |||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore'): | |||
""" | |||
r""" | |||
从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented | |||
r"""undocumented | |||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | |||
""" | |||
@@ -11,7 +11,7 @@ from ..core import logger | |||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
""" | |||
r""" | |||
Construct a generator to read csv items. | |||
:param path: file path | |||
@@ -51,7 +51,7 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||
""" | |||
r""" | |||
Construct a generator to read json items. | |||
:param path: file path | |||
@@ -82,7 +82,7 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||
def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
""" | |||
r""" | |||
Construct a generator to read conll items. | |||
:param path: file path | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -138,7 +138,7 @@ FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | |||
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | |||
""" | |||
r""" | |||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | |||
1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | |||
@@ -183,7 +183,7 @@ def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | |||
def get_filepath(filepath): | |||
""" | |||
r""" | |||
如果filepath为文件夹, | |||
如果内含多个文件, 返回filepath | |||
@@ -210,7 +210,7 @@ def get_filepath(filepath): | |||
def get_cache_path(): | |||
""" | |||
r""" | |||
获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | |||
:return str: 存放路径 | |||
@@ -226,7 +226,7 @@ def get_cache_path(): | |||
def _get_base_url(name): | |||
""" | |||
r""" | |||
根据name返回下载的url地址。 | |||
:param str name: 支持dataset和embedding两种 | |||
@@ -252,7 +252,7 @@ def _get_base_url(name): | |||
def _get_embedding_url(embed_type, name): | |||
""" | |||
r""" | |||
给定embedding类似和名称,返回下载url | |||
:param str embed_type: 支持static, bert, elmo。即embedding的类型 | |||
@@ -276,7 +276,7 @@ def _get_embedding_url(embed_type, name): | |||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | |||
def _read_extend_url_file(filename, name)->str: | |||
""" | |||
r""" | |||
filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 | |||
:param str filename: 在默认的路径下寻找file这个文件 | |||
@@ -297,7 +297,7 @@ def _read_extend_url_file(filename, name)->str: | |||
return None | |||
def _get_dataset_url(name): | |||
""" | |||
r""" | |||
给定dataset的名称,返回下载url | |||
:param str name: 给定dataset的名称,比如imdb, sst-2等 | |||
@@ -317,7 +317,7 @@ def _get_dataset_url(name): | |||
def split_filename_suffix(filepath): | |||
""" | |||
r""" | |||
给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 | |||
:param filepath: 文件路径 | |||
@@ -330,7 +330,7 @@ def split_filename_suffix(filepath): | |||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
""" | |||
r""" | |||
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | |||
文件解压,将解压后的文件全部放在cache_dir文件夹中。 | |||
@@ -469,7 +469,7 @@ def ungzip_file(file: str, to: str, filename:str): | |||
def match_file(dir_name: str, cache_dir: Path) -> str: | |||
""" | |||
r""" | |||
匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。 | |||
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 | |||
三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, | |||
读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CLSBaseLoader", | |||
@@ -29,7 +29,7 @@ from ...core._logger import logger | |||
class CLSBaseLoader(Loader): | |||
""" | |||
r""" | |||
文本分类Loader的一个基类 | |||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | |||
@@ -118,7 +118,7 @@ def _split_dev(dataset_name, data_dir, dev_ratio=0.0, re_download=False, suffix= | |||
class AGsNewsLoader(CLSBaseLoader): | |||
def download(self): | |||
""" | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
@@ -131,7 +131,7 @@ class AGsNewsLoader(CLSBaseLoader): | |||
class DBPediaLoader(CLSBaseLoader): | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | |||
""" | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
@@ -155,7 +155,7 @@ class DBPediaLoader(CLSBaseLoader): | |||
class IMDBLoader(CLSBaseLoader): | |||
""" | |||
r""" | |||
原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 | |||
Example:: | |||
@@ -178,7 +178,7 @@ class IMDBLoader(CLSBaseLoader): | |||
super().__init__(sep='\t') | |||
def download(self, dev_ratio: float = 0.0, re_download=False): | |||
""" | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
http://www.aclweb.org/anthology/P11-1015 | |||
@@ -200,7 +200,7 @@ class IMDBLoader(CLSBaseLoader): | |||
class SSTLoader(Loader): | |||
""" | |||
r""" | |||
原始数据中内容应该为: | |||
Example:: | |||
@@ -225,7 +225,7 @@ class SSTLoader(Loader): | |||
super().__init__() | |||
def _load(self, path: str): | |||
""" | |||
r""" | |||
从path读取SST文件 | |||
:param str path: 文件路径 | |||
@@ -240,7 +240,7 @@ class SSTLoader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf | |||
@@ -253,7 +253,7 @@ class SSTLoader(Loader): | |||
class YelpFullLoader(CLSBaseLoader): | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | |||
""" | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
@@ -278,7 +278,7 @@ class YelpFullLoader(CLSBaseLoader): | |||
class YelpPolarityLoader(CLSBaseLoader): | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | |||
""" | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
@@ -302,7 +302,7 @@ class YelpPolarityLoader(CLSBaseLoader): | |||
class SST2Loader(Loader): | |||
""" | |||
r""" | |||
原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label | |||
Example:: | |||
@@ -327,7 +327,7 @@ class SST2Loader(Loader): | |||
super().__init__() | |||
def _load(self, path: str): | |||
"""从path读取SST2文件 | |||
r"""从path读取SST2文件 | |||
:param str path: 数据路径 | |||
:return: DataSet | |||
@@ -357,7 +357,7 @@ class SST2Loader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
r""" | |||
自动下载数据集,如果你使用了该数据集,请引用以下的文章 | |||
https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf | |||
:return: | |||
@@ -367,7 +367,7 @@ class SST2Loader(Loader): | |||
class ChnSentiCorpLoader(Loader): | |||
""" | |||
r""" | |||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | |||
一个制表符之后认为是句子 | |||
@@ -391,7 +391,7 @@ class ChnSentiCorpLoader(Loader): | |||
super().__init__() | |||
def _load(self, path: str): | |||
""" | |||
r""" | |||
从path中读取数据 | |||
:param path: | |||
@@ -411,7 +411,7 @@ class ChnSentiCorpLoader(Loader): | |||
return ds | |||
def download(self) -> str: | |||
""" | |||
r""" | |||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | |||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | |||
@@ -422,7 +422,7 @@ class ChnSentiCorpLoader(Loader): | |||
class THUCNewsLoader(Loader): | |||
""" | |||
r""" | |||
数据集简介:document-level分类任务,新闻10分类 | |||
原始数据内容为:每行一个sample,第一个'\t'之前为target,第一个'\t'之后为raw_words | |||
@@ -456,7 +456,7 @@ class THUCNewsLoader(Loader): | |||
return ds | |||
def download(self) -> str: | |||
""" | |||
r""" | |||
自动下载数据,该数据取自 | |||
http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews | |||
@@ -468,7 +468,7 @@ class THUCNewsLoader(Loader): | |||
class WeiboSenti100kLoader(Loader): | |||
""" | |||
r""" | |||
别名: | |||
数据集简介:微博sentiment classification,二分类 | |||
@@ -505,7 +505,7 @@ class WeiboSenti100kLoader(Loader): | |||
return ds | |||
def download(self) -> str: | |||
""" | |||
r""" | |||
自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ | |||
在 https://arxiv.org/abs/1906.08101 有使用 | |||
:return: | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConllLoader", | |||
@@ -26,7 +26,7 @@ from ...core.instance import Instance | |||
class ConllLoader(Loader): | |||
""" | |||
r""" | |||
ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||
Example:: | |||
@@ -56,7 +56,7 @@ class ConllLoader(Loader): | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=True): | |||
""" | |||
r""" | |||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
@@ -76,7 +76,7 @@ class ConllLoader(Loader): | |||
self.indexes = indexes | |||
def _load(self, path): | |||
""" | |||
r""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
@@ -90,7 +90,7 @@ class ConllLoader(Loader): | |||
class Conll2003Loader(ConllLoader): | |||
""" | |||
r""" | |||
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 | |||
Example:: | |||
@@ -123,7 +123,7 @@ class Conll2003Loader(ConllLoader): | |||
super(Conll2003Loader, self).__init__(headers=headers) | |||
def _load(self, path): | |||
""" | |||
r""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
@@ -148,7 +148,7 @@ class Conll2003Loader(ConllLoader): | |||
class Conll2003NERLoader(ConllLoader): | |||
""" | |||
r""" | |||
用于读取conll2003任务的NER数据。每一行有4列内容,空行意味着隔开两个句子 | |||
支持读取的内容如下 | |||
@@ -182,7 +182,7 @@ class Conll2003NERLoader(ConllLoader): | |||
super().__init__(headers=headers, indexes=[0, 3]) | |||
def _load(self, path): | |||
""" | |||
r""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
@@ -209,7 +209,7 @@ class Conll2003NERLoader(ConllLoader): | |||
class OntoNotesNERLoader(ConllLoader): | |||
""" | |||
r""" | |||
用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 | |||
https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 | |||
@@ -287,7 +287,7 @@ class OntoNotesNERLoader(ConllLoader): | |||
class CTBLoader(Loader): | |||
""" | |||
r""" | |||
支持加载的数据应该具备以下格式, 其中第二列为词语,第四列为pos tag,第七列为依赖树的head,第八列为依赖树的label | |||
Example:: | |||
@@ -328,7 +328,7 @@ class CTBLoader(Loader): | |||
return dataset | |||
def download(self): | |||
""" | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://catalog.ldc.upenn.edu/LDC2013T21 | |||
@@ -340,7 +340,7 @@ class CTBLoader(Loader): | |||
class CNNERLoader(Loader): | |||
def _load(self, path: str): | |||
""" | |||
r""" | |||
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample | |||
Example:: | |||
@@ -378,7 +378,7 @@ class CNNERLoader(Loader): | |||
class MsraNERLoader(CNNERLoader): | |||
""" | |||
r""" | |||
读取MSRA-NER数据,数据中的格式应该类似与下列的内容 | |||
Example:: | |||
@@ -416,7 +416,7 @@ class MsraNERLoader(CNNERLoader): | |||
super().__init__() | |||
def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: | |||
""" | |||
r""" | |||
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language | |||
Processing Bakeoff: Word Segmentation and Named Entity Recognition. | |||
@@ -466,7 +466,7 @@ class MsraNERLoader(CNNERLoader): | |||
class WeiboNERLoader(CNNERLoader): | |||
""" | |||
r""" | |||
读取WeiboNER数据,数据中的格式应该类似与下列的内容 | |||
Example:: | |||
@@ -494,7 +494,7 @@ class WeiboNERLoader(CNNERLoader): | |||
super().__init__() | |||
def download(self) -> str: | |||
""" | |||
r""" | |||
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for | |||
Chinese Social Media with Jointly Trained Embeddings. | |||
@@ -507,7 +507,7 @@ class WeiboNERLoader(CNNERLoader): | |||
class PeopleDailyNERLoader(CNNERLoader): | |||
""" | |||
r""" | |||
支持加载的数据格式如下 | |||
Example:: | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CoReferenceLoader", | |||
@@ -12,7 +12,7 @@ from .json import JsonLoader | |||
class CoReferenceLoader(JsonLoader): | |||
""" | |||
r""" | |||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||
Example:: | |||
@@ -38,7 +38,7 @@ class CoReferenceLoader(JsonLoader): | |||
"sentences": Const.RAW_WORDS(3)} | |||
def _load(self, path): | |||
""" | |||
r""" | |||
加载数据 | |||
:param path: 数据文件路径,文件为json | |||
@@ -54,7 +54,7 @@ class CoReferenceLoader(JsonLoader): | |||
return dataset | |||
def download(self): | |||
""" | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://www.aclweb.org/anthology/W12-4501 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CSVLoader", | |||
@@ -11,13 +11,13 @@ from ...core.instance import Instance | |||
class CSVLoader(Loader): | |||
""" | |||
r""" | |||
读取CSV格式的数据集, 返回 ``DataSet`` 。 | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=False): | |||
""" | |||
r""" | |||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CWSLoader" | |||
@@ -16,7 +16,7 @@ from ...core.instance import Instance | |||
class CWSLoader(Loader): | |||
""" | |||
r""" | |||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||
Example:: | |||
@@ -36,7 +36,7 @@ class CWSLoader(Loader): | |||
""" | |||
def __init__(self, dataset_name:str=None): | |||
""" | |||
r""" | |||
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||
""" | |||
@@ -57,7 +57,7 @@ class CWSLoader(Loader): | |||
return ds | |||
def download(self, dev_ratio=0.1, re_download=False)->str: | |||
""" | |||
r""" | |||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"JsonLoader" | |||
@@ -11,7 +11,7 @@ from ...core.instance import Instance | |||
class JsonLoader(Loader): | |||
""" | |||
r""" | |||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` | |||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"Loader" | |||
@@ -13,7 +13,7 @@ from ...core.dataset import DataSet | |||
class Loader: | |||
""" | |||
r""" | |||
各种数据 Loader 的基类,提供了 API 的参考. | |||
Loader支持以下的三个函数 | |||
@@ -27,7 +27,7 @@ class Loader: | |||
pass | |||
def _load(self, path: str) -> DataSet: | |||
""" | |||
r""" | |||
给定一个路径,返回读取的DataSet。 | |||
:param str path: 路径 | |||
@@ -71,7 +71,7 @@ class Loader: | |||
return data_bundle | |||
def download(self) -> str: | |||
""" | |||
r""" | |||
自动下载该数据集 | |||
:return: 下载后解压目录 | |||
@@ -80,7 +80,7 @@ class Loader: | |||
@staticmethod | |||
def _get_dataset_path(dataset_name): | |||
""" | |||
r""" | |||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) | |||
:param str dataset_name: 数据集的名称 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MNLILoader", | |||
@@ -26,7 +26,7 @@ from ...core.instance import Instance | |||
class MNLILoader(Loader): | |||
""" | |||
r""" | |||
读取的数据格式为: | |||
Example:: | |||
@@ -80,7 +80,7 @@ class MNLILoader(Loader): | |||
return ds | |||
def load(self, paths: str = None): | |||
""" | |||
r""" | |||
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | |||
test_mismatched.tsv, train.tsv文件夹 | |||
@@ -112,7 +112,7 @@ class MNLILoader(Loader): | |||
return data_bundle | |||
def download(self): | |||
""" | |||
r""" | |||
如果你使用了这个数据,请引用 | |||
https://www.nyu.edu/projects/bowman/multinli/paper.pdf | |||
@@ -123,7 +123,7 @@ class MNLILoader(Loader): | |||
class SNLILoader(JsonLoader): | |||
""" | |||
r""" | |||
文件每一行是一个sample,每一行都为一个json对象,其数据格式为: | |||
Example:: | |||
@@ -157,7 +157,7 @@ class SNLILoader(JsonLoader): | |||
}) | |||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
""" | |||
r""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据Loader初始化时传入的field决定。 | |||
@@ -187,7 +187,7 @@ class SNLILoader(JsonLoader): | |||
return data_bundle | |||
def download(self): | |||
""" | |||
r""" | |||
如果您的文章使用了这份数据,请引用 | |||
http://nlp.stanford.edu/pubs/snli_paper.pdf | |||
@@ -198,7 +198,7 @@ class SNLILoader(JsonLoader): | |||
class QNLILoader(JsonLoader): | |||
""" | |||
r""" | |||
第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、问题、句子和标签构成(以制表符分割),数据结构如下: | |||
Example:: | |||
@@ -250,7 +250,7 @@ class QNLILoader(JsonLoader): | |||
return ds | |||
def download(self): | |||
""" | |||
r""" | |||
如果您的实验使用到了该数据,请引用 | |||
https://arxiv.org/pdf/1809.05053.pdf | |||
@@ -261,7 +261,7 @@ class QNLILoader(JsonLoader): | |||
class RTELoader(Loader): | |||
""" | |||
r""" | |||
第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、句子1、句子2和标签构成(以制表符分割),数据结构如下: | |||
Example:: | |||
@@ -312,7 +312,7 @@ class RTELoader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
r""" | |||
如果您的实验使用到了该数据,请引用GLUE Benchmark | |||
https://openreview.net/pdf?id=rJ4km2R5t7 | |||
@@ -323,7 +323,7 @@ class RTELoader(Loader): | |||
class QuoraLoader(Loader): | |||
""" | |||
r""" | |||
Quora matching任务的数据集Loader | |||
支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 | |||
@@ -364,7 +364,7 @@ class QuoraLoader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://www.kaggle.com/c/quora-question-pairs/data | |||
@@ -375,7 +375,7 @@ class QuoraLoader(Loader): | |||
class CNXNLILoader(Loader): | |||
""" | |||
r""" | |||
数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize | |||
原始数据数据为: | |||
@@ -459,7 +459,7 @@ class CNXNLILoader(Loader): | |||
return data_bundle | |||
def download(self) -> str: | |||
""" | |||
r""" | |||
自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 | |||
在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf | |||
https://arxiv.org/pdf/1809.05053.pdf 有使用 | |||
@@ -470,7 +470,7 @@ class CNXNLILoader(Loader): | |||
class BQCorpusLoader(Loader): | |||
""" | |||
r""" | |||
别名: | |||
数据集简介:句子对二分类任务(判断是否具有相同的语义) | |||
原始数据结构为: | |||
@@ -511,7 +511,7 @@ class BQCorpusLoader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://github.com/ymcui/Chinese-BERT-wwm | |||
@@ -566,7 +566,7 @@ class LCQMCLoader(Loader): | |||
return ds | |||
def download(self): | |||
""" | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://github.com/ymcui/Chinese-BERT-wwm | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
该文件中的Loader主要用于读取问答式任务的数据 | |||
""" | |||
@@ -12,7 +12,7 @@ __all__ = ['CMRC2018Loader'] | |||
class CMRC2018Loader(Loader): | |||
""" | |||
r""" | |||
请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 | |||
读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 | |||
@@ -64,7 +64,7 @@ class CMRC2018Loader(Loader): | |||
return ds | |||
def download(self) -> str: | |||
""" | |||
r""" | |||
如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. | |||
:return: | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ExtCNNDMLoader" | |||
@@ -13,7 +13,7 @@ from .json import JsonLoader | |||
class ExtCNNDMLoader(JsonLoader): | |||
""" | |||
r""" | |||
读取之后的DataSet中的field情况为 | |||
.. csv-table:: | |||
@@ -30,7 +30,7 @@ class ExtCNNDMLoader(JsonLoader): | |||
super(ExtCNNDMLoader, self).__init__(fields=fields) | |||
def load(self, paths: Union[str, Dict[str, str]] = None): | |||
""" | |||
r""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 | |||
@@ -53,7 +53,7 @@ class ExtCNNDMLoader(JsonLoader): | |||
return data_bundle | |||
def download(self): | |||
""" | |||
r""" | |||
如果你使用了这个数据,请引用 | |||
https://arxiv.org/pdf/1506.03340.pdf | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
用于载入和保存模型 | |||
""" | |||
__all__ = [ | |||
@@ -10,7 +10,7 @@ import torch | |||
class ModelLoader: | |||
""" | |||
r""" | |||
用于读取模型 | |||
""" | |||
@@ -19,7 +19,7 @@ class ModelLoader: | |||
@staticmethod | |||
def load_pytorch(empty_model, model_path): | |||
""" | |||
r""" | |||
从 ".pkl" 文件读取 PyTorch 模型 | |||
:param empty_model: 初始化参数的 PyTorch 模型 | |||
@@ -29,7 +29,7 @@ class ModelLoader: | |||
@staticmethod | |||
def load_pytorch_model(model_path): | |||
""" | |||
r""" | |||
读取整个模型 | |||
:param str model_path: 模型保存的路径 | |||
@@ -38,7 +38,7 @@ class ModelLoader: | |||
class ModelSaver(object): | |||
""" | |||
r""" | |||
用于保存模型 | |||
Example:: | |||
@@ -49,14 +49,14 @@ class ModelSaver(object): | |||
""" | |||
def __init__(self, save_path): | |||
""" | |||
r""" | |||
:param save_path: 模型保存的路径 | |||
""" | |||
self.save_path = save_path | |||
def save_pytorch(self, model, param_only=True): | |||
""" | |||
r""" | |||
把 PyTorch 模型存入 ".pkl" 文件 | |||
:param model: PyTorch 模型 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 | |||
``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; | |||
``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CLSBasePipe", | |||
@@ -39,7 +39,7 @@ class CLSBasePipe(Pipe): | |||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | |||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | |||
""" | |||
r""" | |||
将DataBundle中的数据进行tokenize | |||
:param DataBundle data_bundle: | |||
@@ -54,7 +54,7 @@ class CLSBasePipe(Pipe): | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
r""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
@@ -83,7 +83,7 @@ class CLSBasePipe(Pipe): | |||
return data_bundle | |||
def process_from_file(self, paths) -> DataBundle: | |||
""" | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
:param paths: | |||
@@ -93,7 +93,7 @@ class CLSBasePipe(Pipe): | |||
class YelpFullPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
处理YelpFull的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
@@ -117,7 +117,7 @@ class YelpFullPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | |||
""" | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 | |||
@@ -136,7 +136,7 @@ class YelpFullPipe(CLSBasePipe): | |||
self.tag_map = None | |||
def process(self, data_bundle): | |||
""" | |||
r""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
@@ -157,7 +157,7 @@ class YelpFullPipe(CLSBasePipe): | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param paths: | |||
:return: DataBundle | |||
@@ -167,7 +167,7 @@ class YelpFullPipe(CLSBasePipe): | |||
class YelpPolarityPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
@@ -191,7 +191,7 @@ class YelpPolarityPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
""" | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
@@ -199,7 +199,7 @@ class YelpPolarityPipe(CLSBasePipe): | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param str paths: | |||
:return: DataBundle | |||
@@ -209,7 +209,7 @@ class YelpPolarityPipe(CLSBasePipe): | |||
class AGsNewsPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
处理AG's News的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field | |||
@@ -233,7 +233,7 @@ class AGsNewsPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
""" | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
@@ -241,7 +241,7 @@ class AGsNewsPipe(CLSBasePipe): | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param str paths: | |||
:return: DataBundle | |||
""" | |||
@@ -250,7 +250,7 @@ class AGsNewsPipe(CLSBasePipe): | |||
class DBPediaPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
处理DBPedia的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field | |||
@@ -274,7 +274,7 @@ class DBPediaPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
""" | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
@@ -282,7 +282,7 @@ class DBPediaPipe(CLSBasePipe): | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param str paths: | |||
:return: DataBundle | |||
""" | |||
@@ -291,7 +291,7 @@ class DBPediaPipe(CLSBasePipe): | |||
class SSTPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
经过该Pipe之后,DataSet中具备的field如下所示 | |||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | |||
@@ -315,7 +315,7 @@ class SSTPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||
""" | |||
r""" | |||
:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | |||
:param bool train_subtree: 是否将train集通过子树扩展数据。 | |||
@@ -339,7 +339,7 @@ class SSTPipe(CLSBasePipe): | |||
self.tag_map = None | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
r""" | |||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | |||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||
@@ -383,7 +383,7 @@ class SSTPipe(CLSBasePipe): | |||
class SST2Pipe(CLSBasePipe): | |||
""" | |||
r""" | |||
加载SST2的数据, 处理完成之后DataSet将拥有以下的field | |||
.. csv-table:: | |||
@@ -407,7 +407,7 @@ class SST2Pipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower=False, tokenizer='spacy'): | |||
""" | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
@@ -415,7 +415,7 @@ class SST2Pipe(CLSBasePipe): | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 | |||
:return: DataBundle | |||
@@ -425,7 +425,7 @@ class SST2Pipe(CLSBasePipe): | |||
class IMDBPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
经过本Pipe处理后DataSet将如下 | |||
.. csv-table:: 输出DataSet的field | |||
@@ -452,7 +452,7 @@ class IMDBPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
""" | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
@@ -461,7 +461,7 @@ class IMDBPipe(CLSBasePipe): | |||
self.lower = lower | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
r""" | |||
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | |||
.. csv-table:: 输入DataSet的field | |||
@@ -489,7 +489,7 @@ class IMDBPipe(CLSBasePipe): | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
@@ -502,7 +502,7 @@ class IMDBPipe(CLSBasePipe): | |||
class ChnSentiCorpPipe(Pipe): | |||
""" | |||
r""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
@@ -526,7 +526,7 @@ class ChnSentiCorpPipe(Pipe): | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
""" | |||
r""" | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
@@ -541,7 +541,7 @@ class ChnSentiCorpPipe(Pipe): | |||
self.trigrams = trigrams | |||
def _tokenize(self, data_bundle): | |||
""" | |||
r""" | |||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||
:param data_bundle: | |||
@@ -551,7 +551,7 @@ class ChnSentiCorpPipe(Pipe): | |||
return data_bundle | |||
def process(self, data_bundle:DataBundle): | |||
""" | |||
r""" | |||
可以处理的DataSet应该具备以下的field | |||
.. csv-table:: | |||
@@ -596,7 +596,7 @@ class ChnSentiCorpPipe(Pipe): | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
@@ -609,7 +609,7 @@ class ChnSentiCorpPipe(Pipe): | |||
class THUCNewsPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
@@ -658,7 +658,7 @@ class THUCNewsPipe(CLSBasePipe): | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
r""" | |||
可处理的DataSet应具备如下的field | |||
.. csv-table:: | |||
@@ -710,7 +710,7 @@ class THUCNewsPipe(CLSBasePipe): | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
@@ -721,7 +721,7 @@ class THUCNewsPipe(CLSBasePipe): | |||
class WeiboSenti100kPipe(CLSBasePipe): | |||
""" | |||
r""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
@@ -766,7 +766,7 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
r""" | |||
可处理的DataSet应具备以下的field | |||
.. csv-table:: | |||
@@ -814,7 +814,7 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"Conll2003NERPipe", | |||
@@ -21,7 +21,7 @@ from ...core.vocabulary import Vocabulary | |||
class _NERPipe(Pipe): | |||
""" | |||
r""" | |||
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||
Vocabulary转换为index。 | |||
@@ -31,7 +31,7 @@ class _NERPipe(Pipe): | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||
""" | |||
r""" | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
@@ -45,7 +45,7 @@ class _NERPipe(Pipe): | |||
self.lower = lower | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
r""" | |||
支持的DataSet的field为 | |||
.. csv-table:: | |||
@@ -80,7 +80,7 @@ class _NERPipe(Pipe): | |||
class Conll2003NERPipe(_NERPipe): | |||
""" | |||
r""" | |||
Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||
Vocabulary转换为index。 | |||
@@ -110,7 +110,7 @@ class Conll2003NERPipe(_NERPipe): | |||
""" | |||
def process_from_file(self, paths) -> DataBundle: | |||
""" | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | |||
:return: DataBundle | |||
@@ -123,7 +123,7 @@ class Conll2003NERPipe(_NERPipe): | |||
class Conll2003Pipe(Pipe): | |||
""" | |||
r""" | |||
经过该Pipe后,DataSet中的内容如下 | |||
.. csv-table:: | |||
@@ -148,7 +148,7 @@ class Conll2003Pipe(Pipe): | |||
""" | |||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||
""" | |||
r""" | |||
:param str chunk_encoding_type: 支持bioes, bio。 | |||
:param str ner_encoding_type: 支持bioes, bio。 | |||
@@ -169,7 +169,7 @@ class Conll2003Pipe(Pipe): | |||
self.lower = lower | |||
def process(self, data_bundle) -> DataBundle: | |||
""" | |||
r""" | |||
输入的DataSet应该类似于如下的形式 | |||
.. csv-table:: | |||
@@ -210,7 +210,7 @@ class Conll2003Pipe(Pipe): | |||
return data_bundle | |||
def process_from_file(self, paths): | |||
""" | |||
r""" | |||
:param paths: | |||
:return: | |||
@@ -220,7 +220,7 @@ class Conll2003Pipe(Pipe): | |||
class OntoNotesNERPipe(_NERPipe): | |||
""" | |||
r""" | |||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | |||
.. csv-table:: | |||
@@ -252,7 +252,7 @@ class OntoNotesNERPipe(_NERPipe): | |||
class _CNNERPipe(Pipe): | |||
""" | |||
r""" | |||
中文NER任务的处理Pipe, 该Pipe会(1)复制raw_chars列,并命名为chars; (2)在chars, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将chars,target列根据相应的 | |||
Vocabulary转换为index。 | |||
@@ -263,7 +263,7 @@ class _CNNERPipe(Pipe): | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||
""" | |||
r""" | |||
:param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
@@ -284,7 +284,7 @@ class _CNNERPipe(Pipe): | |||
self.trigrams = trigrams | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
r""" | |||
支持的DataSet的field为 | |||
.. csv-table:: | |||
@@ -335,7 +335,7 @@ class _CNNERPipe(Pipe): | |||
class MsraNERPipe(_CNNERPipe): | |||
""" | |||
r""" | |||
处理MSRA-NER的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
@@ -367,7 +367,7 @@ class MsraNERPipe(_CNNERPipe): | |||
class PeopleDailyPipe(_CNNERPipe): | |||
""" | |||
r""" | |||
处理people daily的ner的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
@@ -399,7 +399,7 @@ class PeopleDailyPipe(_CNNERPipe): | |||
class WeiboNERPipe(_CNNERPipe): | |||
""" | |||
r""" | |||
处理weibo的ner的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CoReferencePipe" | |||
@@ -16,7 +16,7 @@ from ...core.const import Const | |||
class CoReferencePipe(Pipe): | |||
""" | |||
r""" | |||
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | |||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||
@@ -45,7 +45,7 @@ class CoReferencePipe(Pipe): | |||
self.config = config | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
r""" | |||
对load进来的数据进一步处理原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters | |||
.. csv-table:: | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CWSPipe" | |||
@@ -15,7 +15,7 @@ from ...core.const import Const | |||
def _word_lens_to_bmes(word_lens): | |||
""" | |||
r""" | |||
:param list word_lens: List[int], 每个词语的长度 | |||
:return: List[str], BMES的序列 | |||
@@ -32,7 +32,7 @@ def _word_lens_to_bmes(word_lens): | |||
def _word_lens_to_segapp(word_lens): | |||
""" | |||
r""" | |||
:param list word_lens: List[int], 每个词语的长度 | |||
:return: List[str], BMES的序列 | |||
@@ -48,7 +48,7 @@ def _word_lens_to_segapp(word_lens): | |||
def _alpha_span_to_special_tag(span): | |||
""" | |||
r""" | |||
将span替换成特殊的字符 | |||
:param str span: | |||
@@ -63,7 +63,7 @@ def _alpha_span_to_special_tag(span): | |||
def _find_and_replace_alpha_spans(line): | |||
""" | |||
r""" | |||
传入原始句子,替换其中的字母为特殊标记 | |||
:param str line:原始数据 | |||
@@ -82,7 +82,7 @@ def _find_and_replace_alpha_spans(line): | |||
def _digit_span_to_special_tag(span): | |||
""" | |||
r""" | |||
:param str span: 需要替换的str | |||
:return: | |||
@@ -108,7 +108,7 @@ def _digit_span_to_special_tag(span): | |||
def _find_and_replace_digit_spans(line): | |||
""" | |||
r""" | |||
only consider words start with number, contains '.', characters. | |||
If ends with space, will be processed | |||
@@ -134,7 +134,7 @@ def _find_and_replace_digit_spans(line): | |||
class CWSPipe(Pipe): | |||
""" | |||
r""" | |||
对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | |||
.. csv-table:: | |||
@@ -158,7 +158,7 @@ class CWSPipe(Pipe): | |||
""" | |||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | |||
""" | |||
r""" | |||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | |||
:param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp | |||
@@ -178,7 +178,7 @@ class CWSPipe(Pipe): | |||
self.replace_num_alpha = replace_num_alpha | |||
def _tokenize(self, data_bundle): | |||
""" | |||
r""" | |||
将data_bundle中的'chars'列切分成一个一个的word. | |||
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | |||
@@ -216,7 +216,7 @@ class CWSPipe(Pipe): | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
r""" | |||
可以处理的DataSet需要包含raw_words列 | |||
.. csv-table:: | |||
@@ -268,7 +268,7 @@ class CWSPipe(Pipe): | |||
return data_bundle | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
""" | |||
r""" | |||
:param str paths: | |||
:return: | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MatchingBertPipe", | |||
@@ -37,7 +37,7 @@ from ...core.vocabulary import Vocabulary | |||
class MatchingBertPipe(Pipe): | |||
""" | |||
r""" | |||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
@@ -65,7 +65,7 @@ class MatchingBertPipe(Pipe): | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
""" | |||
r""" | |||
:param bool lower: 是否将word小写化。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
@@ -76,7 +76,7 @@ class MatchingBertPipe(Pipe): | |||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
r""" | |||
:param DataBundle data_bundle: DataBundle. | |||
:param list field_names: List[str], 需要tokenize的field名称 | |||
@@ -90,7 +90,7 @@ class MatchingBertPipe(Pipe): | |||
return data_bundle | |||
def process(self, data_bundle): | |||
""" | |||
r""" | |||
输入的data_bundle中的dataset需要具有以下结构: | |||
.. csv-table:: | |||
@@ -201,7 +201,7 @@ class MNLIBertPipe(MatchingBertPipe): | |||
class MatchingPipe(Pipe): | |||
""" | |||
r""" | |||
Matching任务的Pipe。输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
@@ -229,7 +229,7 @@ class MatchingPipe(Pipe): | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
""" | |||
r""" | |||
:param bool lower: 是否将所有raw_words转为小写。 | |||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | |||
@@ -240,7 +240,7 @@ class MatchingPipe(Pipe): | |||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
r""" | |||
:param ~fastNLP.DataBundle data_bundle: DataBundle. | |||
:param list field_names: List[str], 需要tokenize的field名称 | |||
@@ -254,7 +254,7 @@ class MatchingPipe(Pipe): | |||
return data_bundle | |||
def process(self, data_bundle): | |||
""" | |||
r""" | |||
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | |||
.. csv-table:: | |||
@@ -433,7 +433,7 @@ class GranularizePipe(Pipe): | |||
self.task = task | |||
def _granularize(self, data_bundle, tag_map): | |||
""" | |||
r""" | |||
该函数对data_bundle中'target'列中的内容进行转换。 | |||
:param data_bundle: | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"Pipe", | |||
@@ -8,7 +8,7 @@ from .. import DataBundle | |||
class Pipe: | |||
""" | |||
r""" | |||
Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe | |||
文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 | |||
@@ -23,7 +23,7 @@ class Pipe: | |||
""" | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
r""" | |||
对输入的DataBundle进行处理,然后返回该DataBundle。 | |||
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 | |||
@@ -32,7 +32,7 @@ class Pipe: | |||
raise NotImplementedError | |||
def process_from_file(self, paths) -> DataBundle: | |||
""" | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
:param paths: | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
本文件中的Pipe主要用于处理问答任务的数据。 | |||
""" | |||
@@ -17,7 +17,7 @@ __all__ = ['CMRC2018BertPipe'] | |||
def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | |||
""" | |||
r""" | |||
处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 | |||
会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | |||
@@ -78,7 +78,7 @@ def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | |||
class CMRC2018BertPipe(Pipe): | |||
""" | |||
r""" | |||
处理之后的DataSet将新增以下的field(传入的field仍然保留) | |||
.. csv-table:: | |||
@@ -111,7 +111,7 @@ class CMRC2018BertPipe(Pipe): | |||
self.max_len = max_len | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
r""" | |||
传入的DataSet应该具备以下的field | |||
.. csv-table:: | |||
@@ -1,197 +1,197 @@ | |||
"""undocumented""" | |||
import os | |||
import numpy as np | |||
from .pipe import Pipe | |||
from .utils import _drop_empty_instance | |||
from ..loader.summarization import ExtCNNDMLoader | |||
from ..data_bundle import DataBundle | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ...core._logger import logger | |||
WORD_PAD = "[PAD]" | |||
WORD_UNK = "[UNK]" | |||
DOMAIN_UNK = "X" | |||
TAG_UNK = "X" | |||
class ExtCNNDMPipe(Pipe): | |||
""" | |||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | |||
""" | |||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||
""" | |||
:param vocab_size: int, 词表大小 | |||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | |||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | |||
:param vocab_path: str, 外部词表路径 | |||
:param domain: bool, 是否需要建立domain词表 | |||
""" | |||
self.vocab_size = vocab_size | |||
self.vocab_path = vocab_path | |||
self.sent_max_len = sent_max_len | |||
self.doc_max_timesteps = doc_max_timesteps | |||
self.domain = domain | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication" | |||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||
["..."], ["..."], [], "cnndm" | |||
:param data_bundle: | |||
:return: 处理得到的数据包括 | |||
.. csv-table:: | |||
:header: "text_wd", "words", "seq_len", "target" | |||
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] | |||
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0] | |||
[[""],...,[""]], [[],...,[]], [], [] | |||
""" | |||
if self.vocab_path is None: | |||
error_msg = 'vocab file is not defined!' | |||
logger.error(error_msg) | |||
raise RuntimeError(error_msg) | |||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | |||
# pad document | |||
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||
# set input and target | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN) | |||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||
word_list = [] | |||
with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | |||
cnt = 2 # pad and unk | |||
for line in vocab_f: | |||
pieces = line.split("\t") | |||
word_list.append(pieces[0]) | |||
cnt += 1 | |||
if cnt > self.vocab_size: | |||
break | |||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.add_word_lst(word_list) | |||
vocabs.build_vocab() | |||
data_bundle.set_vocab(vocabs, "vocab") | |||
if self.domain is True: | |||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") | |||
data_bundle.set_vocab(domaindict, "domain") | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param paths: dict or string | |||
:return: DataBundle | |||
""" | |||
loader = ExtCNNDMLoader() | |||
if self.vocab_path is None: | |||
if paths is None: | |||
paths = loader.download() | |||
if not os.path.isdir(paths): | |||
error_msg = 'vocab file is not defined!' | |||
logger.error(error_msg) | |||
raise RuntimeError(error_msg) | |||
self.vocab_path = os.path.join(paths, 'vocab') | |||
db = loader.load(paths=paths) | |||
db = self.process(db) | |||
for ds in db.datasets.values(): | |||
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
return db | |||
def _lower_text(text_list): | |||
return [text.lower() for text in text_list] | |||
def _split_list(text_list): | |||
return [text.split() for text in text_list] | |||
def _convert_label(label, sent_len): | |||
np_label = np.zeros(sent_len, dtype=int) | |||
if label != []: | |||
np_label[np.array(label)] = 1 | |||
return np_label.tolist() | |||
def _pad_sent(text_wd, sent_max_len): | |||
pad_text_wd = [] | |||
for sent_wd in text_wd: | |||
if len(sent_wd) < sent_max_len: | |||
pad_num = sent_max_len - len(sent_wd) | |||
sent_wd.extend([WORD_PAD] * pad_num) | |||
else: | |||
sent_wd = sent_wd[:sent_max_len] | |||
pad_text_wd.append(sent_wd) | |||
return pad_text_wd | |||
def _token_mask(text_wd, sent_max_len): | |||
token_mask_list = [] | |||
for sent_wd in text_wd: | |||
token_num = len(sent_wd) | |||
if token_num < sent_max_len: | |||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
else: | |||
mask = [1] * sent_max_len | |||
token_mask_list.append(mask) | |||
return token_mask_list | |||
def _pad_label(label, doc_max_timesteps): | |||
text_len = len(label) | |||
if text_len < doc_max_timesteps: | |||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_label = label[:doc_max_timesteps] | |||
return pad_label | |||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
padding = [WORD_PAD] * sent_max_len | |||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_text = text_wd[:doc_max_timesteps] | |||
return pad_text | |||
def _sent_mask(text_wd, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
sent_mask = [1] * doc_max_timesteps | |||
return sent_mask | |||
r"""undocumented""" | |||
import os | |||
import numpy as np | |||
from .pipe import Pipe | |||
from .utils import _drop_empty_instance | |||
from ..loader.summarization import ExtCNNDMLoader | |||
from ..data_bundle import DataBundle | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ...core._logger import logger | |||
WORD_PAD = "[PAD]" | |||
WORD_UNK = "[UNK]" | |||
DOMAIN_UNK = "X" | |||
TAG_UNK = "X" | |||
class ExtCNNDMPipe(Pipe): | |||
r""" | |||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | |||
""" | |||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||
r""" | |||
:param vocab_size: int, 词表大小 | |||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | |||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | |||
:param vocab_path: str, 外部词表路径 | |||
:param domain: bool, 是否需要建立domain词表 | |||
""" | |||
self.vocab_size = vocab_size | |||
self.vocab_path = vocab_path | |||
self.sent_max_len = sent_max_len | |||
self.doc_max_timesteps = doc_max_timesteps | |||
self.domain = domain | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication" | |||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||
["..."], ["..."], [], "cnndm" | |||
:param data_bundle: | |||
:return: 处理得到的数据包括 | |||
.. csv-table:: | |||
:header: "text_wd", "words", "seq_len", "target" | |||
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] | |||
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0] | |||
[[""],...,[""]], [[],...,[]], [], [] | |||
""" | |||
if self.vocab_path is None: | |||
error_msg = 'vocab file is not defined!' | |||
logger.error(error_msg) | |||
raise RuntimeError(error_msg) | |||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | |||
# pad document | |||
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||
# set input and target | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN) | |||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||
word_list = [] | |||
with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | |||
cnt = 2 # pad and unk | |||
for line in vocab_f: | |||
pieces = line.split("\t") | |||
word_list.append(pieces[0]) | |||
cnt += 1 | |||
if cnt > self.vocab_size: | |||
break | |||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.add_word_lst(word_list) | |||
vocabs.build_vocab() | |||
data_bundle.set_vocab(vocabs, "vocab") | |||
if self.domain is True: | |||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") | |||
data_bundle.set_vocab(domaindict, "domain") | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: dict or string | |||
:return: DataBundle | |||
""" | |||
loader = ExtCNNDMLoader() | |||
if self.vocab_path is None: | |||
if paths is None: | |||
paths = loader.download() | |||
if not os.path.isdir(paths): | |||
error_msg = 'vocab file is not defined!' | |||
logger.error(error_msg) | |||
raise RuntimeError(error_msg) | |||
self.vocab_path = os.path.join(paths, 'vocab') | |||
db = loader.load(paths=paths) | |||
db = self.process(db) | |||
for ds in db.datasets.values(): | |||
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
return db | |||
def _lower_text(text_list): | |||
return [text.lower() for text in text_list] | |||
def _split_list(text_list): | |||
return [text.split() for text in text_list] | |||
def _convert_label(label, sent_len): | |||
np_label = np.zeros(sent_len, dtype=int) | |||
if label != []: | |||
np_label[np.array(label)] = 1 | |||
return np_label.tolist() | |||
def _pad_sent(text_wd, sent_max_len): | |||
pad_text_wd = [] | |||
for sent_wd in text_wd: | |||
if len(sent_wd) < sent_max_len: | |||
pad_num = sent_max_len - len(sent_wd) | |||
sent_wd.extend([WORD_PAD] * pad_num) | |||
else: | |||
sent_wd = sent_wd[:sent_max_len] | |||
pad_text_wd.append(sent_wd) | |||
return pad_text_wd | |||
def _token_mask(text_wd, sent_max_len): | |||
token_mask_list = [] | |||
for sent_wd in text_wd: | |||
token_num = len(sent_wd) | |||
if token_num < sent_max_len: | |||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
else: | |||
mask = [1] * sent_max_len | |||
token_mask_list.append(mask) | |||
return token_mask_list | |||
def _pad_label(label, doc_max_timesteps): | |||
text_len = len(label) | |||
if text_len < doc_max_timesteps: | |||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_label = label[:doc_max_timesteps] | |||
return pad_label | |||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
padding = [WORD_PAD] * sent_max_len | |||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_text = text_wd[:doc_max_timesteps] | |||
return pad_text | |||
def _sent_mask(text_wd, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
sent_mask = [1] * doc_max_timesteps | |||
return sent_mask | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"iob2", | |||
@@ -15,7 +15,7 @@ from ...core._logger import logger | |||
def iob2(tags: List[str]) -> List[str]: | |||
""" | |||
r""" | |||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 | |||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | |||
@@ -39,7 +39,7 @@ def iob2(tags: List[str]) -> List[str]: | |||
def iob2bioes(tags: List[str]) -> List[str]: | |||
""" | |||
r""" | |||
将iob的tag转换为bioes编码 | |||
:param tags: | |||
:return: | |||
@@ -66,7 +66,7 @@ def iob2bioes(tags: List[str]) -> List[str]: | |||
def get_tokenizer(tokenize_method: str, lang='en'): | |||
""" | |||
r""" | |||
:param str tokenize_method: 获取tokenzier方法 | |||
:param str lang: 语言,当前仅支持en | |||
@@ -100,7 +100,7 @@ def _raw_split(sent): | |||
def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Const.TARGET): | |||
""" | |||
r""" | |||
在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。 | |||
:param ~fastNLP.DataBundle data_bundle: | |||
@@ -143,7 +143,7 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con | |||
def _add_words_field(data_bundle, lower=False): | |||
""" | |||
r""" | |||
给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化 | |||
:param data_bundle: | |||
@@ -159,7 +159,7 @@ def _add_words_field(data_bundle, lower=False): | |||
def _add_chars_field(data_bundle, lower=False): | |||
""" | |||
r""" | |||
给data_bundle中的dataset中复制一列chars. 并根据lower参数判断是否需要小写化 | |||
:param data_bundle: | |||
@@ -175,7 +175,7 @@ def _add_chars_field(data_bundle, lower=False): | |||
def _drop_empty_instance(data_bundle, field_name): | |||
""" | |||
r""" | |||
删除data_bundle的DataSet中存在的某个field为空的情况 | |||
:param ~fastNLP.DataBundle data_bundle: | |||
@@ -201,7 +201,7 @@ def _drop_empty_instance(data_bundle, field_name): | |||
def _granularize(data_bundle, tag_map): | |||
""" | |||
r""" | |||
该函数对data_bundle中'target'列中的内容进行转换。 | |||
:param data_bundle: | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -15,7 +15,7 @@ from ..core import logger | |||
def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||
""" | |||
r""" | |||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | |||
{ | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、 | |||
:class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [] | |||
@@ -8,7 +8,7 @@ from ..modules.decoder.mlp import MLP | |||
class BaseModel(torch.nn.Module): | |||
"""Base PyTorch model for all models. | |||
r"""Base PyTorch model for all models. | |||
""" | |||
def __init__(self): | |||
@@ -22,7 +22,7 @@ class BaseModel(torch.nn.Module): | |||
class NaiveClassifier(BaseModel): | |||
""" | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
fastNLP提供了BERT应用到五个下游任务的模型代码,可以直接调用。这五个任务分别为 | |||
- 文本分类任务: :class:`~fastNLP.models.BertForSequenceClassification` | |||
@@ -43,12 +43,12 @@ from ..embeddings import BertEmbedding | |||
class BertForSequenceClassification(BaseModel): | |||
""" | |||
r""" | |||
BERT model for classification. | |||
""" | |||
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1): | |||
""" | |||
r""" | |||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | |||
:param int num_labels: 文本分类类别数目,默认值为2. | |||
@@ -69,7 +69,7 @@ class BertForSequenceClassification(BaseModel): | |||
warnings.warn(warn_msg) | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | |||
""" | |||
@@ -80,7 +80,7 @@ class BertForSequenceClassification(BaseModel): | |||
return {Const.OUTPUT: logits} | |||
def predict(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | |||
""" | |||
@@ -89,12 +89,12 @@ class BertForSequenceClassification(BaseModel): | |||
class BertForSentenceMatching(BaseModel): | |||
""" | |||
r""" | |||
BERT model for sentence matching. | |||
""" | |||
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1): | |||
""" | |||
r""" | |||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | |||
:param int num_labels: Matching任务类别数目,默认值为2. | |||
@@ -114,7 +114,7 @@ class BertForSentenceMatching(BaseModel): | |||
warnings.warn(warn_msg) | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | |||
""" | |||
@@ -125,7 +125,7 @@ class BertForSentenceMatching(BaseModel): | |||
return {Const.OUTPUT: logits} | |||
def predict(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | |||
""" | |||
@@ -134,12 +134,12 @@ class BertForSentenceMatching(BaseModel): | |||
class BertForMultipleChoice(BaseModel): | |||
""" | |||
r""" | |||
BERT model for multiple choice. | |||
""" | |||
def __init__(self, embed: BertEmbedding, num_choices=2, dropout=0.1): | |||
""" | |||
r""" | |||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | |||
:param int num_choices: 多选任务选项数目,默认值为2. | |||
@@ -160,7 +160,7 @@ class BertForMultipleChoice(BaseModel): | |||
warnings.warn(warn_msg) | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, num_choices, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, num_choices] | |||
""" | |||
@@ -175,7 +175,7 @@ class BertForMultipleChoice(BaseModel): | |||
return {Const.OUTPUT: reshaped_logits} | |||
def predict(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, num_choices, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | |||
""" | |||
@@ -184,12 +184,12 @@ class BertForMultipleChoice(BaseModel): | |||
class BertForTokenClassification(BaseModel): | |||
""" | |||
r""" | |||
BERT model for token classification. | |||
""" | |||
def __init__(self, embed: BertEmbedding, num_labels, dropout=0.1): | |||
""" | |||
r""" | |||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | |||
:param int num_labels: 序列标注标签数目,无默认值. | |||
@@ -210,7 +210,7 @@ class BertForTokenClassification(BaseModel): | |||
warnings.warn(warn_msg) | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, seq_len, num_labels] | |||
""" | |||
@@ -221,7 +221,7 @@ class BertForTokenClassification(BaseModel): | |||
return {Const.OUTPUT: logits} | |||
def predict(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, seq_len] | |||
""" | |||
@@ -230,12 +230,12 @@ class BertForTokenClassification(BaseModel): | |||
class BertForQuestionAnswering(BaseModel): | |||
""" | |||
r""" | |||
用于做Q&A的Bert模型,如果是Squad2.0请将BertEmbedding的include_cls_sep设置为True,Squad1.0或CMRC则设置为False | |||
""" | |||
def __init__(self, embed: BertEmbedding): | |||
""" | |||
r""" | |||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | |||
:param int num_labels: 抽取式QA列数,默认值为2(即第一列为start_span, 第二列为end_span). | |||
@@ -246,7 +246,7 @@ class BertForQuestionAnswering(BaseModel): | |||
self.qa_outputs = nn.Linear(self.bert.embedding_dim, 2) | |||
def forward(self, words): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
Biaffine Dependency Parser 的 Pytorch 实现. | |||
""" | |||
__all__ = [ | |||
@@ -26,7 +26,7 @@ from ..modules.utils import initial_parameter | |||
def _mst(scores): | |||
""" | |||
r""" | |||
with some modification to support parser output for MST decoding | |||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||
""" | |||
@@ -85,7 +85,7 @@ def _mst(scores): | |||
def _find_cycle(vertices, edges): | |||
""" | |||
r""" | |||
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | |||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | |||
""" | |||
@@ -129,7 +129,7 @@ def _find_cycle(vertices, edges): | |||
class GraphParser(BaseModel): | |||
""" | |||
r""" | |||
基于图的parser base class, 支持贪婪解码和最大生成树解码 | |||
""" | |||
@@ -138,7 +138,7 @@ class GraphParser(BaseModel): | |||
@staticmethod | |||
def greedy_decoder(arc_matrix, mask=None): | |||
""" | |||
r""" | |||
贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 | |||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||
@@ -157,7 +157,7 @@ class GraphParser(BaseModel): | |||
@staticmethod | |||
def mst_decoder(arc_matrix, mask=None): | |||
""" | |||
r""" | |||
用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 | |||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||
@@ -178,13 +178,13 @@ class GraphParser(BaseModel): | |||
class ArcBiaffine(nn.Module): | |||
""" | |||
r""" | |||
Biaffine Dependency Parser 的子模块, 用于构建预测边的图 | |||
""" | |||
def __init__(self, hidden_size, bias=True): | |||
""" | |||
r""" | |||
:param hidden_size: 输入的特征维度 | |||
:param bias: 是否使用bias. Default: ``True`` | |||
@@ -199,7 +199,7 @@ class ArcBiaffine(nn.Module): | |||
initial_parameter(self) | |||
def forward(self, head, dep): | |||
""" | |||
r""" | |||
:param head: arc-head tensor [batch, length, hidden] | |||
:param dep: arc-dependent tensor [batch, length, hidden] | |||
@@ -213,13 +213,13 @@ class ArcBiaffine(nn.Module): | |||
class LabelBilinear(nn.Module): | |||
""" | |||
r""" | |||
Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 | |||
""" | |||
def __init__(self, in1_features, in2_features, num_label, bias=True): | |||
""" | |||
r""" | |||
:param in1_features: 输入的特征1维度 | |||
:param in2_features: 输入的特征2维度 | |||
@@ -231,7 +231,7 @@ class LabelBilinear(nn.Module): | |||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||
def forward(self, x1, x2): | |||
""" | |||
r""" | |||
:param x1: [batch, seq_len, hidden] 输入特征1, 即label-head | |||
:param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep | |||
@@ -243,7 +243,7 @@ class LabelBilinear(nn.Module): | |||
class BiaffineParser(GraphParser): | |||
""" | |||
r""" | |||
Biaffine Dependency Parser 实现. | |||
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . | |||
@@ -261,7 +261,7 @@ class BiaffineParser(GraphParser): | |||
dropout=0.3, | |||
encoder='lstm', | |||
use_greedy_infer=False): | |||
""" | |||
r""" | |||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
@@ -347,7 +347,7 @@ class BiaffineParser(GraphParser): | |||
nn.init.normal_(p, 0, 0.1) | |||
def forward(self, words1, words2, seq_len, target1=None): | |||
"""模型forward阶段 | |||
r"""模型forward阶段 | |||
:param words1: [batch_size, seq_len] 输入word序列 | |||
:param words2: [batch_size, seq_len] 输入pos序列 | |||
@@ -428,7 +428,7 @@ class BiaffineParser(GraphParser): | |||
@staticmethod | |||
def loss(pred1, pred2, target1, target2, seq_len): | |||
""" | |||
r""" | |||
计算parser的loss | |||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
@@ -458,7 +458,7 @@ class BiaffineParser(GraphParser): | |||
return arc_nll + label_nll | |||
def predict(self, words1, words2, seq_len): | |||
"""模型预测API | |||
r"""模型预测API | |||
:param words1: [batch_size, seq_len] 输入word序列 | |||
:param words2: [batch_size, seq_len] 输入pos序列 | |||
@@ -479,7 +479,7 @@ class BiaffineParser(GraphParser): | |||
class ParserLoss(LossFunc): | |||
""" | |||
r""" | |||
计算parser的loss | |||
""" | |||
@@ -487,7 +487,7 @@ class ParserLoss(LossFunc): | |||
def __init__(self, pred1=None, pred2=None, | |||
target1=None, target2=None, | |||
seq_len=None): | |||
""" | |||
r""" | |||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
:param pred2: [batch_size, seq_len, num_label] label预测logits | |||
@@ -505,14 +505,14 @@ class ParserLoss(LossFunc): | |||
class ParserMetric(MetricBase): | |||
""" | |||
r""" | |||
评估parser的性能 | |||
""" | |||
def __init__(self, pred1=None, pred2=None, | |||
target1=None, target2=None, seq_len=None): | |||
""" | |||
r""" | |||
:param pred1: 边预测logits | |||
:param pred2: label预测logits | |||
@@ -539,7 +539,7 @@ class ParserMetric(MetricBase): | |||
return res | |||
def evaluate(self, pred1, pred2, target1, target2, seq_len=None): | |||
"""Evaluate the performance of prediction. | |||
r"""Evaluate the performance of prediction. | |||
""" | |||
if seq_len is None: | |||
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -17,7 +17,7 @@ from ..modules import encoder | |||
class CNNText(torch.nn.Module): | |||
""" | |||
r""" | |||
使用CNN进行文本分类的模型 | |||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | |||
@@ -28,7 +28,7 @@ class CNNText(torch.nn.Module): | |||
kernel_nums=(30, 40, 50), | |||
kernel_sizes=(1, 3, 5), | |||
dropout=0.5): | |||
""" | |||
r""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||
@@ -48,7 +48,7 @@ class CNNText(torch.nn.Module): | |||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | |||
def forward(self, words, seq_len=None): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
@@ -65,7 +65,7 @@ class CNNText(torch.nn.Module): | |||
return {C.OUTPUT: x} | |||
def predict(self, words, seq_len=None): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
本模块实现了几种序列标注模型 | |||
""" | |||
__all__ = [ | |||
@@ -22,13 +22,13 @@ from ..modules.decoder.crf import allowed_transitions | |||
class BiLSTMCRF(BaseModel): | |||
""" | |||
r""" | |||
结构为embedding + BiLSTM + FC + Dropout + CRF. | |||
""" | |||
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, | |||
target_vocab=None): | |||
""" | |||
r""" | |||
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) | |||
:param num_classes: 一共多少个类 | |||
@@ -79,14 +79,14 @@ class BiLSTMCRF(BaseModel): | |||
class SeqLabeling(BaseModel): | |||
""" | |||
r""" | |||
一个基础的Sequence labeling的模型。 | |||
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | |||
""" | |||
def __init__(self, embed, hidden_size, num_classes): | |||
""" | |||
r""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding | |||
@@ -101,7 +101,7 @@ class SeqLabeling(BaseModel): | |||
self.crf = decoder.ConditionalRandomField(num_classes) | |||
def forward(self, words, seq_len, target): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, max_len],序列的index | |||
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | |||
:param torch.LongTensor target: [batch_size, max_len], 序列的目标值 | |||
@@ -118,7 +118,7 @@ class SeqLabeling(BaseModel): | |||
return {C.LOSS: self._internal_loss(x, target, mask)} | |||
def predict(self, words, seq_len): | |||
""" | |||
r""" | |||
用于在预测时使用 | |||
:param torch.LongTensor words: [batch_size, max_len] | |||
@@ -137,7 +137,7 @@ class SeqLabeling(BaseModel): | |||
return {C.OUTPUT: pred} | |||
def _internal_loss(self, x, y, mask): | |||
""" | |||
r""" | |||
Negative log likelihood loss. | |||
:param x: Tensor, [batch_size, max_len, tag_size] | |||
:param y: Tensor, [batch_size, max_len] | |||
@@ -150,7 +150,7 @@ class SeqLabeling(BaseModel): | |||
return torch.mean(total_loss) | |||
def _decode(self, x, mask): | |||
""" | |||
r""" | |||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
:return prediction: [batch_size, max_len] | |||
""" | |||
@@ -159,12 +159,12 @@ class SeqLabeling(BaseModel): | |||
class AdvSeqLabel(nn.Module): | |||
""" | |||
r""" | |||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | |||
""" | |||
def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'): | |||
""" | |||
r""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||
@@ -197,7 +197,7 @@ class AdvSeqLabel(nn.Module): | |||
encoding_type=encoding_type)) | |||
def _decode(self, x, mask): | |||
""" | |||
r""" | |||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
:param torch.ByteTensor mask: [batch_size, max_len] | |||
:return torch.LongTensor, [batch_size, max_len] | |||
@@ -206,7 +206,7 @@ class AdvSeqLabel(nn.Module): | |||
return tag_seq | |||
def _internal_loss(self, x, y, mask): | |||
""" | |||
r""" | |||
Negative log likelihood loss. | |||
:param x: Tensor, [batch_size, max_len, tag_size] | |||
:param y: Tensor, [batch_size, max_len] | |||
@@ -220,7 +220,7 @@ class AdvSeqLabel(nn.Module): | |||
return torch.mean(total_loss) | |||
def _forward(self, words, seq_len, target=None): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len:[batch_size, ] | |||
:param torch.LongTensor target: [batch_size, max_len] | |||
@@ -254,7 +254,7 @@ class AdvSeqLabel(nn.Module): | |||
return {"pred": self._decode(x, mask)} | |||
def forward(self, words, seq_len, target): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len: [batch_size, ] | |||
@@ -264,7 +264,7 @@ class AdvSeqLabel(nn.Module): | |||
return self._forward(words, seq_len, target) | |||
def predict(self, words, seq_len): | |||
""" | |||
r""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len: [batch_size, ] | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -19,7 +19,7 @@ from ..modules.encoder import BiAttention | |||
class ESIM(BaseModel): | |||
""" | |||
r""" | |||
ESIM model的一个PyTorch实现 | |||
论文参见: https://arxiv.org/pdf/1609.06038.pdf | |||
@@ -27,7 +27,7 @@ class ESIM(BaseModel): | |||
def __init__(self, embed, hidden_size=None, num_labels=3, dropout_rate=0.3, | |||
dropout_embed=0.1): | |||
""" | |||
r""" | |||
:param embed: 初始化的Embedding | |||
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度 | |||
@@ -68,7 +68,7 @@ class ESIM(BaseModel): | |||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | |||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||
""" | |||
r""" | |||
:param words1: [batch, seq_len] | |||
:param words2: [batch, seq_len] | |||
:param seq_len1: [batch] | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
Star-Transformer 的 Pytorch 实现。 | |||
""" | |||
__all__ = [ | |||
@@ -18,7 +18,7 @@ from ..modules.encoder.star_transformer import StarTransformer | |||
class StarTransEnc(nn.Module): | |||
""" | |||
r""" | |||
带word embedding的Star-Transformer Encoder | |||
""" | |||
@@ -31,7 +31,7 @@ class StarTransEnc(nn.Module): | |||
max_len, | |||
emb_dropout, | |||
dropout): | |||
""" | |||
r""" | |||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,此时就以传入的对象作为embedding | |||
@@ -56,7 +56,7 @@ class StarTransEnc(nn.Module): | |||
max_len=max_len) | |||
def forward(self, x, mask): | |||
""" | |||
r""" | |||
:param FloatTensor x: [batch, length, hidden] 输入的序列 | |||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||
否则为 1 | |||
@@ -103,7 +103,7 @@ class _NLICls(nn.Module): | |||
class STSeqLabel(nn.Module): | |||
""" | |||
r""" | |||
用于序列标注的Star-Transformer模型 | |||
""" | |||
@@ -117,7 +117,7 @@ class STSeqLabel(nn.Module): | |||
cls_hidden_size=600, | |||
emb_dropout=0.1, | |||
dropout=0.1, ): | |||
""" | |||
r""" | |||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | |||
@@ -143,7 +143,7 @@ class STSeqLabel(nn.Module): | |||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | |||
def forward(self, words, seq_len): | |||
""" | |||
r""" | |||
:param words: [batch, seq_len] 输入序列 | |||
:param seq_len: [batch,] 输入序列的长度 | |||
@@ -156,7 +156,7 @@ class STSeqLabel(nn.Module): | |||
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | |||
def predict(self, words, seq_len): | |||
""" | |||
r""" | |||
:param words: [batch, seq_len] 输入序列 | |||
:param seq_len: [batch,] 输入序列的长度 | |||
@@ -168,7 +168,7 @@ class STSeqLabel(nn.Module): | |||
class STSeqCls(nn.Module): | |||
""" | |||
r""" | |||
用于分类任务的Star-Transformer | |||
""" | |||
@@ -182,7 +182,7 @@ class STSeqCls(nn.Module): | |||
cls_hidden_size=600, | |||
emb_dropout=0.1, | |||
dropout=0.1, ): | |||
""" | |||
r""" | |||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | |||
@@ -208,7 +208,7 @@ class STSeqCls(nn.Module): | |||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout) | |||
def forward(self, words, seq_len): | |||
""" | |||
r""" | |||
:param words: [batch, seq_len] 输入序列 | |||
:param seq_len: [batch,] 输入序列的长度 | |||
@@ -221,7 +221,7 @@ class STSeqCls(nn.Module): | |||
return {Const.OUTPUT: output} | |||
def predict(self, words, seq_len): | |||
""" | |||
r""" | |||
:param words: [batch, seq_len] 输入序列 | |||
:param seq_len: [batch,] 输入序列的长度 | |||
@@ -233,7 +233,7 @@ class STSeqCls(nn.Module): | |||
class STNLICls(nn.Module): | |||
""" | |||
r""" | |||
用于自然语言推断(NLI)的Star-Transformer | |||
""" | |||
@@ -247,7 +247,7 @@ class STNLICls(nn.Module): | |||
cls_hidden_size=600, | |||
emb_dropout=0.1, | |||
dropout=0.1, ): | |||
""" | |||
r""" | |||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | |||
@@ -273,7 +273,7 @@ class STNLICls(nn.Module): | |||
self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) | |||
def forward(self, words1, words2, seq_len1, seq_len2): | |||
""" | |||
r""" | |||
:param words1: [batch, seq_len] 输入序列1 | |||
:param words2: [batch, seq_len] 输入序列2 | |||
@@ -294,7 +294,7 @@ class STNLICls(nn.Module): | |||
return {Const.OUTPUT: output} | |||
def predict(self, words1, words2, seq_len1, seq_len2): | |||
""" | |||
r""" | |||
:param words1: [batch, seq_len] 输入序列1 | |||
:param words2: [batch, seq_len] 输入序列2 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. image:: figures/text_classification.png | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConditionalRandomField", | |||
@@ -16,7 +16,7 @@ from ...core.vocabulary import Vocabulary | |||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, include_start_end=False): | |||
""" | |||
r""" | |||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
:param ~fastNLP.Vocabulary,dict tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | |||
@@ -73,7 +73,7 @@ def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, i | |||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
r""" | |||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
@@ -86,7 +86,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||
return False | |||
encoding_type = encoding_type.lower() | |||
if encoding_type == 'bio': | |||
""" | |||
r""" | |||
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||
+-------+---+---+---+-------+-----+ | |||
| | B | I | O | start | end | | |||
@@ -112,7 +112,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||
elif encoding_type == 'bmes': | |||
""" | |||
r""" | |||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||
+-------+---+---+---+---+-------+-----+ | |||
| | B | M | E | S | start | end | | |||
@@ -167,14 +167,14 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||
class ConditionalRandomField(nn.Module): | |||
""" | |||
r""" | |||
条件随机场。提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | |||
""" | |||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | |||
initial_method=None): | |||
""" | |||
r""" | |||
:param int num_tags: 标签的数量 | |||
:param bool include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | |||
@@ -205,7 +205,7 @@ class ConditionalRandomField(nn.Module): | |||
initial_parameter(self, initial_method) | |||
def _normalizer_likelihood(self, logits, mask): | |||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
sum of the likelihoods across all possible state sequences. | |||
:param logits:FloatTensor, max_len x batch_size x num_tags | |||
@@ -232,7 +232,7 @@ class ConditionalRandomField(nn.Module): | |||
return torch.logsumexp(alpha, 1) | |||
def _gold_score(self, logits, tags, mask): | |||
""" | |||
r""" | |||
Compute the score for the gold path. | |||
:param logits: FloatTensor, max_len x batch_size x num_tags | |||
:param tags: LongTensor, max_len x batch_size | |||
@@ -261,7 +261,7 @@ class ConditionalRandomField(nn.Module): | |||
return score | |||
def forward(self, feats, tags, mask): | |||
""" | |||
r""" | |||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | |||
@@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module): | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, logits, mask, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MLP" | |||
@@ -11,7 +11,7 @@ from ..utils import initial_parameter | |||
class MLP(nn.Module): | |||
""" | |||
r""" | |||
多层感知器 | |||
@@ -36,7 +36,7 @@ class MLP(nn.Module): | |||
""" | |||
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | |||
""" | |||
r""" | |||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | |||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | |||
@@ -87,7 +87,7 @@ class MLP(nn.Module): | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
""" | |||
r""" | |||
:param torch.Tensor x: MLP接受的输入 | |||
:return: torch.Tensor : MLP的输出结果 | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"viterbi_decode" | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"TimestepDropout" | |||
@@ -8,7 +8,7 @@ import torch | |||
class TimestepDropout(torch.nn.Dropout): | |||
""" | |||
r""" | |||
传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` | |||
使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented | |||
r"""undocumented | |||
这个页面的代码大量参考了 allenNLP | |||
""" | |||
@@ -15,7 +15,7 @@ from ..utils import get_dropout_mask | |||
class LstmCellWithProjection(torch.nn.Module): | |||
""" | |||
r""" | |||
An LSTM with Recurrent Dropout and a projected and clipped hidden state and | |||
memory. Note: this implementation is slower than the native Pytorch LSTM because | |||
it cannot make use of CUDNN optimizations for stacked RNNs due to and | |||
@@ -96,7 +96,7 @@ class LstmCellWithProjection(torch.nn.Module): | |||
inputs: torch.FloatTensor, | |||
batch_lengths: List[int], | |||
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): | |||
""" | |||
r""" | |||
Parameters | |||
---------- | |||
inputs : ``torch.FloatTensor``, required. | |||
@@ -307,7 +307,7 @@ class ElmobiLm(torch.nn.Module): | |||
self.backward_layers = backward_layers | |||
def forward(self, inputs, seq_len): | |||
""" | |||
r""" | |||
:param inputs: batch_size x max_len x embed_size | |||
:param seq_len: batch_size | |||
@@ -326,7 +326,7 @@ class ElmobiLm(torch.nn.Module): | |||
inputs: PackedSequence, | |||
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> \ | |||
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |||
""" | |||
r""" | |||
Parameters | |||
---------- | |||
inputs : ``PackedSequence``, required. | |||
@@ -451,7 +451,7 @@ class ConvTokenEmbedder(nn.Module): | |||
self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) | |||
def forward(self, words, chars): | |||
""" | |||
r""" | |||
:param words: | |||
:param chars: Tensor Shape ``(batch_size, sequence_length, 50)``: | |||
:return Tensor Shape ``(batch_size, sequence_length + 2, embedding_dim)`` : | |||
@@ -491,7 +491,7 @@ class ConvTokenEmbedder(nn.Module): | |||
class Highway(torch.nn.Module): | |||
""" | |||
r""" | |||
A `Highway layer <https://arxiv.org/abs/1505.00387>`_ does a gated combination of a linear | |||
transformation and a non-linear transformation of its input. :math:`y = g * x + (1 - g) * | |||
f(A(x))`, where :math:`A` is a linear transformation, :math:`f` is an element-wise | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MultiHeadAttention", | |||
@@ -16,7 +16,7 @@ from fastNLP.modules.utils import initial_parameter | |||
class DotAttention(nn.Module): | |||
""" | |||
r""" | |||
Transformer当中的DotAttention | |||
""" | |||
@@ -29,7 +29,7 @@ class DotAttention(nn.Module): | |||
self.softmax = nn.Softmax(dim=-1) | |||
def forward(self, Q, K, V, mask_out=None): | |||
""" | |||
r""" | |||
:param Q: [..., seq_len_q, key_size] | |||
:param K: [..., seq_len_k, key_size] | |||
@@ -45,12 +45,12 @@ class DotAttention(nn.Module): | |||
class MultiHeadAttention(nn.Module): | |||
""" | |||
r""" | |||
Transformer当中的MultiHeadAttention | |||
""" | |||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | |||
""" | |||
r""" | |||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||
:param key_size: int, 每个head的维度大小。 | |||
@@ -80,7 +80,7 @@ class MultiHeadAttention(nn.Module): | |||
nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
def forward(self, Q, K, V, atte_mask_out=None): | |||
""" | |||
r""" | |||
:param Q: [batch, seq_len_q, model_size] | |||
:param K: [batch, seq_len_k, model_size] | |||
@@ -147,7 +147,7 @@ class BiAttention(nn.Module): | |||
""" | |||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||
""" | |||
r""" | |||
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] | |||
:param torch.Tensor premise_mask: [batch_size, a_seq_len] | |||
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] | |||
@@ -173,13 +173,13 @@ class BiAttention(nn.Module): | |||
class SelfAttention(nn.Module): | |||
""" | |||
r""" | |||
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ | |||
的Self Attention Module. | |||
""" | |||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ): | |||
""" | |||
r""" | |||
:param int input_size: 输入tensor的hidden维度 | |||
:param int attention_unit: 输出tensor的hidden维度 | |||
@@ -199,7 +199,7 @@ class SelfAttention(nn.Module): | |||
initial_parameter(self, initial_method) | |||
def _penalization(self, attention): | |||
""" | |||
r""" | |||
compute the penalization term for attention module | |||
""" | |||
baz = attention.size(0) | |||
@@ -213,7 +213,7 @@ class SelfAttention(nn.Module): | |||
return torch.sum(ret) / size[0] | |||
def forward(self, input, input_origin): | |||
""" | |||
r""" | |||
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 | |||
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 | |||
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented | |||
r"""undocumented | |||
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | |||
有用,也请引用一下他们。 | |||
""" | |||
@@ -45,7 +45,7 @@ BERT_KEY_RENAME_MAP_2 = { | |||
class BertConfig(object): | |||
"""Configuration class to store the configuration of a `BertModel`. | |||
r"""Configuration class to store the configuration of a `BertModel`. | |||
""" | |||
def __init__(self, | |||
@@ -61,7 +61,7 @@ class BertConfig(object): | |||
type_vocab_size=2, | |||
initializer_range=0.02, | |||
layer_norm_eps=1e-12): | |||
"""Constructs BertConfig. | |||
r"""Constructs BertConfig. | |||
Args: | |||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. | |||
@@ -110,7 +110,7 @@ class BertConfig(object): | |||
@classmethod | |||
def from_dict(cls, json_object): | |||
"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||
r"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||
config = BertConfig(vocab_size_or_config_json_file=-1) | |||
for key, value in json_object.items(): | |||
config.__dict__[key] = value | |||
@@ -118,7 +118,7 @@ class BertConfig(object): | |||
@classmethod | |||
def from_json_file(cls, json_file): | |||
"""Constructs a `BertConfig` from a json file of parameters.""" | |||
r"""Constructs a `BertConfig` from a json file of parameters.""" | |||
with open(json_file, "r", encoding='utf-8') as reader: | |||
text = reader.read() | |||
return cls.from_dict(json.loads(text)) | |||
@@ -127,16 +127,16 @@ class BertConfig(object): | |||
return str(self.to_json_string()) | |||
def to_dict(self): | |||
"""Serializes this instance to a Python dictionary.""" | |||
r"""Serializes this instance to a Python dictionary.""" | |||
output = copy.deepcopy(self.__dict__) | |||
return output | |||
def to_json_string(self): | |||
"""Serializes this instance to a JSON string.""" | |||
r"""Serializes this instance to a JSON string.""" | |||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |||
def to_json_file(self, json_file_path): | |||
""" Save this instance to a json file.""" | |||
r""" Save this instance to a json file.""" | |||
with open(json_file_path, "w", encoding='utf-8') as writer: | |||
writer.write(self.to_json_string()) | |||
@@ -167,7 +167,7 @@ def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): | |||
class BertLayerNorm(nn.Module): | |||
def __init__(self, hidden_size, eps=1e-12): | |||
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||
r"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||
""" | |||
super(BertLayerNorm, self).__init__() | |||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
@@ -206,7 +206,7 @@ class DistilBertEmbeddings(nn.Module): | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, input_ids, token_type_ids): | |||
""" | |||
r""" | |||
Parameters | |||
---------- | |||
input_ids: torch.tensor(bs, max_seq_length) | |||
@@ -231,7 +231,7 @@ class DistilBertEmbeddings(nn.Module): | |||
class BertEmbeddings(nn.Module): | |||
"""Construct the embeddings from word, position and token_type embeddings. | |||
r"""Construct the embeddings from word, position and token_type embeddings. | |||
""" | |||
def __init__(self, config): | |||
@@ -415,7 +415,7 @@ class BertPooler(nn.Module): | |||
class BertModel(nn.Module): | |||
""" | |||
r""" | |||
BERT(Bidirectional Embedding Representations from Transformers). | |||
用预训练权重矩阵来建立BERT模型:: | |||
@@ -470,7 +470,7 @@ class BertModel(nn.Module): | |||
self.apply(self.init_bert_weights) | |||
def init_bert_weights(self, module): | |||
""" Initialize the weights. | |||
r""" Initialize the weights. | |||
""" | |||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
@@ -613,7 +613,7 @@ class BertModel(nn.Module): | |||
def whitespace_tokenize(text): | |||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
r"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
text = text.strip() | |||
if not text: | |||
return [] | |||
@@ -622,7 +622,7 @@ def whitespace_tokenize(text): | |||
class WordpieceTokenizer(object): | |||
"""Runs WordPiece tokenization.""" | |||
r"""Runs WordPiece tokenization.""" | |||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): | |||
self.vocab = vocab | |||
@@ -630,7 +630,7 @@ class WordpieceTokenizer(object): | |||
self.max_input_chars_per_word = max_input_chars_per_word | |||
def tokenize(self, text): | |||
"""Tokenizes a piece of text into its word pieces. | |||
r"""Tokenizes a piece of text into its word pieces. | |||
This uses a greedy longest-match-first algorithm to perform tokenization | |||
using the given vocabulary. | |||
@@ -684,7 +684,7 @@ class WordpieceTokenizer(object): | |||
def load_vocab(vocab_file): | |||
"""Loads a vocabulary file into a dictionary.""" | |||
r"""Loads a vocabulary file into a dictionary.""" | |||
vocab = collections.OrderedDict() | |||
index = 0 | |||
with open(vocab_file, "r", encoding="utf-8") as reader: | |||
@@ -699,12 +699,12 @@ def load_vocab(vocab_file): | |||
class BasicTokenizer(object): | |||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||
r"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||
def __init__(self, | |||
do_lower_case=True, | |||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||
"""Constructs a BasicTokenizer. | |||
r"""Constructs a BasicTokenizer. | |||
Args: | |||
do_lower_case: Whether to lower case the input. | |||
@@ -713,7 +713,7 @@ class BasicTokenizer(object): | |||
self.never_split = never_split | |||
def tokenize(self, text): | |||
"""Tokenizes a piece of text.""" | |||
r"""Tokenizes a piece of text.""" | |||
text = self._clean_text(text) | |||
# This was added on November 1st, 2018 for the multilingual and Chinese | |||
# models. This is also applied to the English models now, but it doesn't | |||
@@ -734,7 +734,7 @@ class BasicTokenizer(object): | |||
return output_tokens | |||
def _run_strip_accents(self, text): | |||
"""Strips accents from a piece of text.""" | |||
r"""Strips accents from a piece of text.""" | |||
text = unicodedata.normalize("NFD", text) | |||
output = [] | |||
for char in text: | |||
@@ -745,7 +745,7 @@ class BasicTokenizer(object): | |||
return "".join(output) | |||
def _run_split_on_punc(self, text): | |||
"""Splits punctuation on a piece of text.""" | |||
r"""Splits punctuation on a piece of text.""" | |||
if text in self.never_split: | |||
return [text] | |||
chars = list(text) | |||
@@ -767,7 +767,7 @@ class BasicTokenizer(object): | |||
return ["".join(x) for x in output] | |||
def _tokenize_chinese_chars(self, text): | |||
"""Adds whitespace around any CJK character.""" | |||
r"""Adds whitespace around any CJK character.""" | |||
output = [] | |||
for char in text: | |||
cp = ord(char) | |||
@@ -780,7 +780,7 @@ class BasicTokenizer(object): | |||
return "".join(output) | |||
def _is_chinese_char(self, cp): | |||
"""Checks whether CP is the codepoint of a CJK character.""" | |||
r"""Checks whether CP is the codepoint of a CJK character.""" | |||
# This defines a "chinese character" as anything in the CJK Unicode block: | |||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||
# | |||
@@ -802,7 +802,7 @@ class BasicTokenizer(object): | |||
return False | |||
def _clean_text(self, text): | |||
"""Performs invalid character removal and whitespace cleanup on text.""" | |||
r"""Performs invalid character removal and whitespace cleanup on text.""" | |||
output = [] | |||
for char in text: | |||
cp = ord(char) | |||
@@ -816,7 +816,7 @@ class BasicTokenizer(object): | |||
def _is_whitespace(char): | |||
"""Checks whether `chars` is a whitespace character.""" | |||
r"""Checks whether `chars` is a whitespace character.""" | |||
# \t, \n, and \r are technically contorl characters but we treat them | |||
# as whitespace since they are generally considered as such. | |||
if char == " " or char == "\t" or char == "\n" or char == "\r": | |||
@@ -828,7 +828,7 @@ def _is_whitespace(char): | |||
def _is_control(char): | |||
"""Checks whether `chars` is a control character.""" | |||
r"""Checks whether `chars` is a control character.""" | |||
# These are technically control characters but we count them as whitespace | |||
# characters. | |||
if char == "\t" or char == "\n" or char == "\r": | |||
@@ -840,7 +840,7 @@ def _is_control(char): | |||
def _is_punctuation(char): | |||
"""Checks whether `chars` is a punctuation character.""" | |||
r"""Checks whether `chars` is a punctuation character.""" | |||
cp = ord(char) | |||
# We treat all non-letter/number ASCII as punctuation. | |||
# Characters such as "^", "$", and "`" are not in the Unicode | |||
@@ -856,11 +856,11 @@ def _is_punctuation(char): | |||
class BertTokenizer(object): | |||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||
r"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, | |||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||
"""Constructs a BertTokenizer. | |||
r"""Constructs a BertTokenizer. | |||
Args: | |||
vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||
@@ -889,7 +889,7 @@ class BertTokenizer(object): | |||
self.max_len = max_len if max_len is not None else int(1e12) | |||
def _reinit_on_new_vocab(self, vocab): | |||
""" | |||
r""" | |||
在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质 | |||
:param vocab: | |||
@@ -909,7 +909,7 @@ class BertTokenizer(object): | |||
return split_tokens | |||
def convert_tokens_to_ids(self, tokens): | |||
"""Converts a sequence of tokens into ids using the vocab.""" | |||
r"""Converts a sequence of tokens into ids using the vocab.""" | |||
ids = [] | |||
for token in tokens: | |||
ids.append(self.vocab[token]) | |||
@@ -922,14 +922,14 @@ class BertTokenizer(object): | |||
return ids | |||
def convert_ids_to_tokens(self, ids): | |||
"""Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||
r"""Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||
tokens = [] | |||
for i in ids: | |||
tokens.append(self.ids_to_tokens[i]) | |||
return tokens | |||
def save_vocabulary(self, vocab_path): | |||
"""Save the tokenizer vocabulary to a directory or file.""" | |||
r"""Save the tokenizer vocabulary to a directory or file.""" | |||
index = 0 | |||
if os.path.isdir(vocab_path): | |||
vocab_file = os.path.join(vocab_path, VOCAB_NAME) | |||
@@ -947,7 +947,7 @@ class BertTokenizer(object): | |||
@classmethod | |||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||
""" | |||
r""" | |||
给定模型的名字或者路径,直接读取vocab. | |||
""" | |||
model_dir = _get_bert_dir(model_dir_or_name) | |||
@@ -961,7 +961,7 @@ class BertTokenizer(object): | |||
class _WordPieceBertModel(nn.Module): | |||
""" | |||
r""" | |||
这个模块用于直接计算word_piece的结果. | |||
""" | |||
@@ -989,7 +989,7 @@ class _WordPieceBertModel(nn.Module): | |||
self.pooled_cls = pooled_cls | |||
def index_dataset(self, *datasets, field_name, add_cls_sep=True): | |||
""" | |||
r""" | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||
@@ -1024,7 +1024,7 @@ class _WordPieceBertModel(nn.Module): | |||
raise e | |||
def forward(self, word_pieces, token_type_ids=None): | |||
""" | |||
r""" | |||
:param word_pieces: torch.LongTensor, batch_size x max_len | |||
:param token_type_ids: torch.LongTensor, batch_size x max_len | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConvolutionCharEncoder", | |||
@@ -12,13 +12,13 @@ from ..utils import initial_parameter | |||
# from torch.nn.init import xavier_uniform | |||
class ConvolutionCharEncoder(nn.Module): | |||
""" | |||
r""" | |||
char级别的卷积编码器. | |||
""" | |||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None): | |||
""" | |||
r""" | |||
:param int char_emb_size: char级别embedding的维度. Default: 50 | |||
:例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | |||
@@ -35,7 +35,7 @@ class ConvolutionCharEncoder(nn.Module): | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
""" | |||
r""" | |||
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding | |||
:return: torch.Tensor : 卷积计算的结果, 维度为[batch_size * sent_length, sum(feature_maps), 1] | |||
""" | |||
@@ -60,12 +60,12 @@ class ConvolutionCharEncoder(nn.Module): | |||
class LSTMCharEncoder(nn.Module): | |||
""" | |||
r""" | |||
char级别基于LSTM的encoder. | |||
""" | |||
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | |||
""" | |||
r""" | |||
:param int char_emb_size: char级别embedding的维度. Default: 50 | |||
例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | |||
:param int hidden_size: LSTM隐层的大小, 默认为char的embedding维度 | |||
@@ -82,7 +82,7 @@ class LSTMCharEncoder(nn.Module): | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
""" | |||
r""" | |||
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding | |||
:return: torch.Tensor : [ n_batch*n_word, char_emb_size]经过LSTM编码的结果 | |||
""" | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConvMaxpool" | |||
@@ -9,7 +9,7 @@ import torch.nn.functional as F | |||
class ConvMaxpool(nn.Module): | |||
""" | |||
r""" | |||
集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x | |||
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) | |||
这一维进行max_pooling。最后得到每个sample的一个向量表示。 | |||
@@ -17,7 +17,7 @@ class ConvMaxpool(nn.Module): | |||
""" | |||
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | |||
""" | |||
r""" | |||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
@@ -68,7 +68,7 @@ class ConvMaxpool(nn.Module): | |||
"Undefined activation function: choose from: relu, tanh, sigmoid") | |||
def forward(self, x, mask=None): | |||
""" | |||
r""" | |||
:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | |||
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented | |||
r"""undocumented | |||
轻量封装的 Pytorch LSTM 模块. | |||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||
""" | |||
@@ -13,7 +13,7 @@ import torch.nn.utils.rnn as rnn | |||
class LSTM(nn.Module): | |||
""" | |||
r""" | |||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||
为1; 且可以应对DataParallel中LSTM的使用问题。 | |||
@@ -21,7 +21,7 @@ class LSTM(nn.Module): | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||
bidirectional=False, bias=True): | |||
""" | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 | |||
@@ -50,7 +50,7 @@ class LSTM(nn.Module): | |||
nn.init.xavier_uniform_(param) | |||
def forward(self, x, seq_len=None, h0=None, c0=None): | |||
""" | |||
r""" | |||
:param x: [batch, seq_len, input_size] 输入序列 | |||
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MaxPool", | |||
@@ -12,13 +12,13 @@ import torch.nn as nn | |||
class MaxPool(nn.Module): | |||
""" | |||
r""" | |||
Max-pooling模块。 | |||
""" | |||
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | |||
""" | |||
r""" | |||
:param stride: 窗口移动大小,默认为kernel_size | |||
:param padding: padding的内容,默认为0 | |||
@@ -61,7 +61,7 @@ class MaxPool(nn.Module): | |||
class MaxPoolWithMask(nn.Module): | |||
""" | |||
r""" | |||
带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 | |||
""" | |||
@@ -70,7 +70,7 @@ class MaxPoolWithMask(nn.Module): | |||
self.inf = 10e12 | |||
def forward(self, tensor, mask, dim=1): | |||
""" | |||
r""" | |||
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor | |||
:param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 | |||
:param int dim: 需要进行max pooling的维度 | |||
@@ -82,14 +82,14 @@ class MaxPoolWithMask(nn.Module): | |||
class KMaxPool(nn.Module): | |||
"""K max-pooling module.""" | |||
r"""K max-pooling module.""" | |||
def __init__(self, k=1): | |||
super(KMaxPool, self).__init__() | |||
self.k = k | |||
def forward(self, x): | |||
""" | |||
r""" | |||
:param torch.Tensor x: [N, C, L] 初始tensor | |||
:return: torch.Tensor x: [N, C*k] k-max pool后的结果 | |||
""" | |||
@@ -99,7 +99,7 @@ class KMaxPool(nn.Module): | |||
class AvgPool(nn.Module): | |||
""" | |||
r""" | |||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] | |||
""" | |||
@@ -109,7 +109,7 @@ class AvgPool(nn.Module): | |||
self.padding = padding | |||
def forward(self, x): | |||
""" | |||
r""" | |||
:param torch.Tensor x: [N, C, L] 初始tensor | |||
:return: torch.Tensor x: [N, C] avg pool后的结果 | |||
""" | |||
@@ -124,7 +124,7 @@ class AvgPool(nn.Module): | |||
class AvgPoolWithMask(nn.Module): | |||
""" | |||
r""" | |||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | |||
的时候只会考虑mask为1的位置 | |||
""" | |||
@@ -134,7 +134,7 @@ class AvgPoolWithMask(nn.Module): | |||
self.inf = 10e12 | |||
def forward(self, tensor, mask, dim=1): | |||
""" | |||
r""" | |||
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor | |||
:param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 | |||
:param int dim: 需要进行max pooling的维度 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented | |||
r"""undocumented | |||
Star-Transformer 的encoder部分的 Pytorch 实现 | |||
""" | |||
@@ -13,7 +13,7 @@ from torch.nn import functional as F | |||
class StarTransformer(nn.Module): | |||
""" | |||
r""" | |||
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | |||
paper: https://arxiv.org/abs/1902.09113 | |||
@@ -21,7 +21,7 @@ class StarTransformer(nn.Module): | |||
""" | |||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | |||
""" | |||
r""" | |||
:param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | |||
:param int num_layers: star-transformer的层数 | |||
@@ -51,7 +51,7 @@ class StarTransformer(nn.Module): | |||
self.pos_emb = None | |||
def forward(self, data, mask): | |||
""" | |||
r""" | |||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | |||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||
否则为 1 | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented""" | |||
r"""undocumented""" | |||
__all__ = [ | |||
"TransformerEncoder" | |||
@@ -9,7 +9,7 @@ from .attention import MultiHeadAttention | |||
class TransformerEncoder(nn.Module): | |||
""" | |||
r""" | |||
transformer的encoder模块,不包含embedding层 | |||
""" | |||
@@ -27,7 +27,7 @@ class TransformerEncoder(nn.Module): | |||
self.dropout = nn.Dropout(dropout) | |||
def forward(self, input, seq_mask=None, atte_mask_out=None): | |||
""" | |||
r""" | |||
:param input: [batch, seq_len, model_size] | |||
:param seq_mask: [batch, seq_len] | |||
@@ -46,7 +46,7 @@ class TransformerEncoder(nn.Module): | |||
return input | |||
def __init__(self, num_layers, **kargs): | |||
""" | |||
r""" | |||
:param int num_layers: transformer的层数 | |||
:param int model_size: 输入维度的大小。同时也是输出维度的大小。 | |||
@@ -61,7 +61,7 @@ class TransformerEncoder(nn.Module): | |||
self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) | |||
def forward(self, x, seq_mask=None): | |||
""" | |||
r""" | |||
:param x: [batch, seq_len, model_size] 输入序列 | |||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. | |||
Default: ``None`` | |||
@@ -1,4 +1,4 @@ | |||
"""undocumented | |||
r"""undocumented | |||
Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | |||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
@@ -27,7 +27,7 @@ from ..utils import initial_parameter | |||
class VarRnnCellWrapper(nn.Module): | |||
""" | |||
r""" | |||
Wrapper for normal RNN Cells, make it support variational dropout | |||
""" | |||
@@ -39,7 +39,7 @@ class VarRnnCellWrapper(nn.Module): | |||
self.hidden_p = hidden_p | |||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||
""" | |||
r""" | |||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | |||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | |||
for other RNN, h_0, [batch_size, hidden_size] | |||
@@ -101,7 +101,7 @@ class VarRnnCellWrapper(nn.Module): | |||
class VarRNNBase(nn.Module): | |||
""" | |||
r""" | |||
Variational Dropout RNN 实现. | |||
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||
@@ -112,7 +112,7 @@ class VarRNNBase(nn.Module): | |||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||
bias=True, batch_first=False, | |||
input_dropout=0, hidden_dropout=0, bidirectional=False): | |||
""" | |||
r""" | |||
:param mode: rnn 模式, (lstm or not) | |||
:param Cell: rnn cell 类型, (lstm, gru, etc) | |||
@@ -157,7 +157,7 @@ class VarRNNBase(nn.Module): | |||
return output_x, hidden_x | |||
def forward(self, x, hx=None): | |||
""" | |||
r""" | |||
:param x: [batch, seq_len, input_size] 输入序列 | |||
:param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||
@@ -226,14 +226,14 @@ class VarRNNBase(nn.Module): | |||
class VarLSTM(VarRNNBase): | |||
""" | |||
r""" | |||
Variational Dropout LSTM. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
""" | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
@@ -253,14 +253,14 @@ class VarLSTM(VarRNNBase): | |||
class VarRNN(VarRNNBase): | |||
""" | |||
r""" | |||
Variational Dropout RNN. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
""" | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
@@ -280,14 +280,14 @@ class VarRNN(VarRNNBase): | |||
class VarGRU(VarRNNBase): | |||
""" | |||
r""" | |||
Variational Dropout GRU. | |||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
""" | |||
r""" | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
@@ -17,7 +17,7 @@ import torch.nn.init as init | |||
def initial_parameter(net, initial_method=None): | |||
"""A method used to initialize the weights of PyTorch models. | |||
r"""A method used to initialize the weights of PyTorch models. | |||
:param net: a PyTorch model | |||
:param str initial_method: one of the following initializations. | |||
@@ -81,7 +81,7 @@ def initial_parameter(net, initial_method=None): | |||
def summary(model: nn.Module): | |||
""" | |||
r""" | |||
得到模型的总参数量 | |||
:params model: Pytorch 模型 | |||
@@ -122,7 +122,7 @@ def summary(model: nn.Module): | |||
def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||
""" | |||
r""" | |||
根据tensor的形状,生成一个mask | |||
:param drop_p: float, 以多大的概率置为0。 | |||
@@ -136,7 +136,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||
def _get_file_name_base_on_postfix(dir_path, postfix): | |||
""" | |||
r""" | |||
在dir_path中寻找后缀为postfix的文件. | |||
:param dir_path: str, 文件夹 | |||
:param postfix: 形如".bin", ".json"等 | |||