|
- # coding=utf-8
- # Copyright 2018 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ Auto Model class. """
-
- import warnings
- from collections import OrderedDict
-
- from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
- from .configuration_auto import CONFIG_MAPPING_NAMES
- from fastNLP.core.log import logger
-
-
- MODEL_MAPPING_NAMES = OrderedDict(
- [
- ("bart", "BartModel"),
- ("roberta", "RobertaModel"),
- ("bert", "BertModel"),
- ("gpt2", "GPT2Model"),
- ("cpt", "CPTModel"),
- ]
- )
-
- MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
- [
- ("bart", "BartForConditionalGeneration"),
- ("roberta", "RobertaForMaskedLM"),
- ("bert", "BertForPreTraining"),
- ("gpt2", "GPT2LMHeadModel"),
- ("cpt", "CPTForConditionalGeneration"),
- ]
- )
-
- MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
- [
- # Model with LM heads mapping
- ("bart", "BartForConditionalGeneration"),
- ("roberta", "RobertaForMaskedLM"),
- ("bert", "BertForMaskedLM"),
- ("gpt2", "GPT2LMHeadModel"),
- ("cpt", "CPTForConditionalGeneration"),
- ]
- )
-
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Causal LM mapping
- ("roberta", "RobertaForCausalLM"),
- ("bert", "BertLMHeadModel"),
- ("gpt2", "GPT2LMHeadModel"),
- ("bart", "BartForCausalLM"),
- ]
- )
-
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict([])
-
- MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Masked LM mapping
- ("bart", "BartForConditionalGeneration"),
- ("roberta", "RobertaForMaskedLM"),
- ("bert", "BertForMaskedLM"),
- ("cpt", "CPTForConditionalGeneration"),
- ]
- )
-
- MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict([])
-
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Seq2Seq Causal LM mapping
- ("bart", "BartForConditionalGeneration"),
- ("cpt", "CPTForConditionalGeneration"),
- ]
- )
-
- MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict([])
-
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Sequence Classification mapping
- ("bart", "BartForSequenceClassification"),
- ("roberta", "RobertaForSequenceClassification"),
- ("bert", "BertForSequenceClassification"),
- ("gpt2", "GPT2ForSequenceClassification"),
- ("cpt", "CPTForSequenceClassification"),
- ]
- )
-
- MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Question Answering mapping
- ("bart", "BartForQuestionAnswering"),
- ("roberta", "RobertaForQuestionAnswering"),
- ("bert", "BertForQuestionAnswering"),
- ("cpt", "CPTForQuestionAnswering"),
- ]
- )
-
- MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict([])
-
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Token Classification mapping
- ("roberta", "RobertaForTokenClassification"),
- ("bert", "BertForTokenClassification"),
- ("gpt2", "GPT2ForTokenClassification"),
- ]
- )
-
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
- [
- # Model for Multiple Choice mapping
- ("roberta", "RobertaForMultipleChoice"),
- ("bert", "BertForMultipleChoice"),
- ]
- )
-
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
- [
- ("bert", "BertForNextSentencePrediction"),
- ]
- )
-
- MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([])
-
- MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict([])
-
- MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
- MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
- MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
- MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
- MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
- )
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
- )
- MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
- MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
-
-
- class AutoModel(_BaseAutoModelClass):
- _model_mapping = MODEL_MAPPING
-
-
- AutoModel = auto_class_update(AutoModel)
-
-
- class AutoModelForPreTraining(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_PRETRAINING_MAPPING
-
-
- AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
-
-
- # Private on purpose, the public class will add the deprecation warnings.
- class _AutoModelWithLMHead(_BaseAutoModelClass):
- _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
-
-
- _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
-
-
- class AutoModelForCausalLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
-
-
- AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
-
-
- class AutoModelForMaskedLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
-
-
- AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
-
-
- class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
-
-
- AutoModelForSeq2SeqLM = auto_class_update(
- AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
- )
-
-
- class AutoModelForSequenceClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
-
-
- AutoModelForSequenceClassification = auto_class_update(
- AutoModelForSequenceClassification, head_doc="sequence classification"
- )
-
-
- class AutoModelForQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
-
-
- AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
-
-
- class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
-
-
- AutoModelForTableQuestionAnswering = auto_class_update(
- AutoModelForTableQuestionAnswering,
- head_doc="table question answering",
- checkpoint_for_example="google/tapas-base-finetuned-wtq",
- )
-
-
- class AutoModelForTokenClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
-
-
- AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
-
-
- class AutoModelForMultipleChoice(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
-
-
- AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
-
-
- class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
-
-
- AutoModelForNextSentencePrediction = auto_class_update(
- AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
- )
-
-
- class AutoModelForImageClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
-
-
- AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
-
-
- class AutoModelForObjectDetection(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
-
-
- AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
-
-
- class AutoModelForAudioClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
-
-
- AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
-
-
- class AutoModelForCTC(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_CTC_MAPPING
-
-
- AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
-
-
- class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
-
-
- AutoModelForSpeechSeq2Seq = auto_class_update(
- AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
- )
-
-
- class AutoModelWithLMHead(_AutoModelWithLMHead):
- @classmethod
- def from_config(cls, config):
- warnings.warn(
- "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
- "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
- "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
- FutureWarning,
- )
- return super().from_config(config)
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- warnings.warn(
- "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
- "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
- "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
- FutureWarning,
- )
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|