Browse Source

迁移transformers的modeling_auto

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
1a1b1e8c9e
5 changed files with 1530 additions and 3 deletions
  1. +14
    -3
      fastNLP/transformers/torch/file_utils.py
  2. +83
    -0
      fastNLP/transformers/torch/models/auto/__init__.py
  3. +562
    -0
      fastNLP/transformers/torch/models/auto/auto_factory.py
  4. +208
    -0
      fastNLP/transformers/torch/models/auto/dynamic.py
  5. +663
    -0
      fastNLP/transformers/torch/models/auto/modeling_auto.py

+ 14
- 3
fastNLP/transformers/torch/file_utils.py View File

@@ -10,6 +10,8 @@ import sys
import tarfile import tarfile
import tempfile import tempfile
import operator import operator
import types
import functools
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import fields from dataclasses import fields
@@ -37,6 +39,8 @@ if _NEED_IMPORT_TORCH:
import torch import torch
_torch_version = importlib_metadata.version("torch") _torch_version = importlib_metadata.version("torch")


ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}

hf_cache_home = os.path.expanduser( hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
) )
@@ -45,10 +49,9 @@ default_cache_path = os.path.join(hf_cache_home, "transformers")
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex SESSION_ID = uuid4().hex

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}

DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES


WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
@@ -1043,3 +1046,11 @@ class TensorType(ExplicitEnum):


PYTORCH = "pt" PYTORCH = "pt"
NUMPY = "np" NUMPY = "np"

def copy_func(f):
"""Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g

+ 83
- 0
fastNLP/transformers/torch/models/auto/__init__.py View File

@@ -0,0 +1,83 @@
__all__ = [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CONFIG_MAPPING",
"MODEL_NAMES_MAPPING",
"AutoConfig",
"TOKENIZER_MAPPING",
"get_values",
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
"MODEL_FOR_PRETRAINING_MAPPING",
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
"AutoModelForAudioClassification",
"AutoModelForCausalLM",
"AutoModelForCTC",
"AutoModelForImageClassification",
"AutoModelForMaskedLM",
"AutoModelForMultipleChoice",
"AutoModelForNextSentencePrediction",
"AutoModelForObjectDetection",
"AutoModelForPreTraining",
"AutoModelForQuestionAnswering",
"AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelWithLMHead",
]

from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \
AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING
from .auto_factory import get_values
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction,
AutoModelForObjectDetection,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,
)

+ 562
- 0
fastNLP/transformers/torch/models/auto/auto_factory.py View File

@@ -0,0 +1,562 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory function to build auto-model classes."""
import importlib
from collections import OrderedDict

from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
from .dynamic import get_class_from_dynamic_module
from fastNLP.transformers.torch.configuration_utils import PretrainedConfig
from fastNLP.transformers.torch.file_utils import copy_func
from fastNLP.core.log import logger


CLASS_DOCSTRING = """
This is a generic model class that will be instantiated as one of the model classes of the library when created
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
:meth:`~transformers.BaseAutoModelClass.from_config` class method.

This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""

FROM_CONFIG_DOCSTRING = """
Instantiates one of the model classes of the library from a configuration.

Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model
weights.

Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:

List options

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('checkpoint_placeholder')
>>> model = BaseAutoModelClass.from_config(config)
"""

FROM_PRETRAINED_TORCH_DOCSTRING = """
Instantiate one of the model classes of the library from a pretrained model.

The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with ``model.train()``

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args (additional positional arguments, `optional`):
Will be passed along to the underlying model ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`):
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:

- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
state_dict (`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file.

This option can be used if you want to create a model from a pretrained configuration but load your own
weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:

- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')

>>> # Update configuration during loading
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
>>> model.config.output_attentions
True

>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json')
>>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""

FROM_PRETRAINED_TF_DOCSTRING = """
Instantiate one of the model classes of the library from a pretrained model.

The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
afterwards.
model_args (additional positional arguments, `optional`):
Will be passed along to the underlying model ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`):
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:

- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:

- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')

>>> # Update configuration during loading
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
>>> model.config.output_attentions
True

>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
"""

FROM_PRETRAINED_FLAX_DOCSTRING = """
Instantiate one of the model classes of the library from a pretrained model.

The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
afterwards.
model_args (additional positional arguments, `optional`):
Will be passed along to the underlying model ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`):
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:

- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:

- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')

>>> # Update configuration during loading
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
>>> model.config.output_attentions
True

>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
"""


def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models

name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]

# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]


class _BaseAutoModelClass:
# Base class for auto models.
_model_mapping = None

def __init__(self, *args, **kwargs):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_config(config)` methods."
)

@classmethod
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)

raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)


def insert_head_doc(docstring, head_doc=""):
if len(head_doc) > 0:
return docstring.replace(
"one of the model classes of the library ",
f"one of the model classes of the library (with a {head_doc} head) ",
)
return docstring.replace(
"one of the model classes of the library ", "one of the base model classes of the library "
)


def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""):
# Create a new class with the right name from the base class
model_mapping = cls._model_mapping
name = cls.__name__
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)

# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
# have a specific docstrings for them.
from_config = copy_func(_BaseAutoModelClass.from_config)
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
cls.from_config = classmethod(from_config)

if name.startswith("TF"):
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
elif name.startswith("Flax"):
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
else:
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained)
return cls


def get_values(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)

return result


def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)


class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.

Args:

- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""

def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._modules = {}

def __getitem__(self, key):
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)

def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr)

def keys(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]

def get(self, key, default):
try:
return self.__getitem__(key)
except KeyError:
return default

def __bool__(self):
return bool(self.keys())

def values(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]

def items(self):
return [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]

def __iter__(self):
return iter(self._mapping.keys())

def __contains__(self, item):
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping

+ 208
- 0
fastNLP/transformers/torch/models/auto/dynamic.py View File

@@ -0,0 +1,208 @@
import importlib
import os
import re
import shutil
import sys
from pathlib import Path
from typing import Dict, Optional, Union

from fastNLP.transformers.torch.file_utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
)
from fastNLP.core.log import logger

def init_hf_modules():
"""
Creates the cache directory for modules with an init, and adds it to the Python path.
"""
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
if HF_MODULES_CACHE in sys.path:
return

sys.path.append(HF_MODULES_CACHE)
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
if not init_path.exists():
init_path.touch()

def create_dynamic_module(name: Union[str, os.PathLike]):
"""
Creates a dynamic module in the cache directory for modules.
"""
init_hf_modules()
dynamic_module_path = Path(HF_MODULES_CACHE) / name
# If the parent module does not exist yet, recursively create it.
if not dynamic_module_path.parent.exists():
create_dynamic_module(dynamic_module_path.parent)
os.makedirs(dynamic_module_path, exist_ok=True)
init_path = dynamic_module_path / "__init__.py"
if not init_path.exists():
init_path.touch()

def check_imports(filename):
"""
Check if the current Python environment contains all the libraries that are imported in a file.
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()

# Imports of the form `import xxx`
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]

# Unique-ify and test we got them all
imports = list(set(imports))
missing_packages = []
for imp in imports:
try:
importlib.import_module(imp)
except ImportError:
missing_packages.append(imp)

if len(missing_packages) > 0:
raise ImportError(
"This modeling file requires the following packages that were not found in your environment: "
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
)


def get_class_in_module(class_name, module_path):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
return getattr(module, class_name)

def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
class_name: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.

.. warning::

Calling this function will execute the code in the module file found locally or downloaded from the Hub. It
should therefore only be called on trusted repos.

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:

- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.

module_file (:obj:`str`):
The name of the module file containing the class to look for.
class_name (:obj:`str`):
The name of the class to import in the module.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, will only try to load the tokenizer configuration from local files.

.. note::

Passing :obj:`use_auth_token=True` is required when you want to use a private model.


Returns:
:obj:`type`: The class, dynamically imported from the module.

Examples::

# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
# module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
submodule = "local"
else:
module_file_or_url = hf_bucket_url(
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)

try:
# Load from URL or cache if already cached
resolved_module_file = cached_path(
module_file_or_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)

except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise

# Check we have all the requirements in our environment
check_imports(resolved_module_file)

# Now we move the module inside our cached dynamic modules.
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == "local":
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path.
module_name = module_file
shutil.copy(resolved_module_file, submodule_path / module_file)
else:
# The module file will end up being named module_file + the etag. This way we get the benefit of versioning.
resolved_module_file_name = Path(resolved_module_file).name
module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".")
module_name = "_".join(module_name_parts) + ".py"
if not (submodule_path / module_name).exists():
shutil.copy(resolved_module_file, submodule_path / module_name)

# And lastly we get the class inside our newly created module
final_module = os.path.join(full_submodule, module_name.replace(".py", ""))
return get_class_in_module(class_name, final_module)

+ 663
- 0
fastNLP/transformers/torch/models/auto/modeling_auto.py View File

@@ -0,0 +1,663 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """

import warnings
from collections import OrderedDict

from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
from fastNLP.core.log import logger


MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("fnet", "FNetModel"),
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("beit", "BeitModel"),
("rembert", "RemBertModel"),
("visual_bert", "VisualBertModel"),
("canine", "CanineModel"),
("roformer", "RoFormerModel"),
("clip", "CLIPModel"),
("bigbird_pegasus", "BigBirdPegasusModel"),
("deit", "DeiTModel"),
("luke", "LukeModel"),
("detr", "DetrModel"),
("gpt_neo", "GPTNeoModel"),
("big_bird", "BigBirdModel"),
("speech_to_text", "Speech2TextModel"),
("vit", "ViTModel"),
("wav2vec2", "Wav2Vec2Model"),
("hubert", "HubertModel"),
("m2m_100", "M2M100Model"),
("convbert", "ConvBertModel"),
("led", "LEDModel"),
("blenderbot-small", "BlenderbotSmallModel"),
("retribert", "RetriBertModel"),
("mt5", "MT5Model"),
("t5", "T5Model"),
("pegasus", "PegasusModel"),
("marian", "MarianModel"),
("mbart", "MBartModel"),
("blenderbot", "BlenderbotModel"),
("distilbert", "DistilBertModel"),
("albert", "AlbertModel"),
("camembert", "CamembertModel"),
("xlm-roberta", "XLMRobertaModel"),
("bart", "BartModel"),
("longformer", "LongformerModel"),
("roberta", "RobertaModel"),
("layoutlm", "LayoutLMModel"),
("squeezebert", "SqueezeBertModel"),
("bert", "BertModel"),
("openai-gpt", "OpenAIGPTModel"),
("gpt2", "GPT2Model"),
("megatron-bert", "MegatronBertModel"),
("mobilebert", "MobileBertModel"),
("transfo-xl", "TransfoXLModel"),
("xlnet", "XLNetModel"),
("flaubert", "FlaubertModel"),
("fsmt", "FSMTModel"),
("xlm", "XLMModel"),
("ctrl", "CTRLModel"),
("electra", "ElectraModel"),
("reformer", "ReformerModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("lxmert", "LxmertModel"),
("bert-generation", "BertGenerationEncoder"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("dpr", "DPRQuestionEncoder"),
("xlm-prophetnet", "XLMProphetNetModel"),
("prophetnet", "ProphetNetModel"),
("mpnet", "MPNetModel"),
("tapas", "TapasModel"),
("ibert", "IBertModel"),
("splinter", "SplinterModel"),
]
)

MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("fnet", "FNetForPreTraining"),
("visual_bert", "VisualBertForPreTraining"),
("layoutlm", "LayoutLMForMaskedLM"),
("retribert", "RetriBertModel"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForPreTraining"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForPreTraining"),
("big_bird", "BigBirdForPreTraining"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForPreTraining"),
("lxmert", "LxmertForPreTraining"),
("funnel", "FunnelForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("wav2vec2", "Wav2Vec2ForPreTraining"),
]
)

MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("fnet", "FNetForMaskedLM"),
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForMaskedLM"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("convbert", "ConvBertForMaskedLM"),
("led", "LEDForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("layoutlm", "LayoutLMForMaskedLM"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("marian", "MarianMTModel"),
("fsmt", "FSMTForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForCausalLM"),
("mobilebert", "MobileBertForMaskedLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("encoder-decoder", "EncoderDecoderModel"),
("reformer", "ReformerModelWithLMHead"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("ibert", "IBertForMaskedLM"),
]
)

MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForCausalLM"),
("camembert", "CamembertForCausalLM"),
("xlm-roberta", "XLMRobertaForCausalLM"),
("roberta", "RobertaForCausalLM"),
("bert", "BertLMHeadModel"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("reformer", "ReformerModelWithLMHead"),
("bert-generation", "BertGenerationDecoder"),
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("bart", "BartForCausalLM"),
("mbart", "MBartForCausalLM"),
("pegasus", "PegasusForCausalLM"),
("marian", "MarianForCausalLM"),
("blenderbot", "BlenderbotForCausalLM"),
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("speech_to_text_2", "Speech2Text2ForCausalLM"),
]
)

MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image Classification mapping
("vit", "ViTForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
("beit", "BeitForImageClassification"),
]
)

MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("fnet", "FNetForMaskedLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("big_bird", "BigBirdForMaskedLM"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("mbart", "MBartForConditionalGeneration"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("reformer", "ReformerForMaskedLM"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("ibert", "IBertForMaskedLM"),
]
)

MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# Model for Object Detection mapping
("detr", "DetrForObjectDetection"),
]
)

MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("mt5", "MT5ForConditionalGeneration"),
("t5", "T5ForConditionalGeneration"),
("pegasus", "PegasusForConditionalGeneration"),
("marian", "MarianMTModel"),
("mbart", "MBartForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"),
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
("prophetnet", "ProphetNetForConditionalGeneration"),
]
)

MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
]
)

MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("fnet", "FNetForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("roformer", "RoFormerForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("distilbert", "DistilBertForSequenceClassification"),
("albert", "AlbertForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("xlm-roberta", "XLMRobertaForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("bart", "BartForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("roberta", "RobertaForSequenceClassification"),
("squeezebert", "SqueezeBertForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("bert", "BertForSequenceClassification"),
("xlnet", "XLNetForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("flaubert", "FlaubertForSequenceClassification"),
("xlm", "XLMForSequenceClassification"),
("electra", "ElectraForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("deberta", "DebertaForSequenceClassification"),
("deberta-v2", "DebertaV2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_neo", "GPTNeoForSequenceClassification"),
("openai-gpt", "OpenAIGPTForSequenceClassification"),
("reformer", "ReformerForSequenceClassification"),
("ctrl", "CTRLForSequenceClassification"),
("transfo-xl", "TransfoXLForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
("tapas", "TapasForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
]
)

MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("fnet", "FNetForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"),
("canine", "CanineForQuestionAnswering"),
("roformer", "RoFormerForQuestionAnswering"),
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
("big_bird", "BigBirdForQuestionAnswering"),
("convbert", "ConvBertForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("distilbert", "DistilBertForQuestionAnswering"),
("albert", "AlbertForQuestionAnswering"),
("camembert", "CamembertForQuestionAnswering"),
("bart", "BartForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
("roberta", "RobertaForQuestionAnswering"),
("squeezebert", "SqueezeBertForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
("xlnet", "XLNetForQuestionAnsweringSimple"),
("flaubert", "FlaubertForQuestionAnsweringSimple"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("xlm", "XLMForQuestionAnsweringSimple"),
("electra", "ElectraForQuestionAnswering"),
("reformer", "ReformerForQuestionAnswering"),
("funnel", "FunnelForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
("deberta", "DebertaForQuestionAnswering"),
("deberta-v2", "DebertaV2ForQuestionAnswering"),
("ibert", "IBertForQuestionAnswering"),
("splinter", "SplinterForQuestionAnswering"),
]
)

MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Table Question Answering mapping
("tapas", "TapasForQuestionAnswering"),
]
)

MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("fnet", "FNetForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("rembert", "RemBertForTokenClassification"),
("canine", "CanineForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("convbert", "ConvBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("distilbert", "DistilBertForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("flaubert", "FlaubertForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("roberta", "RobertaForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("bert", "BertForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("xlnet", "XLNetForTokenClassification"),
("albert", "AlbertForTokenClassification"),
("electra", "ElectraForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
("deberta", "DebertaForTokenClassification"),
("deberta-v2", "DebertaV2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("ibert", "IBertForTokenClassification"),
]
)

MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("fnet", "FNetForMultipleChoice"),
("rembert", "RemBertForMultipleChoice"),
("canine", "CanineForMultipleChoice"),
("roformer", "RoFormerForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
("convbert", "ConvBertForMultipleChoice"),
("camembert", "CamembertForMultipleChoice"),
("electra", "ElectraForMultipleChoice"),
("xlm-roberta", "XLMRobertaForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("roberta", "RobertaForMultipleChoice"),
("squeezebert", "SqueezeBertForMultipleChoice"),
("bert", "BertForMultipleChoice"),
("distilbert", "DistilBertForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("xlnet", "XLNetForMultipleChoice"),
("albert", "AlbertForMultipleChoice"),
("xlm", "XLMForMultipleChoice"),
("flaubert", "FlaubertForMultipleChoice"),
("funnel", "FunnelForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
]
)

MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "BertForNextSentencePrediction"),
("fnet", "FNetForNextSentencePrediction"),
("megatron-bert", "MegatronBertForNextSentencePrediction"),
("mobilebert", "MobileBertForNextSentencePrediction"),
]
)

MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("hubert", "HubertForSequenceClassification"),
]
)

MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
[
# Model for Connectionist temporal classification (CTC) mapping
("wav2vec2", "Wav2Vec2ForCTC"),
("hubert", "HubertForCTC"),
]
)

MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)


class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING


AutoModel = auto_class_update(AutoModel)


class AutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING


AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")


# Private on purpose, the public class will add the deprecation warnings.
class _AutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING


_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")


class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING


AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")


class AutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING


AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")


class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)


class AutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


AutoModelForSequenceClassification = auto_class_update(
AutoModelForSequenceClassification, head_doc="sequence classification"
)


class AutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")


class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


AutoModelForTableQuestionAnswering = auto_class_update(
AutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)


class AutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING


AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")


class AutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING


AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")


class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


AutoModelForNextSentencePrediction = auto_class_update(
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)


class AutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")


class AutoModelForObjectDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING


AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")


class AutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")


class AutoModelForCTC(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CTC_MAPPING


AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")


class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
)


class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

Loading…
Cancel
Save