Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh 3 years ago
parent
commit
1ec3337a49
17 changed files with 676 additions and 149 deletions
  1. +17
    -7
      fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py
  2. +1
    -2
      fastNLP/core/drivers/driver.py
  3. +123
    -45
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  4. +14
    -1
      fastNLP/core/drivers/jittor_driver/mpi.py
  5. +36
    -18
      fastNLP/core/drivers/jittor_driver/single_device.py
  6. +0
    -50
      fastNLP/core/drivers/jittor_driver/utils.py
  7. +4
    -5
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  8. +1
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  9. +2
    -1
      fastNLP/core/utils/paddle_utils.py
  10. +1
    -1
      fastNLP/core/utils/seq_len_to_mask.py
  11. +3
    -2
      fastNLP/transformers/torch/models/auto/__init__.py
  12. +313
    -3
      fastNLP/transformers/torch/models/auto/tokenization_auto.py
  13. +5
    -0
      fastNLP/transformers/torch/models/encoder_decoder/__init__.py
  14. +114
    -0
      fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py
  15. +22
    -7
      tests/core/drivers/paddle_driver/test_single_device.py
  16. +19
    -5
      tests/core/drivers/torch_driver/test_single_device.py
  17. +1
    -1
      tests/core/utils/test_seq_len_to_mask.py

+ 17
- 7
fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py View File

@@ -5,6 +5,7 @@ __all__ = [
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping

import numpy as np
from pkg_resources import parse_version

from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler
@@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.collators import Collator

if _NEED_IMPORT_TORCH:
from torch import __version__ as torchversion
from torch.utils.data import DataLoader, Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
@@ -213,13 +215,21 @@ class MixDataLoader(DataLoader):
else:
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler")

super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
prefetch_factor=2, persistent_workers=False
)
if parse_version(torchversion) >= parse_version('1.7'):
super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
prefetch_factor=2, persistent_workers=False
)
else:
super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
)

def __iter__(self):
return super().__iter__()

+ 1
- 2
fastNLP/core/drivers/driver.py View File

@@ -41,7 +41,7 @@ class Driver(ABC):
r"""
根据输入的 ``dataloader`` 得到一个 支持分布式 (``distributed``) 与 可复现的 (``reproducible``) 的 dataloader。

:param dataloader: 根据 ``dataloade``r 设置其对应的分布式版本以及可复现版本;
:param dataloader: 根据 ``dataloader`` 设置其对应的分布式版本以及可复现版本;
:param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``;为 ``None`` 时,表示不需要考虑当前 dataloader
切换为分布式状态;为 ``dist`` 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在
不同 gpu 上出现重复;为 ``unrepeatdist`` 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
@@ -263,7 +263,6 @@ class Driver(ABC):
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象;
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。
:return: 返回加载指定文件后的结果;
"""
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.")



+ 123
- 45
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -1,13 +1,17 @@
import os
import warnings
from typing import Optional, Callable, Dict
import random
from pathlib import Path
from typing import Union, Optional
from functools import partial

import numpy as np

from .utils import _build_fp16_env
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS

if _NEED_IMPORT_JITTOR:
import jittor as jt
@@ -47,17 +51,18 @@ class JittorDriver(Driver):
f"`jittor.Module` type.")
super(JittorDriver, self).__init__(model)

self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()
if fp16:
jt.flags.auto_mixed_precision_level = 6
else:
jt.flags.auto_mixed_precision_level = 0
self.fp16 = fp16

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

def check_dataloader_legality(self, dataloader):
# 在fastnlp中实现了JittorDataLoader
if not isinstance(dataloader, Dataset):
raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`")

if not isinstance(dataloader, (Dataset, JittorDataLoader)):
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`")

@staticmethod
def _check_optimizer_legality(optimizers):
@@ -66,54 +71,102 @@ class JittorDriver(Driver):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, "
f"not {type(each_optimizer)}.")

def check_evaluator_mode(self, mode: str):
def step(self):
for optimizer in self.optimizers:
optimizer.step()

def backward(self, loss):
for optimizer in self.optimizers:
optimizer.backward(loss)

def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()

def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
r"""
将模型保存到 ``filepath`` 中。

:param filepath: 保存文件的文件位置(需要包括文件名);
:param only_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持保存模型的 ``state_dict``。
"""
if not only_state_dict:
logger.rank_zero_warning(
"Jittor only supports saving state_dict, and we will also save state_dict for you.",
once=True
)
if isinstance(filepath, Path):
filepath = str(filepath)
model = self.unwrap_model()
if mode == "evaluate":
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"):
logger.warning_once(
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for"
"'evaluate_step'.")
model.save(filepath)

else:
if not hasattr(model, "test_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.")

def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None):
if model_save_fn is not None:
outputs = model_save_fn(filepath)
if outputs is not None:
jt.save(outputs, filepath)
else:
if only_state_dict:
states = self.model.state_dict()
else:
warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.")
jt.save(states, filepath)
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
r"""
加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。

def load_model(self, filepath: str):
if not os.path.exists(filepath):
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath))
return jt.load(filepath)
:param filepath: 保存文件的文件位置(需要包括文件名);
:param load_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持加载模型的 ``state_dict``。
"""
if not only_state_dict:
logger.rank_zero_warning(
"Jittor only supports loading state_dict, and we will also load state_dict for you.",
once=True
)
if isinstance(filepath, Path):
filepath = str(filepath)
model = self.unwrap_model()
model.load(filepath)

def save_checkpoint(self):
...

def get_optimizer_state(self):
# optimizers_state_dict = {}
# for i in range(len(self.optimizers)):
# optimizer: torch.optim.Optimizer = self.optimizers[i]
# optimizer_state = optimizer.state_dict()
# optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
# optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
# return optimizers_state_dict
...

def load_optimizer_state(self, states):
# assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
# f"checkpoint it is:{len(states)}"
# for i in range(len(self.optimizers)):
# optimizer: torch.optim.Optimizer = self.optimizers[i]
# optimizer.load_state_dict(states[f"optimizer{i}"])
# logger.debug("Load optimizer state dict.")
...

def load_checkpoint(self):
...

def get_evaluate_context(self):
return jt.no_grad

def get_model_device(self):
return self.model_device
@staticmethod
def move_model_to_device(model: "jt.Module", device):
r"""
将模型转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。
"""
...

def move_data_to_device(self, batch):
"""
将数据 ``batch`` 转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。
"""
return batch

@staticmethod
def tensor_to_numeric(tensor, reduce=None):
r"""
将一个 :class:`jittor.Var` 对象转换为 转换成 python 中的数值类型;

:param tensor: :class:`jittor.Var` 类型的对象;
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``;
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等;
"""
if tensor is None:
return None

@@ -145,7 +198,32 @@ class JittorDriver(Driver):
"""
return batch

# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx):
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
# dataloader.batch_sampler.set_epoch(cur_epoch_idx)
@staticmethod
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
process_seed = jt.get_seed()
# back out the base seed so we can use all the bits
base_seed = process_seed - worker_id
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
jittor_ss, stdlib_ss = ss.spawn(2)
jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)

def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]):
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(self.worker_init_function,
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)))

def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int):
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
if callable(getattr(dataloader.sampler, "set_epoch", None)):
dataloader.sampler.set_epoch(cur_epoch_idx)

@staticmethod
def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]):
pass

+ 14
- 1
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -146,7 +146,10 @@ class JittorMPIDriver(JittorDriver):
return self.model.no_sync

def unwrap_model(self):
pass
"""
返回训练使用的模型。
"""
return self.model

def get_local_rank(self) -> int:
return self.local_rank
@@ -155,4 +158,14 @@ class JittorMPIDriver(JittorDriver):
pass

def is_distributed(self):
"""
判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``True``。
"""
return True

@property
def data_device(self) -> str:
"""
:return: 数据所在的设备;
"""
return self.model_device

+ 36
- 18
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -27,28 +27,36 @@ class JittorSingleDriver(JittorDriver):
支持 cpu 和 gpu 的切换;
实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数

:param model: 传入给 ``Trainer`` 的 ``model`` 参数;
:param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``;
* 为 ``None`` 或 ``cpu`` 时
表示在 ``cpu`` 上进行训练;
* 为 ``gpu`` 或 ``cuda`` 时
表示在显卡设备上进行训练;

:param fp16: 是否开启 fp16;
"""

def __init__(self, model, device=None, fp16: bool = False, **kwargs):
if device not in [None, "cpu", "gpu", "cuda"]:
raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .")
super(JittorSingleDriver, self).__init__(model, fp16)

self.model_device = device
self.model_device = device if device is not None else "cpu"

self.local_rank = 0
self.global_rank = 0
self.world_size = 1

def step(self):
for optimizer in self.optimizers:
optimizer.step()

def backward(self, loss):
for optimizer in self.optimizers:
optimizer.backward(loss)

def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
def setup(self):
r"""
初始化训练环境;根据传入的 ``device`` 值设置模型的训练场景为 ``cpu`` 或 ``gpu``;
"""
if self.model_device in ["cpu", None]:
jt.flags.use_cuda = 0 # 使用 cpu
else:
jt.flags.use_cuda = 1 # 使用 cuda

def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
@@ -70,9 +78,15 @@ class JittorSingleDriver(JittorDriver):
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def unwrap_model(self):
"""
返回训练使用的模型。
"""
return self.model

def is_distributed(self):
"""
判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``False``。
"""
return False

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
@@ -103,11 +117,15 @@ class JittorSingleDriver(JittorDriver):
else:
return dataloader

def setup(self):
def unwrap_model(self):
"""
支持 cpu 和 gpu 的切换
返回训练使用的模型。
"""
if self.model_device in ["cpu", None]:
jt.flags.use_cuda = 0 # 使用 cpu
else:
jt.flags.use_cuda = 1 # 使用 cuda
return self.model

@property
def data_device(self) -> str:
"""
:return: 数据和模型所在的设备;
"""
return self.model_device

+ 0
- 50
fastNLP/core/drivers/jittor_driver/utils.py View File

@@ -1,56 +1,6 @@
from contextlib import ExitStack

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR

if _NEED_IMPORT_JITTOR:
import jittor

__all__ = []

class DummyGradScaler:
"""
用于仿造的 **GradScaler** 对象,防止重复写大量的if判断
"""
def __init__(self, *args, **kwargs):
pass

def get_scale(self):
return 1.0

def is_enabled(self):
return False

def scale(self, outputs):
return outputs

def step(self, optimizer, *args, **kwargs):
optimizer.step(*args, **kwargs)

def update(self, new_scale=None):
pass

def unscale_(self, optimizer):
pass

def load_state_dict(self, state_dict):
pass

def state_dict(self):
return {}


def _build_fp16_env(dummy=False):
if dummy:
auto_cast = ExitStack
GradScaler = DummyGradScaler
else:
raise NotImplementedError("JittorDriver does not support fp16 now.")
# if not jt.flags.use_cuda:
# raise RuntimeError("No cuda")
# if paddle.device.cuda.get_device_capability(0)[0] < 7:
# log.warning(
# "NOTE: your device does NOT support faster training with fp16, "
# "please switch to FP32 which is likely to be faster"
# )
# from paddle.amp import auto_cast, GradScaler
return auto_cast, GradScaler

+ 4
- 5
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -113,12 +113,11 @@ class PaddleDriver(Driver):
@staticmethod
def tensor_to_numeric(tensor, reduce=None):
r"""
将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个元素则返回 float 或 int 。
将一个 :class:`paddle.Tensor` 对象转换为 转换成 python 中的数值类型;

:param tensor: 需要被转换的 `tensor` 对象
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回
float 或 int 对象。
:return: 转换后返回的结果
:param tensor: :class:`paddle.Tensor` 类型的对象;
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``;
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等;
"""
if tensor is None:
return None


+ 1
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH:
from torch.optim import Optimizer
from torch.utils.data import RandomSampler as TorchRandomSampler
_reduces = {
'sum': torch.max,
'sum': torch.sum,
'min': torch.min,
'max': torch.max,
'mean': torch.mean


+ 2
- 1
fastNLP/core/utils/paddle_utils.py View File

@@ -68,7 +68,8 @@ def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_
该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。

:param data: 要迁移的张量;
:param device: 目标设备,可以是 ``str`` 或 ``int`` 类型;
:param device: 目标设备,可以是 ``str`` 或 ``int`` 及 **paddle** 自己的 :class:`paddle.fluid.core_avx.Place`、
:class:`paddle.CPUPlace` 和 :class:`paddle.CUDAPlace` 类型;
:return: 迁移后的张量;
"""
if isinstance(device, paddle.fluid.core_avx.Place):


+ 1
- 1
fastNLP/core/utils/seq_len_to_mask.py View File

@@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None):
if isinstance(seq_len, jittor.Var):
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
batch_size = seq_len.shape[0]
broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1)
broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1)
mask = broad_cast_seq_len < seq_len.unsqueeze(1)
return mask
except NameError as e:


+ 3
- 2
fastNLP/transformers/torch/models/auto/__init__.py View File

@@ -3,7 +3,8 @@ __all__ = [
"CONFIG_MAPPING",
"MODEL_NAMES_MAPPING",
"AutoConfig",
"TOKENIZER_MAPPING_NAMES",
"TOKENIZER_MAPPING",
"AutoTokenizer",
"get_values",
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
@@ -43,7 +44,7 @@ __all__ = [

from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \
AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING_NAMES
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .auto_factory import get_values
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,


+ 313
- 3
fastNLP/transformers/torch/models/auto/tokenization_auto.py View File

@@ -13,14 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Tokenizer class. """

import importlib
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

from ...configuration_utils import PretrainedConfig
from ...file_utils import (
cached_path,
hf_bucket_url,
is_offline_mode,
is_sentencepiece_available,
is_tokenizers_available,
)
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from fastNLP.core.log import logger

if TYPE_CHECKING:
# This significantly improves completion suggestion performance when
@@ -34,4 +51,297 @@ else:
("bert", ("BertTokenizer", None)),
("gpt2", ("GPT2Tokenizer", None)),
]
)
)

TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)

CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}


def tokenizer_class_from_name(class_name: str):
if class_name == "PreTrainedTokenizerFast":
raise RuntimeError("fastNLP does not support TokenizerFast now.")

for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
module_name = model_type_to_module_name(module_name)

try:
module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models")
except ImportError:
raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.")
return getattr(module, class_name)

return None


def get_tokenizer_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:

- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.

cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, will only try to load the tokenizer configuration from local files.

.. note::

Passing :obj:`use_auth_token=True` is required when you want to use a private model.


Returns:
:obj:`Dict`: The configuration of the tokenizer.

Examples::

# Download configuration from huggingface.co and cache.
tokenizer_config = get_tokenizer_config("bert-base-uncased")
# This model does not have a tokenizer config so the result will be an empty dict.
tokenizer_config = get_tokenizer_config("xlm-roberta-base")

# Save a pretrained tokenizer locally and you can reload its config
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None
)

try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)

except EnvironmentError:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}

with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)


class AutoTokenizer:
r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
created with the :meth:`AutoTokenizer.from_pretrained` class method.

This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""

def __init__(self):
raise EnvironmentError(
"AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)

@classmethod
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r"""
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.

The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

Params:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved
using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.,
``./my_model_directory/``.
- A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
single vocabulary file (like Bert or XLNet), e.g.: ``./my_model_directory/vocab.txt``. (Not
applicable to all derived classes)
inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`)
The configuration object used to dertermine the tokenizer class to instantiate.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download the model weights and configuration files and override the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
subfolder (:obj:`str`, `optional`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
facebook/rag-token-base), specify it here.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to try to load the fast version of the tokenizer.
tokenizer_type (:obj:`str`, `optional`):
Tokenizer type to be loaded.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__()`` for more details.

Examples::

>>> from transformers import AutoTokenizer

>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased')

>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
>>> tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')

"""
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True

use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None)

# First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None:
tokenizer_class = None
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)

if tokenizer_class_tuple is None:
raise ValueError(
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
)

tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple

if use_fast and tokenizer_fast_class_name is not None:
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)

if tokenizer_class is None:
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

if tokenizer_class is None:
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")

return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")

# If that did not work, let's try to use the config.
if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = config.tokenizer_class

# If we have the tokenizer class from the tokenizer config or the model config we're good!
if config_tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
tokenizer_class_candidate = config_tokenizer_class
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)

if tokenizer_class is None:
raise ValueError(
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
)
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

# Otherwise we have to be creative.
# if model is an encoder decoder, the encoder tokenizer class is used by default
if isinstance(config, EncoderDecoderConfig):
if type(config.decoder) is not type(config.encoder): # noqa: E721
logger.warning(
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
f"config class: {config.decoder.__class__}. It is not recommended to use the "
"`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
"specific tokenizer classes."
)
config = config.encoder

model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
if tokenizer_class_py is not None:
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
"in order to use this tokenizer."
)

raise ValueError(
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
)

+ 5
- 0
fastNLP/transformers/torch/models/encoder_decoder/__init__.py View File

@@ -0,0 +1,5 @@
__all__ = [
"EncoderDecoderConfig",
]

from .configuration_encoder_decoder import EncoderDecoderConfig

+ 114
- 0
fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py View File

@@ -0,0 +1,114 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from ...configuration_utils import PretrainedConfig
from fastNLP.core.log import logger

class EncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.EncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to the
specified arguments, defining the encoder and decoder configs.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.

Args:
kwargs (`optional`):
Dictionary of keyword arguments. Notably:

- **encoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
object that defines the encoder config.
- **decoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
object that defines the decoder config.

Examples::

>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel

>>> # Initializing a BERT bert-base-uncased style configuration
>>> config_encoder = BertConfig()
>>> config_decoder = BertConfig()

>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

>>> # Initializing a Bert2Bert model from the bert-base-uncased style configurations
>>> model = EncoderDecoderModel(config=config)

>>> # Accessing the model configuration
>>> config_encoder = model.config.encoder
>>> config_decoder = model.config.decoder
>>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True

>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')

>>> # loading model and config from pretrained folder
>>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model')
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
"""
model_type = "encoder-decoder"
is_composition = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
assert (
"encoder" in kwargs and "decoder" in kwargs
), "Config has to be initialized with encoder and decoder config"
encoder_config = kwargs.pop("encoder")
encoder_model_type = encoder_config.pop("model_type")
decoder_config = kwargs.pop("decoder")
decoder_model_type = decoder_config.pop("model_type")

from ..auto.configuration_auto import AutoConfig

self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.is_encoder_decoder = True

@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model
configuration and decoder model configuration.

Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True

return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.

Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output

+ 22
- 7
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -75,12 +75,12 @@ class TestPaddleDriverFunctions:
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
"""
dataloader = DataLoader(PaddleNormalDataset())
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
@@ -88,7 +88,7 @@ class TestPaddleDriverFunctions:
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")

@pytest.mark.torchpaddle
def test_check_dataloader_legality_in_test(self):
@@ -100,7 +100,7 @@ class TestPaddleDriverFunctions:
"train": DataLoader(PaddleNormalDataset()),
"test":DataLoader(PaddleNormalDataset())
}
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = {
@@ -108,12 +108,12 @@ class TestPaddleDriverFunctions:
"test":DataLoader(PaddleNormalDataset(), batch_size=None)
}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# 传入的不是 dict ,应该报错
dataloader = DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# 创建 torch 的 dataloader
train_loader = torch.utils.data.DataLoader(
@@ -126,7 +126,7 @@ class TestPaddleDriverFunctions:
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")

@pytest.mark.paddle
def test_tensor_to_numeric(self):
@@ -183,6 +183,21 @@ class TestPaddleDriverFunctions:
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()

@pytest.mark.paddle
def test_tensor_to_numeric_reduce(self):
tensor = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

res_max = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="max")
res_min = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="min")
res_sum = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="sum")
res_mean = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="mean")

assert res_max == 6
assert res_min == 1
assert res_sum == 21
assert res_mean == 3.5


@pytest.mark.paddle
def test_set_model_mode(self):
"""
测试 set_model_mode 函数


+ 19
- 5
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -117,7 +117,7 @@ class TestTorchDriverFunctions:
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
"""
dataloader = DataLoader(TorchNormalDataset())
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# 创建 paddle 的 dataloader
dataloader = paddle.io.DataLoader(
@@ -125,7 +125,7 @@ class TestTorchDriverFunctions:
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")

@pytest.mark.torchpaddle
def test_check_dataloader_legality_in_test(self):
@@ -137,12 +137,12 @@ class TestTorchDriverFunctions:
"train": DataLoader(TorchNormalDataset()),
"test": DataLoader(TorchNormalDataset())
}
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# 传入的不是 dict,应该报错
dataloader = DataLoader(TorchNormalDataset())
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")

# 创建 paddle 的 dataloader
train_loader = paddle.io.DataLoader(
@@ -155,7 +155,7 @@ class TestTorchDriverFunctions:
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")

@pytest.mark.torch
def test_tensor_to_numeric(self):
@@ -212,6 +212,20 @@ class TestTorchDriverFunctions:
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()

@pytest.mark.torch
def test_tensor_to_numeric_reduce(self):
tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

res_max = TorchSingleDriver.tensor_to_numeric(tensor, reduce="max")
res_min = TorchSingleDriver.tensor_to_numeric(tensor, reduce="min")
res_sum = TorchSingleDriver.tensor_to_numeric(tensor, reduce="sum")
res_mean = TorchSingleDriver.tensor_to_numeric(tensor, reduce="mean")

assert res_max == 6
assert res_min == 1
assert res_sum == 21
assert res_mean == 3.5

@pytest.mark.torch
def test_set_model_mode(self):
"""
测试set_model_mode函数


+ 1
- 1
tests/core/utils/test_seq_len_to_mask.py View File

@@ -78,7 +78,7 @@ class TestSeqLenToMask:
mask = seq_len_to_mask(seq_len)

# 3. pad到指定长度
seq_len = paddle.randint(1, 10, size=(10,))
seq_len = paddle.randint(1, 10, shape=(10,))
mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.shape[1]



Loading…
Cancel
Save