@@ -5,6 +5,7 @@ __all__ = [ | |||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping | |||
import numpy as np | |||
from pkg_resources import parse_version | |||
from fastNLP.core.dataset import DataSet, Instance | |||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | |||
@@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.collators import Collator | |||
if _NEED_IMPORT_TORCH: | |||
from torch import __version__ as torchversion | |||
from torch.utils.data import DataLoader, Sampler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
@@ -213,13 +215,21 @@ class MixDataLoader(DataLoader): | |||
else: | |||
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") | |||
super(MixDataLoader, self).__init__( | |||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||
prefetch_factor=2, persistent_workers=False | |||
) | |||
if parse_version(torchversion) >= parse_version('1.7'): | |||
super(MixDataLoader, self).__init__( | |||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||
prefetch_factor=2, persistent_workers=False | |||
) | |||
else: | |||
super(MixDataLoader, self).__init__( | |||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||
) | |||
def __iter__(self): | |||
return super().__iter__() |
@@ -41,7 +41,7 @@ class Driver(ABC): | |||
r""" | |||
根据输入的 ``dataloader`` 得到一个 支持分布式 (``distributed``) 与 可复现的 (``reproducible``) 的 dataloader。 | |||
:param dataloader: 根据 ``dataloade``r 设置其对应的分布式版本以及可复现版本; | |||
:param dataloader: 根据 ``dataloader`` 设置其对应的分布式版本以及可复现版本; | |||
:param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``;为 ``None`` 时,表示不需要考虑当前 dataloader | |||
切换为分布式状态;为 ``dist`` 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||
不同 gpu 上出现重复;为 ``unrepeatdist`` 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
@@ -263,7 +263,6 @@ class Driver(ABC): | |||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象; | |||
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | |||
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | |||
:return: 返回加载指定文件后的结果; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | |||
@@ -1,13 +1,17 @@ | |||
import os | |||
import warnings | |||
from typing import Optional, Callable, Dict | |||
import random | |||
from pathlib import Path | |||
from typing import Union, Optional | |||
from functools import partial | |||
import numpy as np | |||
from .utils import _build_fp16_env | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
@@ -47,17 +51,18 @@ class JittorDriver(Driver): | |||
f"`jittor.Module` type.") | |||
super(JittorDriver, self).__init__(model) | |||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||
self.grad_scaler = _grad_scaler() | |||
if fp16: | |||
jt.flags.auto_mixed_precision_level = 6 | |||
else: | |||
jt.flags.auto_mixed_precision_level = 0 | |||
self.fp16 = fp16 | |||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
def check_dataloader_legality(self, dataloader): | |||
# 在fastnlp中实现了JittorDataLoader | |||
if not isinstance(dataloader, Dataset): | |||
raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`") | |||
if not isinstance(dataloader, (Dataset, JittorDataLoader)): | |||
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
@@ -66,54 +71,102 @@ class JittorDriver(Driver): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
def check_evaluator_mode(self, mode: str): | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
optimizer.step() | |||
def backward(self, loss): | |||
for optimizer in self.optimizers: | |||
optimizer.backward(loss) | |||
def zero_grad(self): | |||
for optimizer in self.optimizers: | |||
optimizer.zero_grad() | |||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): | |||
r""" | |||
将模型保存到 ``filepath`` 中。 | |||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||
:param only_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持保存模型的 ``state_dict``。 | |||
""" | |||
if not only_state_dict: | |||
logger.rank_zero_warning( | |||
"Jittor only supports saving state_dict, and we will also save state_dict for you.", | |||
once=True | |||
) | |||
if isinstance(filepath, Path): | |||
filepath = str(filepath) | |||
model = self.unwrap_model() | |||
if mode == "evaluate": | |||
if not hasattr(model, "evaluate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning_once( | |||
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you" | |||
"are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for" | |||
"'evaluate_step'.") | |||
model.save(filepath) | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "evaluate_step"): | |||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for" | |||
"'test_step'.") | |||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | |||
if model_save_fn is not None: | |||
outputs = model_save_fn(filepath) | |||
if outputs is not None: | |||
jt.save(outputs, filepath) | |||
else: | |||
if only_state_dict: | |||
states = self.model.state_dict() | |||
else: | |||
warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.") | |||
jt.save(states, filepath) | |||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): | |||
r""" | |||
加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。 | |||
def load_model(self, filepath: str): | |||
if not os.path.exists(filepath): | |||
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | |||
return jt.load(filepath) | |||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||
:param load_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持加载模型的 ``state_dict``。 | |||
""" | |||
if not only_state_dict: | |||
logger.rank_zero_warning( | |||
"Jittor only supports loading state_dict, and we will also load state_dict for you.", | |||
once=True | |||
) | |||
if isinstance(filepath, Path): | |||
filepath = str(filepath) | |||
model = self.unwrap_model() | |||
model.load(filepath) | |||
def save_checkpoint(self): | |||
... | |||
def get_optimizer_state(self): | |||
# optimizers_state_dict = {} | |||
# for i in range(len(self.optimizers)): | |||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||
# optimizer_state = optimizer.state_dict() | |||
# optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||
# optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
# return optimizers_state_dict | |||
... | |||
def load_optimizer_state(self, states): | |||
# assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||
# f"checkpoint it is:{len(states)}" | |||
# for i in range(len(self.optimizers)): | |||
# optimizer: torch.optim.Optimizer = self.optimizers[i] | |||
# optimizer.load_state_dict(states[f"optimizer{i}"]) | |||
# logger.debug("Load optimizer state dict.") | |||
... | |||
def load_checkpoint(self): | |||
... | |||
def get_evaluate_context(self): | |||
return jt.no_grad | |||
def get_model_device(self): | |||
return self.model_device | |||
@staticmethod | |||
def move_model_to_device(model: "jt.Module", device): | |||
r""" | |||
将模型转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。 | |||
""" | |||
... | |||
def move_data_to_device(self, batch): | |||
""" | |||
将数据 ``batch`` 转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。 | |||
""" | |||
return batch | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce=None): | |||
r""" | |||
将一个 :class:`jittor.Var` 对象转换为 转换成 python 中的数值类型; | |||
:param tensor: :class:`jittor.Var` 类型的对象; | |||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||
""" | |||
if tensor is None: | |||
return None | |||
@@ -145,7 +198,32 @@ class JittorDriver(Driver): | |||
""" | |||
return batch | |||
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | |||
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||
@staticmethod | |||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||
process_seed = jt.get_seed() | |||
# back out the base seed so we can use all the bits | |||
base_seed = process_seed - worker_id | |||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||
# use 128 bits (4 x 32-bit words) | |||
np.random.seed(ss.generate_state(4)) | |||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||
jittor_ss, stdlib_ss = ss.spawn(2) | |||
jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) | |||
# use 128 bits expressed as an integer | |||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||
random.seed(stdlib_seed) | |||
def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | |||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||
dataloader.worker_init_fn = partial(self.worker_init_function, | |||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||
def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | |||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||
if callable(getattr(dataloader.sampler, "set_epoch", None)): | |||
dataloader.sampler.set_epoch(cur_epoch_idx) | |||
@staticmethod | |||
def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | |||
pass |
@@ -146,7 +146,10 @@ class JittorMPIDriver(JittorDriver): | |||
return self.model.no_sync | |||
def unwrap_model(self): | |||
pass | |||
""" | |||
返回训练使用的模型。 | |||
""" | |||
return self.model | |||
def get_local_rank(self) -> int: | |||
return self.local_rank | |||
@@ -155,4 +158,14 @@ class JittorMPIDriver(JittorDriver): | |||
pass | |||
def is_distributed(self): | |||
""" | |||
判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``True``。 | |||
""" | |||
return True | |||
@property | |||
def data_device(self) -> str: | |||
""" | |||
:return: 数据所在的设备; | |||
""" | |||
return self.model_device |
@@ -27,28 +27,36 @@ class JittorSingleDriver(JittorDriver): | |||
支持 cpu 和 gpu 的切换; | |||
实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数 | |||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||
:param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``; | |||
* 为 ``None`` 或 ``cpu`` 时 | |||
表示在 ``cpu`` 上进行训练; | |||
* 为 ``gpu`` 或 ``cuda`` 时 | |||
表示在显卡设备上进行训练; | |||
:param fp16: 是否开启 fp16; | |||
""" | |||
def __init__(self, model, device=None, fp16: bool = False, **kwargs): | |||
if device not in [None, "cpu", "gpu", "cuda"]: | |||
raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .") | |||
super(JittorSingleDriver, self).__init__(model, fp16) | |||
self.model_device = device | |||
self.model_device = device if device is not None else "cpu" | |||
self.local_rank = 0 | |||
self.global_rank = 0 | |||
self.world_size = 1 | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
optimizer.step() | |||
def backward(self, loss): | |||
for optimizer in self.optimizers: | |||
optimizer.backward(loss) | |||
def zero_grad(self): | |||
for optimizer in self.optimizers: | |||
optimizer.zero_grad() | |||
def setup(self): | |||
r""" | |||
初始化训练环境;根据传入的 ``device`` 值设置模型的训练场景为 ``cpu`` 或 ``gpu``; | |||
""" | |||
if self.model_device in ["cpu", None]: | |||
jt.flags.use_cuda = 0 # 使用 cpu | |||
else: | |||
jt.flags.use_cuda = 1 # 使用 cuda | |||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
@@ -70,9 +78,15 @@ class JittorSingleDriver(JittorDriver): | |||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||
def unwrap_model(self): | |||
""" | |||
返回训练使用的模型。 | |||
""" | |||
return self.model | |||
def is_distributed(self): | |||
""" | |||
判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``False``。 | |||
""" | |||
return False | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
@@ -103,11 +117,15 @@ class JittorSingleDriver(JittorDriver): | |||
else: | |||
return dataloader | |||
def setup(self): | |||
def unwrap_model(self): | |||
""" | |||
支持 cpu 和 gpu 的切换 | |||
返回训练使用的模型。 | |||
""" | |||
if self.model_device in ["cpu", None]: | |||
jt.flags.use_cuda = 0 # 使用 cpu | |||
else: | |||
jt.flags.use_cuda = 1 # 使用 cuda | |||
return self.model | |||
@property | |||
def data_device(self) -> str: | |||
""" | |||
:return: 数据和模型所在的设备; | |||
""" | |||
return self.model_device |
@@ -1,56 +1,6 @@ | |||
from contextlib import ExitStack | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
__all__ = [] | |||
class DummyGradScaler: | |||
""" | |||
用于仿造的 **GradScaler** 对象,防止重复写大量的if判断 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def get_scale(self): | |||
return 1.0 | |||
def is_enabled(self): | |||
return False | |||
def scale(self, outputs): | |||
return outputs | |||
def step(self, optimizer, *args, **kwargs): | |||
optimizer.step(*args, **kwargs) | |||
def update(self, new_scale=None): | |||
pass | |||
def unscale_(self, optimizer): | |||
pass | |||
def load_state_dict(self, state_dict): | |||
pass | |||
def state_dict(self): | |||
return {} | |||
def _build_fp16_env(dummy=False): | |||
if dummy: | |||
auto_cast = ExitStack | |||
GradScaler = DummyGradScaler | |||
else: | |||
raise NotImplementedError("JittorDriver does not support fp16 now.") | |||
# if not jt.flags.use_cuda: | |||
# raise RuntimeError("No cuda") | |||
# if paddle.device.cuda.get_device_capability(0)[0] < 7: | |||
# log.warning( | |||
# "NOTE: your device does NOT support faster training with fp16, " | |||
# "please switch to FP32 which is likely to be faster" | |||
# ) | |||
# from paddle.amp import auto_cast, GradScaler | |||
return auto_cast, GradScaler |
@@ -113,12 +113,11 @@ class PaddleDriver(Driver): | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce=None): | |||
r""" | |||
将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个元素则返回 float 或 int 。 | |||
将一个 :class:`paddle.Tensor` 对象转换为 转换成 python 中的数值类型; | |||
:param tensor: 需要被转换的 `tensor` 对象 | |||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||
float 或 int 对象。 | |||
:return: 转换后返回的结果 | |||
:param tensor: :class:`paddle.Tensor` 类型的对象; | |||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||
""" | |||
if tensor is None: | |||
return None | |||
@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: | |||
from torch.optim import Optimizer | |||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||
_reduces = { | |||
'sum': torch.max, | |||
'sum': torch.sum, | |||
'min': torch.min, | |||
'max': torch.max, | |||
'mean': torch.mean | |||
@@ -68,7 +68,8 @@ def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_ | |||
该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 | |||
:param data: 要迁移的张量; | |||
:param device: 目标设备,可以是 ``str`` 或 ``int`` 类型; | |||
:param device: 目标设备,可以是 ``str`` 或 ``int`` 及 **paddle** 自己的 :class:`paddle.fluid.core_avx.Place`、 | |||
:class:`paddle.CPUPlace` 和 :class:`paddle.CUDAPlace` 类型; | |||
:return: 迁移后的张量; | |||
""" | |||
if isinstance(device, paddle.fluid.core_avx.Place): | |||
@@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None): | |||
if isinstance(seq_len, jittor.Var): | |||
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." | |||
batch_size = seq_len.shape[0] | |||
broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) | |||
broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1) | |||
mask = broad_cast_seq_len < seq_len.unsqueeze(1) | |||
return mask | |||
except NameError as e: | |||
@@ -3,7 +3,8 @@ __all__ = [ | |||
"CONFIG_MAPPING", | |||
"MODEL_NAMES_MAPPING", | |||
"AutoConfig", | |||
"TOKENIZER_MAPPING_NAMES", | |||
"TOKENIZER_MAPPING", | |||
"AutoTokenizer", | |||
"get_values", | |||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", | |||
"MODEL_FOR_CAUSAL_LM_MAPPING", | |||
@@ -43,7 +44,7 @@ __all__ = [ | |||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ | |||
AutoConfig | |||
from .tokenization_auto import TOKENIZER_MAPPING_NAMES | |||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer | |||
from .auto_factory import get_values | |||
from .modeling_auto import ( | |||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, | |||
@@ -13,14 +13,31 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" Auto Tokenizer class. """ | |||
import importlib | |||
import json | |||
import os | |||
from collections import OrderedDict | |||
from typing import TYPE_CHECKING, Optional, Tuple | |||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union | |||
from ...configuration_utils import PretrainedConfig | |||
from ...file_utils import ( | |||
cached_path, | |||
hf_bucket_url, | |||
is_offline_mode, | |||
is_sentencepiece_available, | |||
is_tokenizers_available, | |||
) | |||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE | |||
from ..encoder_decoder import EncoderDecoderConfig | |||
from .auto_factory import _LazyAutoMapping | |||
from .configuration_auto import ( | |||
CONFIG_MAPPING_NAMES, | |||
AutoConfig, | |||
config_class_to_model_type, | |||
model_type_to_module_name, | |||
replace_list_option_in_docstrings, | |||
) | |||
from fastNLP.core.log import logger | |||
if TYPE_CHECKING: | |||
# This significantly improves completion suggestion performance when | |||
@@ -34,4 +51,297 @@ else: | |||
("bert", ("BertTokenizer", None)), | |||
("gpt2", ("GPT2Tokenizer", None)), | |||
] | |||
) | |||
) | |||
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES) | |||
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()} | |||
def tokenizer_class_from_name(class_name: str): | |||
if class_name == "PreTrainedTokenizerFast": | |||
raise RuntimeError("fastNLP does not support TokenizerFast now.") | |||
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): | |||
if class_name in tokenizers: | |||
module_name = model_type_to_module_name(module_name) | |||
try: | |||
module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") | |||
except ImportError: | |||
raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.") | |||
return getattr(module, class_name) | |||
return None | |||
def get_tokenizer_config( | |||
pretrained_model_name_or_path: Union[str, os.PathLike], | |||
cache_dir: Optional[Union[str, os.PathLike]] = None, | |||
force_download: bool = False, | |||
resume_download: bool = False, | |||
proxies: Optional[Dict[str, str]] = None, | |||
use_auth_token: Optional[Union[bool, str]] = None, | |||
revision: Optional[str] = None, | |||
local_files_only: bool = False, | |||
**kwargs, | |||
): | |||
""" | |||
Loads the tokenizer configuration from a pretrained model tokenizer configuration. | |||
Args: | |||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||
This can be either: | |||
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on | |||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or | |||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||
- a path to a `directory` containing a configuration file saved using the | |||
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``. | |||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard | |||
cache should not be used. | |||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to force to (re-)download the configuration files and override the cached versions if they | |||
exist. | |||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. | |||
proxies (:obj:`Dict[str, str]`, `optional`): | |||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. | |||
use_auth_token (:obj:`str` or `bool`, `optional`): | |||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token | |||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). | |||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||
identifier allowed by git. | |||
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
If :obj:`True`, will only try to load the tokenizer configuration from local files. | |||
.. note:: | |||
Passing :obj:`use_auth_token=True` is required when you want to use a private model. | |||
Returns: | |||
:obj:`Dict`: The configuration of the tokenizer. | |||
Examples:: | |||
# Download configuration from huggingface.co and cache. | |||
tokenizer_config = get_tokenizer_config("bert-base-uncased") | |||
# This model does not have a tokenizer config so the result will be an empty dict. | |||
tokenizer_config = get_tokenizer_config("xlm-roberta-base") | |||
# Save a pretrained tokenizer locally and you can reload its config | |||
from transformers import AutoTokenizer | |||
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |||
tokenizer.save_pretrained("tokenizer-test") | |||
tokenizer_config = get_tokenizer_config("tokenizer-test") | |||
""" | |||
if is_offline_mode() and not local_files_only: | |||
logger.info("Offline mode: forcing local_files_only=True") | |||
local_files_only = True | |||
pretrained_model_name_or_path = str(pretrained_model_name_or_path) | |||
if os.path.isdir(pretrained_model_name_or_path): | |||
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) | |||
else: | |||
config_file = hf_bucket_url( | |||
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None | |||
) | |||
try: | |||
# Load from URL or cache if already cached | |||
resolved_config_file = cached_path( | |||
config_file, | |||
cache_dir=cache_dir, | |||
force_download=force_download, | |||
proxies=proxies, | |||
resume_download=resume_download, | |||
local_files_only=local_files_only, | |||
use_auth_token=use_auth_token, | |||
) | |||
except EnvironmentError: | |||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") | |||
return {} | |||
with open(resolved_config_file, encoding="utf-8") as reader: | |||
return json.load(reader) | |||
class AutoTokenizer: | |||
r""" | |||
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when | |||
created with the :meth:`AutoTokenizer.from_pretrained` class method. | |||
This class cannot be instantiated directly using ``__init__()`` (throws an error). | |||
""" | |||
def __init__(self): | |||
raise EnvironmentError( | |||
"AutoTokenizer is designed to be instantiated " | |||
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method." | |||
) | |||
@classmethod | |||
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES) | |||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): | |||
r""" | |||
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary. | |||
The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object | |||
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's | |||
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||
List options | |||
Params: | |||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||
Can be either: | |||
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co. | |||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||
a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved | |||
using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., | |||
``./my_model_directory/``. | |||
- A path or url to a single saved vocabulary file if and only if the tokenizer only requires a | |||
single vocabulary file (like Bert or XLNet), e.g.: ``./my_model_directory/vocab.txt``. (Not | |||
applicable to all derived classes) | |||
inputs (additional positional arguments, `optional`): | |||
Will be passed along to the Tokenizer ``__init__()`` method. | |||
config (:class:`~transformers.PretrainedConfig`, `optional`) | |||
The configuration object used to dertermine the tokenizer class to instantiate. | |||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||
standard cache should not be used. | |||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to force the (re-)download the model weights and configuration files and override the | |||
cached versions if they exist. | |||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||
file exists. | |||
proxies (:obj:`Dict[str, str]`, `optional`): | |||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||
identifier allowed by git. | |||
subfolder (:obj:`str`, `optional`): | |||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for | |||
facebook/rag-token-base), specify it here. | |||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): | |||
Whether or not to try to load the fast version of the tokenizer. | |||
tokenizer_type (:obj:`str`, `optional`): | |||
Tokenizer type to be loaded. | |||
kwargs (additional keyword arguments, `optional`): | |||
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like | |||
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, | |||
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__()`` for more details. | |||
Examples:: | |||
>>> from transformers import AutoTokenizer | |||
>>> # Download vocabulary from huggingface.co and cache. | |||
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |||
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache. | |||
>>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased') | |||
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) | |||
>>> tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') | |||
""" | |||
config = kwargs.pop("config", None) | |||
kwargs["_from_auto"] = True | |||
use_fast = kwargs.pop("use_fast", True) | |||
tokenizer_type = kwargs.pop("tokenizer_type", None) | |||
# First, let's see whether the tokenizer_type is passed so that we can leverage it | |||
if tokenizer_type is not None: | |||
tokenizer_class = None | |||
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None) | |||
if tokenizer_class_tuple is None: | |||
raise ValueError( | |||
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of " | |||
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}." | |||
) | |||
tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple | |||
if use_fast and tokenizer_fast_class_name is not None: | |||
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name) | |||
if tokenizer_class is None: | |||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name) | |||
if tokenizer_class is None: | |||
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.") | |||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||
# Next, let's try to use the tokenizer_config file to get the tokenizer class. | |||
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) | |||
config_tokenizer_class = tokenizer_config.get("tokenizer_class") | |||
# If that did not work, let's try to use the config. | |||
if config_tokenizer_class is None: | |||
if not isinstance(config, PretrainedConfig): | |||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |||
config_tokenizer_class = config.tokenizer_class | |||
# If we have the tokenizer class from the tokenizer config or the model config we're good! | |||
if config_tokenizer_class is not None: | |||
tokenizer_class = None | |||
if use_fast and not config_tokenizer_class.endswith("Fast"): | |||
tokenizer_class_candidate = f"{config_tokenizer_class}Fast" | |||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) | |||
if tokenizer_class is None: | |||
tokenizer_class_candidate = config_tokenizer_class | |||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) | |||
if tokenizer_class is None: | |||
raise ValueError( | |||
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported." | |||
) | |||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||
# Otherwise we have to be creative. | |||
# if model is an encoder decoder, the encoder tokenizer class is used by default | |||
if isinstance(config, EncoderDecoderConfig): | |||
if type(config.decoder) is not type(config.encoder): # noqa: E721 | |||
logger.warning( | |||
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model " | |||
f"config class: {config.decoder.__class__}. It is not recommended to use the " | |||
"`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder " | |||
"specific tokenizer classes." | |||
) | |||
config = config.encoder | |||
model_type = config_class_to_model_type(type(config).__name__) | |||
if model_type is not None: | |||
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)] | |||
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None): | |||
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||
else: | |||
if tokenizer_class_py is not None: | |||
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) | |||
else: | |||
raise ValueError( | |||
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed " | |||
"in order to use this tokenizer." | |||
) | |||
raise ValueError( | |||
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n" | |||
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}." | |||
) |
@@ -0,0 +1,5 @@ | |||
__all__ = [ | |||
"EncoderDecoderConfig", | |||
] | |||
from .configuration_encoder_decoder import EncoderDecoderConfig |
@@ -0,0 +1,114 @@ | |||
# coding=utf-8 | |||
# Copyright 2020 The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import copy | |||
from ...configuration_utils import PretrainedConfig | |||
from fastNLP.core.log import logger | |||
class EncoderDecoderConfig(PretrainedConfig): | |||
r""" | |||
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a | |||
:class:`~transformers.EncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to the | |||
specified arguments, defining the encoder and decoder configs. | |||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model | |||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. | |||
Args: | |||
kwargs (`optional`): | |||
Dictionary of keyword arguments. Notably: | |||
- **encoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration | |||
object that defines the encoder config. | |||
- **decoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration | |||
object that defines the decoder config. | |||
Examples:: | |||
>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel | |||
>>> # Initializing a BERT bert-base-uncased style configuration | |||
>>> config_encoder = BertConfig() | |||
>>> config_decoder = BertConfig() | |||
>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) | |||
>>> # Initializing a Bert2Bert model from the bert-base-uncased style configurations | |||
>>> model = EncoderDecoderModel(config=config) | |||
>>> # Accessing the model configuration | |||
>>> config_encoder = model.config.encoder | |||
>>> config_decoder = model.config.decoder | |||
>>> # set decoder config to causal lm | |||
>>> config_decoder.is_decoder = True | |||
>>> config_decoder.add_cross_attention = True | |||
>>> # Saving the model, including its configuration | |||
>>> model.save_pretrained('my-model') | |||
>>> # loading model and config from pretrained folder | |||
>>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model') | |||
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config) | |||
""" | |||
model_type = "encoder-decoder" | |||
is_composition = True | |||
def __init__(self, **kwargs): | |||
super().__init__(**kwargs) | |||
assert ( | |||
"encoder" in kwargs and "decoder" in kwargs | |||
), "Config has to be initialized with encoder and decoder config" | |||
encoder_config = kwargs.pop("encoder") | |||
encoder_model_type = encoder_config.pop("model_type") | |||
decoder_config = kwargs.pop("decoder") | |||
decoder_model_type = decoder_config.pop("model_type") | |||
from ..auto.configuration_auto import AutoConfig | |||
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config) | |||
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config) | |||
self.is_encoder_decoder = True | |||
@classmethod | |||
def from_encoder_decoder_configs( | |||
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs | |||
) -> PretrainedConfig: | |||
r""" | |||
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model | |||
configuration and decoder model configuration. | |||
Returns: | |||
:class:`EncoderDecoderConfig`: An instance of a configuration object | |||
""" | |||
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") | |||
decoder_config.is_decoder = True | |||
decoder_config.add_cross_attention = True | |||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) | |||
def to_dict(self): | |||
""" | |||
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`. | |||
Returns: | |||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, | |||
""" | |||
output = copy.deepcopy(self.__dict__) | |||
output["encoder"] = self.encoder.to_dict() | |||
output["decoder"] = self.decoder.to_dict() | |||
output["model_type"] = self.__class__.model_type | |||
return output |
@@ -75,12 +75,12 @@ class TestPaddleDriverFunctions: | |||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||
""" | |||
dataloader = DataLoader(PaddleNormalDataset()) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建torch的dataloader | |||
dataloader = torch.utils.data.DataLoader( | |||
@@ -88,7 +88,7 @@ class TestPaddleDriverFunctions: | |||
batch_size=32, shuffle=True | |||
) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
@pytest.mark.torchpaddle | |||
def test_check_dataloader_legality_in_test(self): | |||
@@ -100,7 +100,7 @@ class TestPaddleDriverFunctions: | |||
"train": DataLoader(PaddleNormalDataset()), | |||
"test":DataLoader(PaddleNormalDataset()) | |||
} | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = { | |||
@@ -108,12 +108,12 @@ class TestPaddleDriverFunctions: | |||
"test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||
} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 传入的不是 dict ,应该报错 | |||
dataloader = DataLoader(PaddleNormalDataset()) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建 torch 的 dataloader | |||
train_loader = torch.utils.data.DataLoader( | |||
@@ -126,7 +126,7 @@ class TestPaddleDriverFunctions: | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
@pytest.mark.paddle | |||
def test_tensor_to_numeric(self): | |||
@@ -183,6 +183,21 @@ class TestPaddleDriverFunctions: | |||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||
@pytest.mark.paddle | |||
def test_tensor_to_numeric_reduce(self): | |||
tensor = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||
res_max = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||
res_min = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||
res_sum = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||
res_mean = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||
assert res_max == 6 | |||
assert res_min == 1 | |||
assert res_sum == 21 | |||
assert res_mean == 3.5 | |||
@pytest.mark.paddle | |||
def test_set_model_mode(self): | |||
""" | |||
测试 set_model_mode 函数 | |||
@@ -117,7 +117,7 @@ class TestTorchDriverFunctions: | |||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||
""" | |||
dataloader = DataLoader(TorchNormalDataset()) | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建 paddle 的 dataloader | |||
dataloader = paddle.io.DataLoader( | |||
@@ -125,7 +125,7 @@ class TestTorchDriverFunctions: | |||
batch_size=32, shuffle=True | |||
) | |||
with pytest.raises(ValueError): | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
@pytest.mark.torchpaddle | |||
def test_check_dataloader_legality_in_test(self): | |||
@@ -137,12 +137,12 @@ class TestTorchDriverFunctions: | |||
"train": DataLoader(TorchNormalDataset()), | |||
"test": DataLoader(TorchNormalDataset()) | |||
} | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 传入的不是 dict,应该报错 | |||
dataloader = DataLoader(TorchNormalDataset()) | |||
with pytest.raises(ValueError): | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
# 创建 paddle 的 dataloader | |||
train_loader = paddle.io.DataLoader( | |||
@@ -155,7 +155,7 @@ class TestTorchDriverFunctions: | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with pytest.raises(ValueError): | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
@pytest.mark.torch | |||
def test_tensor_to_numeric(self): | |||
@@ -212,6 +212,20 @@ class TestTorchDriverFunctions: | |||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||
@pytest.mark.torch | |||
def test_tensor_to_numeric_reduce(self): | |||
tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||
res_max = TorchSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||
res_min = TorchSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||
res_sum = TorchSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||
res_mean = TorchSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||
assert res_max == 6 | |||
assert res_min == 1 | |||
assert res_sum == 21 | |||
assert res_mean == 3.5 | |||
@pytest.mark.torch | |||
def test_set_model_mode(self): | |||
""" | |||
测试set_model_mode函数 | |||
@@ -78,7 +78,7 @@ class TestSeqLenToMask: | |||
mask = seq_len_to_mask(seq_len) | |||
# 3. pad到指定长度 | |||
seq_len = paddle.randint(1, 10, size=(10,)) | |||
seq_len = paddle.randint(1, 10, shape=(10,)) | |||
mask = seq_len_to_mask(seq_len, 100) | |||
assert 100 == mask.shape[1] | |||