@@ -5,6 +5,7 @@ __all__ = [ | |||||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping | from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping | ||||
import numpy as np | import numpy as np | ||||
from pkg_resources import parse_version | |||||
from fastNLP.core.dataset import DataSet, Instance | from fastNLP.core.dataset import DataSet, Instance | ||||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | ||||
@@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch import __version__ as torchversion | |||||
from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | ||||
@@ -213,13 +215,21 @@ class MixDataLoader(DataLoader): | |||||
else: | else: | ||||
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") | raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") | ||||
super(MixDataLoader, self).__init__( | |||||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||||
prefetch_factor=2, persistent_workers=False | |||||
) | |||||
if parse_version(torchversion) >= parse_version('1.7'): | |||||
super(MixDataLoader, self).__init__( | |||||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||||
prefetch_factor=2, persistent_workers=False | |||||
) | |||||
else: | |||||
super(MixDataLoader, self).__init__( | |||||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||||
) | |||||
def __iter__(self): | def __iter__(self): | ||||
return super().__iter__() | return super().__iter__() |
@@ -41,7 +41,7 @@ class Driver(ABC): | |||||
r""" | r""" | ||||
根据输入的 ``dataloader`` 得到一个 支持分布式 (``distributed``) 与 可复现的 (``reproducible``) 的 dataloader。 | 根据输入的 ``dataloader`` 得到一个 支持分布式 (``distributed``) 与 可复现的 (``reproducible``) 的 dataloader。 | ||||
:param dataloader: 根据 ``dataloade``r 设置其对应的分布式版本以及可复现版本; | |||||
:param dataloader: 根据 ``dataloader`` 设置其对应的分布式版本以及可复现版本; | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``;为 ``None`` 时,表示不需要考虑当前 dataloader | :param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``;为 ``None`` 时,表示不需要考虑当前 dataloader | ||||
切换为分布式状态;为 ``dist`` 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | 切换为分布式状态;为 ``dist`` 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | ||||
不同 gpu 上出现重复;为 ``unrepeatdist`` 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | 不同 gpu 上出现重复;为 ``unrepeatdist`` 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | ||||
@@ -263,7 +263,6 @@ class Driver(ABC): | |||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象; | :param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象; | ||||
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | :param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | ||||
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | 模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | ||||
:return: 返回加载指定文件后的结果; | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | ||||
@@ -1,13 +1,17 @@ | |||||
import os | import os | ||||
import warnings | |||||
from typing import Optional, Callable, Dict | |||||
import random | |||||
from pathlib import Path | |||||
from typing import Union, Optional | |||||
from functools import partial | |||||
import numpy as np | |||||
from .utils import _build_fp16_env | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
from fastNLP.core.dataloaders import JittorDataLoader | from fastNLP.core.dataloaders import JittorDataLoader | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor as jt | import jittor as jt | ||||
@@ -47,17 +51,18 @@ class JittorDriver(Driver): | |||||
f"`jittor.Module` type.") | f"`jittor.Module` type.") | ||||
super(JittorDriver, self).__init__(model) | super(JittorDriver, self).__init__(model) | ||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||||
self.grad_scaler = _grad_scaler() | |||||
if fp16: | |||||
jt.flags.auto_mixed_precision_level = 6 | |||||
else: | |||||
jt.flags.auto_mixed_precision_level = 0 | |||||
self.fp16 = fp16 | |||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
def check_dataloader_legality(self, dataloader): | def check_dataloader_legality(self, dataloader): | ||||
# 在fastnlp中实现了JittorDataLoader | |||||
if not isinstance(dataloader, Dataset): | |||||
raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`") | |||||
if not isinstance(dataloader, (Dataset, JittorDataLoader)): | |||||
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") | |||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
@@ -66,54 +71,102 @@ class JittorDriver(Driver): | |||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | ||||
f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
def check_evaluator_mode(self, mode: str): | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
optimizer.step() | |||||
def backward(self, loss): | |||||
for optimizer in self.optimizers: | |||||
optimizer.backward(loss) | |||||
def zero_grad(self): | |||||
for optimizer in self.optimizers: | |||||
optimizer.zero_grad() | |||||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): | |||||
r""" | |||||
将模型保存到 ``filepath`` 中。 | |||||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||||
:param only_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持保存模型的 ``state_dict``。 | |||||
""" | |||||
if not only_state_dict: | |||||
logger.rank_zero_warning( | |||||
"Jittor only supports saving state_dict, and we will also save state_dict for you.", | |||||
once=True | |||||
) | |||||
if isinstance(filepath, Path): | |||||
filepath = str(filepath) | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if mode == "evaluate": | |||||
if not hasattr(model, "evaluate_step"): | |||||
if hasattr(model, "test_step"): | |||||
logger.warning_once( | |||||
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you" | |||||
"are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for" | |||||
"'evaluate_step'.") | |||||
model.save(filepath) | |||||
else: | |||||
if not hasattr(model, "test_step"): | |||||
if hasattr(model, "evaluate_step"): | |||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for" | |||||
"'test_step'.") | |||||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | |||||
if model_save_fn is not None: | |||||
outputs = model_save_fn(filepath) | |||||
if outputs is not None: | |||||
jt.save(outputs, filepath) | |||||
else: | |||||
if only_state_dict: | |||||
states = self.model.state_dict() | |||||
else: | |||||
warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.") | |||||
jt.save(states, filepath) | |||||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): | |||||
r""" | |||||
加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。 | |||||
def load_model(self, filepath: str): | |||||
if not os.path.exists(filepath): | |||||
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | |||||
return jt.load(filepath) | |||||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||||
:param load_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持加载模型的 ``state_dict``。 | |||||
""" | |||||
if not only_state_dict: | |||||
logger.rank_zero_warning( | |||||
"Jittor only supports loading state_dict, and we will also load state_dict for you.", | |||||
once=True | |||||
) | |||||
if isinstance(filepath, Path): | |||||
filepath = str(filepath) | |||||
model = self.unwrap_model() | |||||
model.load(filepath) | |||||
def save_checkpoint(self): | def save_checkpoint(self): | ||||
... | ... | ||||
def get_optimizer_state(self): | |||||
# optimizers_state_dict = {} | |||||
# for i in range(len(self.optimizers)): | |||||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
# optimizer_state = optimizer.state_dict() | |||||
# optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||||
# optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
# return optimizers_state_dict | |||||
... | |||||
def load_optimizer_state(self, states): | |||||
# assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
# f"checkpoint it is:{len(states)}" | |||||
# for i in range(len(self.optimizers)): | |||||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
# optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
# logger.debug("Load optimizer state dict.") | |||||
... | |||||
def load_checkpoint(self): | def load_checkpoint(self): | ||||
... | ... | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
return jt.no_grad | return jt.no_grad | ||||
def get_model_device(self): | |||||
return self.model_device | |||||
@staticmethod | |||||
def move_model_to_device(model: "jt.Module", device): | |||||
r""" | |||||
将模型转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。 | |||||
""" | |||||
... | |||||
def move_data_to_device(self, batch): | |||||
""" | |||||
将数据 ``batch`` 转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。 | |||||
""" | |||||
return batch | |||||
@staticmethod | @staticmethod | ||||
def tensor_to_numeric(tensor, reduce=None): | def tensor_to_numeric(tensor, reduce=None): | ||||
r""" | |||||
将一个 :class:`jittor.Var` 对象转换为 转换成 python 中的数值类型; | |||||
:param tensor: :class:`jittor.Var` 类型的对象; | |||||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||||
""" | |||||
if tensor is None: | if tensor is None: | ||||
return None | return None | ||||
@@ -145,7 +198,32 @@ class JittorDriver(Driver): | |||||
""" | """ | ||||
return batch | return batch | ||||
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | |||||
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||||
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||||
@staticmethod | |||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
process_seed = jt.get_seed() | |||||
# back out the base seed so we can use all the bits | |||||
base_seed = process_seed - worker_id | |||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||||
# use 128 bits (4 x 32-bit words) | |||||
np.random.seed(ss.generate_state(4)) | |||||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||||
jittor_ss, stdlib_ss = ss.spawn(2) | |||||
jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) | |||||
# use 128 bits expressed as an integer | |||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||||
random.seed(stdlib_seed) | |||||
def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | |||||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||||
dataloader.worker_init_fn = partial(self.worker_init_function, | |||||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||||
def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | |||||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||||
if callable(getattr(dataloader.sampler, "set_epoch", None)): | |||||
dataloader.sampler.set_epoch(cur_epoch_idx) | |||||
@staticmethod | |||||
def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | |||||
pass |
@@ -146,7 +146,10 @@ class JittorMPIDriver(JittorDriver): | |||||
return self.model.no_sync | return self.model.no_sync | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
pass | |||||
""" | |||||
返回训练使用的模型。 | |||||
""" | |||||
return self.model | |||||
def get_local_rank(self) -> int: | def get_local_rank(self) -> int: | ||||
return self.local_rank | return self.local_rank | ||||
@@ -155,4 +158,14 @@ class JittorMPIDriver(JittorDriver): | |||||
pass | pass | ||||
def is_distributed(self): | def is_distributed(self): | ||||
""" | |||||
判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``True``。 | |||||
""" | |||||
return True | return True | ||||
@property | |||||
def data_device(self) -> str: | |||||
""" | |||||
:return: 数据所在的设备; | |||||
""" | |||||
return self.model_device |
@@ -27,28 +27,36 @@ class JittorSingleDriver(JittorDriver): | |||||
支持 cpu 和 gpu 的切换; | 支持 cpu 和 gpu 的切换; | ||||
实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数 | 实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数 | ||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``; | |||||
* 为 ``None`` 或 ``cpu`` 时 | |||||
表示在 ``cpu`` 上进行训练; | |||||
* 为 ``gpu`` 或 ``cuda`` 时 | |||||
表示在显卡设备上进行训练; | |||||
:param fp16: 是否开启 fp16; | |||||
""" | """ | ||||
def __init__(self, model, device=None, fp16: bool = False, **kwargs): | def __init__(self, model, device=None, fp16: bool = False, **kwargs): | ||||
if device not in [None, "cpu", "gpu", "cuda"]: | |||||
raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .") | |||||
super(JittorSingleDriver, self).__init__(model, fp16) | super(JittorSingleDriver, self).__init__(model, fp16) | ||||
self.model_device = device | |||||
self.model_device = device if device is not None else "cpu" | |||||
self.local_rank = 0 | self.local_rank = 0 | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
self.world_size = 1 | self.world_size = 1 | ||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
optimizer.step() | |||||
def backward(self, loss): | |||||
for optimizer in self.optimizers: | |||||
optimizer.backward(loss) | |||||
def zero_grad(self): | |||||
for optimizer in self.optimizers: | |||||
optimizer.zero_grad() | |||||
def setup(self): | |||||
r""" | |||||
初始化训练环境;根据传入的 ``device`` 值设置模型的训练场景为 ``cpu`` 或 ``gpu``; | |||||
""" | |||||
if self.model_device in ["cpu", None]: | |||||
jt.flags.use_cuda = 0 # 使用 cpu | |||||
else: | |||||
jt.flags.use_cuda = 1 # 使用 cuda | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
@@ -70,9 +78,15 @@ class JittorSingleDriver(JittorDriver): | |||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | ||||
def unwrap_model(self): | def unwrap_model(self): | ||||
""" | |||||
返回训练使用的模型。 | |||||
""" | |||||
return self.model | return self.model | ||||
def is_distributed(self): | def is_distributed(self): | ||||
""" | |||||
判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``False``。 | |||||
""" | |||||
return False | return False | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | ||||
@@ -103,11 +117,15 @@ class JittorSingleDriver(JittorDriver): | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
def setup(self): | |||||
def unwrap_model(self): | |||||
""" | """ | ||||
支持 cpu 和 gpu 的切换 | |||||
返回训练使用的模型。 | |||||
""" | """ | ||||
if self.model_device in ["cpu", None]: | |||||
jt.flags.use_cuda = 0 # 使用 cpu | |||||
else: | |||||
jt.flags.use_cuda = 1 # 使用 cuda | |||||
return self.model | |||||
@property | |||||
def data_device(self) -> str: | |||||
""" | |||||
:return: 数据和模型所在的设备; | |||||
""" | |||||
return self.model_device |
@@ -1,56 +1,6 @@ | |||||
from contextlib import ExitStack | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | import jittor | ||||
__all__ = [] | __all__ = [] | ||||
class DummyGradScaler: | |||||
""" | |||||
用于仿造的 **GradScaler** 对象,防止重复写大量的if判断 | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def get_scale(self): | |||||
return 1.0 | |||||
def is_enabled(self): | |||||
return False | |||||
def scale(self, outputs): | |||||
return outputs | |||||
def step(self, optimizer, *args, **kwargs): | |||||
optimizer.step(*args, **kwargs) | |||||
def update(self, new_scale=None): | |||||
pass | |||||
def unscale_(self, optimizer): | |||||
pass | |||||
def load_state_dict(self, state_dict): | |||||
pass | |||||
def state_dict(self): | |||||
return {} | |||||
def _build_fp16_env(dummy=False): | |||||
if dummy: | |||||
auto_cast = ExitStack | |||||
GradScaler = DummyGradScaler | |||||
else: | |||||
raise NotImplementedError("JittorDriver does not support fp16 now.") | |||||
# if not jt.flags.use_cuda: | |||||
# raise RuntimeError("No cuda") | |||||
# if paddle.device.cuda.get_device_capability(0)[0] < 7: | |||||
# log.warning( | |||||
# "NOTE: your device does NOT support faster training with fp16, " | |||||
# "please switch to FP32 which is likely to be faster" | |||||
# ) | |||||
# from paddle.amp import auto_cast, GradScaler | |||||
return auto_cast, GradScaler |
@@ -113,12 +113,11 @@ class PaddleDriver(Driver): | |||||
@staticmethod | @staticmethod | ||||
def tensor_to_numeric(tensor, reduce=None): | def tensor_to_numeric(tensor, reduce=None): | ||||
r""" | r""" | ||||
将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个元素则返回 float 或 int 。 | |||||
将一个 :class:`paddle.Tensor` 对象转换为 转换成 python 中的数值类型; | |||||
:param tensor: 需要被转换的 `tensor` 对象 | |||||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||||
float 或 int 对象。 | |||||
:return: 转换后返回的结果 | |||||
:param tensor: :class:`paddle.Tensor` 类型的对象; | |||||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||||
""" | """ | ||||
if tensor is None: | if tensor is None: | ||||
return None | return None | ||||
@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: | |||||
from torch.optim import Optimizer | from torch.optim import Optimizer | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | from torch.utils.data import RandomSampler as TorchRandomSampler | ||||
_reduces = { | _reduces = { | ||||
'sum': torch.max, | |||||
'sum': torch.sum, | |||||
'min': torch.min, | 'min': torch.min, | ||||
'max': torch.max, | 'max': torch.max, | ||||
'mean': torch.mean | 'mean': torch.mean | ||||
@@ -68,7 +68,8 @@ def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_ | |||||
该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 | 该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 | ||||
:param data: 要迁移的张量; | :param data: 要迁移的张量; | ||||
:param device: 目标设备,可以是 ``str`` 或 ``int`` 类型; | |||||
:param device: 目标设备,可以是 ``str`` 或 ``int`` 及 **paddle** 自己的 :class:`paddle.fluid.core_avx.Place`、 | |||||
:class:`paddle.CPUPlace` 和 :class:`paddle.CUDAPlace` 类型; | |||||
:return: 迁移后的张量; | :return: 迁移后的张量; | ||||
""" | """ | ||||
if isinstance(device, paddle.fluid.core_avx.Place): | if isinstance(device, paddle.fluid.core_avx.Place): | ||||
@@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None): | |||||
if isinstance(seq_len, jittor.Var): | if isinstance(seq_len, jittor.Var): | ||||
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." | assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." | ||||
batch_size = seq_len.shape[0] | batch_size = seq_len.shape[0] | ||||
broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) | |||||
broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1) | |||||
mask = broad_cast_seq_len < seq_len.unsqueeze(1) | mask = broad_cast_seq_len < seq_len.unsqueeze(1) | ||||
return mask | return mask | ||||
except NameError as e: | except NameError as e: | ||||
@@ -3,7 +3,8 @@ __all__ = [ | |||||
"CONFIG_MAPPING", | "CONFIG_MAPPING", | ||||
"MODEL_NAMES_MAPPING", | "MODEL_NAMES_MAPPING", | ||||
"AutoConfig", | "AutoConfig", | ||||
"TOKENIZER_MAPPING_NAMES", | |||||
"TOKENIZER_MAPPING", | |||||
"AutoTokenizer", | |||||
"get_values", | "get_values", | ||||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", | "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", | ||||
"MODEL_FOR_CAUSAL_LM_MAPPING", | "MODEL_FOR_CAUSAL_LM_MAPPING", | ||||
@@ -43,7 +44,7 @@ __all__ = [ | |||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ | from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ | ||||
AutoConfig | AutoConfig | ||||
from .tokenization_auto import TOKENIZER_MAPPING_NAMES | |||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer | |||||
from .auto_factory import get_values | from .auto_factory import get_values | ||||
from .modeling_auto import ( | from .modeling_auto import ( | ||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, | MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, | ||||
@@ -13,14 +13,31 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
""" Auto Tokenizer class. """ | """ Auto Tokenizer class. """ | ||||
import importlib | |||||
import json | |||||
import os | |||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import TYPE_CHECKING, Optional, Tuple | |||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union | |||||
from ...configuration_utils import PretrainedConfig | |||||
from ...file_utils import ( | from ...file_utils import ( | ||||
cached_path, | |||||
hf_bucket_url, | |||||
is_offline_mode, | |||||
is_sentencepiece_available, | is_sentencepiece_available, | ||||
is_tokenizers_available, | is_tokenizers_available, | ||||
) | ) | ||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE | |||||
from ..encoder_decoder import EncoderDecoderConfig | |||||
from .auto_factory import _LazyAutoMapping | |||||
from .configuration_auto import ( | |||||
CONFIG_MAPPING_NAMES, | |||||
AutoConfig, | |||||
config_class_to_model_type, | |||||
model_type_to_module_name, | |||||
replace_list_option_in_docstrings, | |||||
) | |||||
from fastNLP.core.log import logger | |||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
# This significantly improves completion suggestion performance when | # This significantly improves completion suggestion performance when | ||||
@@ -34,4 +51,297 @@ else: | |||||
("bert", ("BertTokenizer", None)), | ("bert", ("BertTokenizer", None)), | ||||
("gpt2", ("GPT2Tokenizer", None)), | ("gpt2", ("GPT2Tokenizer", None)), | ||||
] | ] | ||||
) | |||||
) | |||||
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES) | |||||
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()} | |||||
def tokenizer_class_from_name(class_name: str): | |||||
if class_name == "PreTrainedTokenizerFast": | |||||
raise RuntimeError("fastNLP does not support TokenizerFast now.") | |||||
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): | |||||
if class_name in tokenizers: | |||||
module_name = model_type_to_module_name(module_name) | |||||
try: | |||||
module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") | |||||
except ImportError: | |||||
raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.") | |||||
return getattr(module, class_name) | |||||
return None | |||||
def get_tokenizer_config( | |||||
pretrained_model_name_or_path: Union[str, os.PathLike], | |||||
cache_dir: Optional[Union[str, os.PathLike]] = None, | |||||
force_download: bool = False, | |||||
resume_download: bool = False, | |||||
proxies: Optional[Dict[str, str]] = None, | |||||
use_auth_token: Optional[Union[bool, str]] = None, | |||||
revision: Optional[str] = None, | |||||
local_files_only: bool = False, | |||||
**kwargs, | |||||
): | |||||
""" | |||||
Loads the tokenizer configuration from a pretrained model tokenizer configuration. | |||||
Args: | |||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||||
This can be either: | |||||
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on | |||||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or | |||||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||||
- a path to a `directory` containing a configuration file saved using the | |||||
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``. | |||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard | |||||
cache should not be used. | |||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they | |||||
exist. | |||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. | |||||
proxies (:obj:`Dict[str, str]`, `optional`): | |||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. | |||||
use_auth_token (:obj:`str` or `bool`, `optional`): | |||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token | |||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). | |||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
identifier allowed by git. | |||||
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
If :obj:`True`, will only try to load the tokenizer configuration from local files. | |||||
.. note:: | |||||
Passing :obj:`use_auth_token=True` is required when you want to use a private model. | |||||
Returns: | |||||
:obj:`Dict`: The configuration of the tokenizer. | |||||
Examples:: | |||||
# Download configuration from huggingface.co and cache. | |||||
tokenizer_config = get_tokenizer_config("bert-base-uncased") | |||||
# This model does not have a tokenizer config so the result will be an empty dict. | |||||
tokenizer_config = get_tokenizer_config("xlm-roberta-base") | |||||
# Save a pretrained tokenizer locally and you can reload its config | |||||
from transformers import AutoTokenizer | |||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |||||
tokenizer.save_pretrained("tokenizer-test") | |||||
tokenizer_config = get_tokenizer_config("tokenizer-test") | |||||
""" | |||||
if is_offline_mode() and not local_files_only: | |||||
logger.info("Offline mode: forcing local_files_only=True") | |||||
local_files_only = True | |||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path) | |||||
if os.path.isdir(pretrained_model_name_or_path): | |||||
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) | |||||
else: | |||||
config_file = hf_bucket_url( | |||||
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None | |||||
) | |||||
try: | |||||
# Load from URL or cache if already cached | |||||
resolved_config_file = cached_path( | |||||
config_file, | |||||
cache_dir=cache_dir, | |||||
force_download=force_download, | |||||
proxies=proxies, | |||||
resume_download=resume_download, | |||||
local_files_only=local_files_only, | |||||
use_auth_token=use_auth_token, | |||||
) | |||||
except EnvironmentError: | |||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") | |||||
return {} | |||||
with open(resolved_config_file, encoding="utf-8") as reader: | |||||
return json.load(reader) | |||||
class AutoTokenizer: | |||||
r""" | |||||
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when | |||||
created with the :meth:`AutoTokenizer.from_pretrained` class method. | |||||
This class cannot be instantiated directly using ``__init__()`` (throws an error). | |||||
""" | |||||
def __init__(self): | |||||
raise EnvironmentError( | |||||
"AutoTokenizer is designed to be instantiated " | |||||
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method." | |||||
) | |||||
@classmethod | |||||
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES) | |||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): | |||||
r""" | |||||
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary. | |||||
The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object | |||||
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's | |||||
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||||
List options | |||||
Params: | |||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||||
Can be either: | |||||
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co. | |||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||||
a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||||
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved | |||||
using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., | |||||
``./my_model_directory/``. | |||||
- A path or url to a single saved vocabulary file if and only if the tokenizer only requires a | |||||
single vocabulary file (like Bert or XLNet), e.g.: ``./my_model_directory/vocab.txt``. (Not | |||||
applicable to all derived classes) | |||||
inputs (additional positional arguments, `optional`): | |||||
Will be passed along to the Tokenizer ``__init__()`` method. | |||||
config (:class:`~transformers.PretrainedConfig`, `optional`) | |||||
The configuration object used to dertermine the tokenizer class to instantiate. | |||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||||
standard cache should not be used. | |||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
Whether or not to force the (re-)download the model weights and configuration files and override the | |||||
cached versions if they exist. | |||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||||
file exists. | |||||
proxies (:obj:`Dict[str, str]`, `optional`): | |||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
identifier allowed by git. | |||||
subfolder (:obj:`str`, `optional`): | |||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for | |||||
facebook/rag-token-base), specify it here. | |||||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): | |||||
Whether or not to try to load the fast version of the tokenizer. | |||||
tokenizer_type (:obj:`str`, `optional`): | |||||
Tokenizer type to be loaded. | |||||
kwargs (additional keyword arguments, `optional`): | |||||
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like | |||||
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, | |||||
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__()`` for more details. | |||||
Examples:: | |||||
>>> from transformers import AutoTokenizer | |||||
>>> # Download vocabulary from huggingface.co and cache. | |||||
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |||||
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache. | |||||
>>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased') | |||||
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) | |||||
>>> tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') | |||||
""" | |||||
config = kwargs.pop("config", None) | |||||
kwargs["_from_auto"] = True | |||||
use_fast = kwargs.pop("use_fast", True) | |||||
tokenizer_type = kwargs.pop("tokenizer_type", None) | |||||
# First, let's see whether the tokenizer_type is passed so that we can leverage it | |||||
if tokenizer_type is not None: | |||||
tokenizer_class = None | |||||
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None) | |||||
if tokenizer_class_tuple is None: | |||||
raise ValueError( | |||||
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of " | |||||
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}." | |||||
) | |||||
tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple | |||||
if use_fast and tokenizer_fast_class_name is not None: | |||||
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name) | |||||
if tokenizer_class is None: | |||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name) | |||||
if tokenizer_class is None: | |||||
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.") | |||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||||
# Next, let's try to use the tokenizer_config file to get the tokenizer class. | |||||
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) | |||||
config_tokenizer_class = tokenizer_config.get("tokenizer_class") | |||||
# If that did not work, let's try to use the config. | |||||
if config_tokenizer_class is None: | |||||
if not isinstance(config, PretrainedConfig): | |||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |||||
config_tokenizer_class = config.tokenizer_class | |||||
# If we have the tokenizer class from the tokenizer config or the model config we're good! | |||||
if config_tokenizer_class is not None: | |||||
tokenizer_class = None | |||||
if use_fast and not config_tokenizer_class.endswith("Fast"): | |||||
tokenizer_class_candidate = f"{config_tokenizer_class}Fast" | |||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) | |||||
if tokenizer_class is None: | |||||
tokenizer_class_candidate = config_tokenizer_class | |||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) | |||||
if tokenizer_class is None: | |||||
raise ValueError( | |||||
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported." | |||||
) | |||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||||
# Otherwise we have to be creative. | |||||
# if model is an encoder decoder, the encoder tokenizer class is used by default | |||||
if isinstance(config, EncoderDecoderConfig): | |||||
if type(config.decoder) is not type(config.encoder): # noqa: E721 | |||||
logger.warning( | |||||
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model " | |||||
f"config class: {config.decoder.__class__}. It is not recommended to use the " | |||||
"`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder " | |||||
"specific tokenizer classes." | |||||
) | |||||
config = config.encoder | |||||
model_type = config_class_to_model_type(type(config).__name__) | |||||
if model_type is not None: | |||||
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)] | |||||
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None): | |||||
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||||
else: | |||||
if tokenizer_class_py is not None: | |||||
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||||
else: | |||||
raise ValueError( | |||||
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed " | |||||
"in order to use this tokenizer." | |||||
) | |||||
raise ValueError( | |||||
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n" | |||||
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}." | |||||
) |
@@ -0,0 +1,5 @@ | |||||
__all__ = [ | |||||
"EncoderDecoderConfig", | |||||
] | |||||
from .configuration_encoder_decoder import EncoderDecoderConfig |
@@ -0,0 +1,114 @@ | |||||
# coding=utf-8 | |||||
# Copyright 2020 The HuggingFace Inc. team. | |||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import copy | |||||
from ...configuration_utils import PretrainedConfig | |||||
from fastNLP.core.log import logger | |||||
class EncoderDecoderConfig(PretrainedConfig): | |||||
r""" | |||||
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a | |||||
:class:`~transformers.EncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to the | |||||
specified arguments, defining the encoder and decoder configs. | |||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model | |||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. | |||||
Args: | |||||
kwargs (`optional`): | |||||
Dictionary of keyword arguments. Notably: | |||||
- **encoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration | |||||
object that defines the encoder config. | |||||
- **decoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration | |||||
object that defines the decoder config. | |||||
Examples:: | |||||
>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel | |||||
>>> # Initializing a BERT bert-base-uncased style configuration | |||||
>>> config_encoder = BertConfig() | |||||
>>> config_decoder = BertConfig() | |||||
>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) | |||||
>>> # Initializing a Bert2Bert model from the bert-base-uncased style configurations | |||||
>>> model = EncoderDecoderModel(config=config) | |||||
>>> # Accessing the model configuration | |||||
>>> config_encoder = model.config.encoder | |||||
>>> config_decoder = model.config.decoder | |||||
>>> # set decoder config to causal lm | |||||
>>> config_decoder.is_decoder = True | |||||
>>> config_decoder.add_cross_attention = True | |||||
>>> # Saving the model, including its configuration | |||||
>>> model.save_pretrained('my-model') | |||||
>>> # loading model and config from pretrained folder | |||||
>>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model') | |||||
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config) | |||||
""" | |||||
model_type = "encoder-decoder" | |||||
is_composition = True | |||||
def __init__(self, **kwargs): | |||||
super().__init__(**kwargs) | |||||
assert ( | |||||
"encoder" in kwargs and "decoder" in kwargs | |||||
), "Config has to be initialized with encoder and decoder config" | |||||
encoder_config = kwargs.pop("encoder") | |||||
encoder_model_type = encoder_config.pop("model_type") | |||||
decoder_config = kwargs.pop("decoder") | |||||
decoder_model_type = decoder_config.pop("model_type") | |||||
from ..auto.configuration_auto import AutoConfig | |||||
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config) | |||||
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config) | |||||
self.is_encoder_decoder = True | |||||
@classmethod | |||||
def from_encoder_decoder_configs( | |||||
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs | |||||
) -> PretrainedConfig: | |||||
r""" | |||||
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model | |||||
configuration and decoder model configuration. | |||||
Returns: | |||||
:class:`EncoderDecoderConfig`: An instance of a configuration object | |||||
""" | |||||
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") | |||||
decoder_config.is_decoder = True | |||||
decoder_config.add_cross_attention = True | |||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) | |||||
def to_dict(self): | |||||
""" | |||||
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`. | |||||
Returns: | |||||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, | |||||
""" | |||||
output = copy.deepcopy(self.__dict__) | |||||
output["encoder"] = self.encoder.to_dict() | |||||
output["decoder"] = self.decoder.to_dict() | |||||
output["model_type"] = self.__class__.model_type | |||||
return output |
@@ -75,12 +75,12 @@ class TestPaddleDriverFunctions: | |||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | ||||
""" | """ | ||||
dataloader = DataLoader(PaddleNormalDataset()) | dataloader = DataLoader(PaddleNormalDataset()) | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建torch的dataloader | # 创建torch的dataloader | ||||
dataloader = torch.utils.data.DataLoader( | dataloader = torch.utils.data.DataLoader( | ||||
@@ -88,7 +88,7 @@ class TestPaddleDriverFunctions: | |||||
batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
) | ) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
@pytest.mark.torchpaddle | @pytest.mark.torchpaddle | ||||
def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
@@ -100,7 +100,7 @@ class TestPaddleDriverFunctions: | |||||
"train": DataLoader(PaddleNormalDataset()), | "train": DataLoader(PaddleNormalDataset()), | ||||
"test":DataLoader(PaddleNormalDataset()) | "test":DataLoader(PaddleNormalDataset()) | ||||
} | } | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = { | dataloader = { | ||||
@@ -108,12 +108,12 @@ class TestPaddleDriverFunctions: | |||||
"test":DataLoader(PaddleNormalDataset(), batch_size=None) | "test":DataLoader(PaddleNormalDataset(), batch_size=None) | ||||
} | } | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 传入的不是 dict ,应该报错 | # 传入的不是 dict ,应该报错 | ||||
dataloader = DataLoader(PaddleNormalDataset()) | dataloader = DataLoader(PaddleNormalDataset()) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建 torch 的 dataloader | # 创建 torch 的 dataloader | ||||
train_loader = torch.utils.data.DataLoader( | train_loader = torch.utils.data.DataLoader( | ||||
@@ -126,7 +126,7 @@ class TestPaddleDriverFunctions: | |||||
) | ) | ||||
dataloader = {"train": train_loader, "test": test_loader} | dataloader = {"train": train_loader, "test": test_loader} | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
@@ -183,6 +183,21 @@ class TestPaddleDriverFunctions: | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | ||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
def test_tensor_to_numeric_reduce(self): | |||||
tensor = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||||
res_max = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||||
res_min = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||||
res_sum = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||||
res_mean = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||||
assert res_max == 6 | |||||
assert res_min == 1 | |||||
assert res_sum == 21 | |||||
assert res_mean == 3.5 | |||||
@pytest.mark.paddle | |||||
def test_set_model_mode(self): | def test_set_model_mode(self): | ||||
""" | """ | ||||
测试 set_model_mode 函数 | 测试 set_model_mode 函数 | ||||
@@ -117,7 +117,7 @@ class TestTorchDriverFunctions: | |||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | ||||
""" | """ | ||||
dataloader = DataLoader(TorchNormalDataset()) | dataloader = DataLoader(TorchNormalDataset()) | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建 paddle 的 dataloader | # 创建 paddle 的 dataloader | ||||
dataloader = paddle.io.DataLoader( | dataloader = paddle.io.DataLoader( | ||||
@@ -125,7 +125,7 @@ class TestTorchDriverFunctions: | |||||
batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
) | ) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
@pytest.mark.torchpaddle | @pytest.mark.torchpaddle | ||||
def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
@@ -137,12 +137,12 @@ class TestTorchDriverFunctions: | |||||
"train": DataLoader(TorchNormalDataset()), | "train": DataLoader(TorchNormalDataset()), | ||||
"test": DataLoader(TorchNormalDataset()) | "test": DataLoader(TorchNormalDataset()) | ||||
} | } | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 传入的不是 dict,应该报错 | # 传入的不是 dict,应该报错 | ||||
dataloader = DataLoader(TorchNormalDataset()) | dataloader = DataLoader(TorchNormalDataset()) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
# 创建 paddle 的 dataloader | # 创建 paddle 的 dataloader | ||||
train_loader = paddle.io.DataLoader( | train_loader = paddle.io.DataLoader( | ||||
@@ -155,7 +155,7 @@ class TestTorchDriverFunctions: | |||||
) | ) | ||||
dataloader = {"train": train_loader, "test": test_loader} | dataloader = {"train": train_loader, "test": test_loader} | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
@@ -212,6 +212,20 @@ class TestTorchDriverFunctions: | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_tensor_to_numeric_reduce(self): | |||||
tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||||
res_max = TorchSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||||
res_min = TorchSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||||
res_sum = TorchSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||||
res_mean = TorchSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||||
assert res_max == 6 | |||||
assert res_min == 1 | |||||
assert res_sum == 21 | |||||
assert res_mean == 3.5 | |||||
@pytest.mark.torch | |||||
def test_set_model_mode(self): | def test_set_model_mode(self): | ||||
""" | """ | ||||
测试set_model_mode函数 | 测试set_model_mode函数 | ||||
@@ -78,7 +78,7 @@ class TestSeqLenToMask: | |||||
mask = seq_len_to_mask(seq_len) | mask = seq_len_to_mask(seq_len) | ||||
# 3. pad到指定长度 | # 3. pad到指定长度 | ||||
seq_len = paddle.randint(1, 10, size=(10,)) | |||||
seq_len = paddle.randint(1, 10, shape=(10,)) | |||||
mask = seq_len_to_mask(seq_len, 100) | mask = seq_len_to_mask(seq_len, 100) | ||||
assert 100 == mask.shape[1] | assert 100 == mask.shape[1] | ||||