Browse Source

transformers 添加 AutoTokenizer

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
e734d11d3b
4 changed files with 435 additions and 5 deletions
  1. +3
    -2
      fastNLP/transformers/torch/models/auto/__init__.py
  2. +313
    -3
      fastNLP/transformers/torch/models/auto/tokenization_auto.py
  3. +5
    -0
      fastNLP/transformers/torch/models/encoder_decoder/__init__.py
  4. +114
    -0
      fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py

+ 3
- 2
fastNLP/transformers/torch/models/auto/__init__.py View File

@@ -3,7 +3,8 @@ __all__ = [
"CONFIG_MAPPING",
"MODEL_NAMES_MAPPING",
"AutoConfig",
"TOKENIZER_MAPPING_NAMES",
"TOKENIZER_MAPPING",
"AutoTokenizer",
"get_values",
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
@@ -43,7 +44,7 @@ __all__ = [

from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \
AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING_NAMES
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .auto_factory import get_values
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,


+ 313
- 3
fastNLP/transformers/torch/models/auto/tokenization_auto.py View File

@@ -13,14 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Tokenizer class. """

import importlib
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

from ...configuration_utils import PretrainedConfig
from ...file_utils import (
cached_path,
hf_bucket_url,
is_offline_mode,
is_sentencepiece_available,
is_tokenizers_available,
)
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from fastNLP.core.log import logger

if TYPE_CHECKING:
# This significantly improves completion suggestion performance when
@@ -34,4 +51,297 @@ else:
("bert", ("BertTokenizer", None)),
("gpt2", ("GPT2Tokenizer", None)),
]
)
)

TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)

CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}


def tokenizer_class_from_name(class_name: str):
if class_name == "PreTrainedTokenizerFast":
raise RuntimeError("fastNLP does not support TokenizerFast now.")

for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
module_name = model_type_to_module_name(module_name)

try:
module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models")
except ImportError:
raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.")
return getattr(module, class_name)

return None


def get_tokenizer_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:

- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.

cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, will only try to load the tokenizer configuration from local files.

.. note::

Passing :obj:`use_auth_token=True` is required when you want to use a private model.


Returns:
:obj:`Dict`: The configuration of the tokenizer.

Examples::

# Download configuration from huggingface.co and cache.
tokenizer_config = get_tokenizer_config("bert-base-uncased")
# This model does not have a tokenizer config so the result will be an empty dict.
tokenizer_config = get_tokenizer_config("xlm-roberta-base")

# Save a pretrained tokenizer locally and you can reload its config
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None
)

try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)

except EnvironmentError:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}

with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)


class AutoTokenizer:
r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
created with the :meth:`AutoTokenizer.from_pretrained` class method.

This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""

def __init__(self):
raise EnvironmentError(
"AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)

@classmethod
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r"""
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.

The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

Params:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved
using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.,
``./my_model_directory/``.
- A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
single vocabulary file (like Bert or XLNet), e.g.: ``./my_model_directory/vocab.txt``. (Not
applicable to all derived classes)
inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`)
The configuration object used to dertermine the tokenizer class to instantiate.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download the model weights and configuration files and override the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
subfolder (:obj:`str`, `optional`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
facebook/rag-token-base), specify it here.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to try to load the fast version of the tokenizer.
tokenizer_type (:obj:`str`, `optional`):
Tokenizer type to be loaded.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__()`` for more details.

Examples::

>>> from transformers import AutoTokenizer

>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased')

>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
>>> tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')

"""
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True

use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None)

# First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None:
tokenizer_class = None
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)

if tokenizer_class_tuple is None:
raise ValueError(
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
)

tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple

if use_fast and tokenizer_fast_class_name is not None:
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)

if tokenizer_class is None:
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

if tokenizer_class is None:
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")

return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")

# If that did not work, let's try to use the config.
if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = config.tokenizer_class

# If we have the tokenizer class from the tokenizer config or the model config we're good!
if config_tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
tokenizer_class_candidate = config_tokenizer_class
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)

if tokenizer_class is None:
raise ValueError(
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
)
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

# Otherwise we have to be creative.
# if model is an encoder decoder, the encoder tokenizer class is used by default
if isinstance(config, EncoderDecoderConfig):
if type(config.decoder) is not type(config.encoder): # noqa: E721
logger.warning(
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
f"config class: {config.decoder.__class__}. It is not recommended to use the "
"`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
"specific tokenizer classes."
)
config = config.encoder

model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
if tokenizer_class_py is not None:
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
"in order to use this tokenizer."
)

raise ValueError(
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
)

+ 5
- 0
fastNLP/transformers/torch/models/encoder_decoder/__init__.py View File

@@ -0,0 +1,5 @@
__all__ = [
"EncoderDecoderConfig",
]

from .configuration_encoder_decoder import EncoderDecoderConfig

+ 114
- 0
fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py View File

@@ -0,0 +1,114 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from ...configuration_utils import PretrainedConfig
from fastNLP.core.log import logger

class EncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.EncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to the
specified arguments, defining the encoder and decoder configs.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.

Args:
kwargs (`optional`):
Dictionary of keyword arguments. Notably:

- **encoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
object that defines the encoder config.
- **decoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
object that defines the decoder config.

Examples::

>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel

>>> # Initializing a BERT bert-base-uncased style configuration
>>> config_encoder = BertConfig()
>>> config_decoder = BertConfig()

>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

>>> # Initializing a Bert2Bert model from the bert-base-uncased style configurations
>>> model = EncoderDecoderModel(config=config)

>>> # Accessing the model configuration
>>> config_encoder = model.config.encoder
>>> config_decoder = model.config.decoder
>>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True

>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')

>>> # loading model and config from pretrained folder
>>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model')
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
"""
model_type = "encoder-decoder"
is_composition = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
assert (
"encoder" in kwargs and "decoder" in kwargs
), "Config has to be initialized with encoder and decoder config"
encoder_config = kwargs.pop("encoder")
encoder_model_type = encoder_config.pop("model_type")
decoder_config = kwargs.pop("decoder")
decoder_model_type = decoder_config.pop("model_type")

from ..auto.configuration_auto import AutoConfig

self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.is_encoder_decoder = True

@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model
configuration and decoder model configuration.

Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True

return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.

Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output

Loading…
Cancel
Save