From 639de99563f5b58a14850796576424245e6da082 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 22 May 2022 12:57:38 +0000 Subject: [PATCH] =?UTF-8?q?=E4=B8=BAtransformers=20=E6=B7=BB=E5=8A=A0AutoM?= =?UTF-8?q?odel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/transformers/torch/models/auto/__init__.py | 4 +- .../transformers/torch/models/auto/auto_factory.py | 5 +- .../torch/models/auto/configuration_auto.py | 194 +---------- .../torch/models/auto/modeling_auto.py | 363 +-------------------- .../torch/models/auto/tokenization_auto.py | 170 +--------- fastNLP/transformers/torch/models/cpt/__init__.py | 3 +- 6 files changed, 28 insertions(+), 711 deletions(-) diff --git a/fastNLP/transformers/torch/models/auto/__init__.py b/fastNLP/transformers/torch/models/auto/__init__.py index ac2967d2..0ce22235 100644 --- a/fastNLP/transformers/torch/models/auto/__init__.py +++ b/fastNLP/transformers/torch/models/auto/__init__.py @@ -3,7 +3,7 @@ __all__ = [ "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig", - "TOKENIZER_MAPPING", + "TOKENIZER_MAPPING_NAMES", "get_values", "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", @@ -43,7 +43,7 @@ __all__ = [ from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ AutoConfig -from .tokenization_auto import TOKENIZER_MAPPING +from .tokenization_auto import TOKENIZER_MAPPING_NAMES from .auto_factory import get_values from .modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, diff --git a/fastNLP/transformers/torch/models/auto/auto_factory.py b/fastNLP/transformers/torch/models/auto/auto_factory.py index 015f5642..9eb8ec69 100644 --- a/fastNLP/transformers/torch/models/auto/auto_factory.py +++ b/fastNLP/transformers/torch/models/auto/auto_factory.py @@ -516,7 +516,10 @@ class _LazyAutoMapping(OrderedDict): def _load_attr_from_module(self, model_type, attr): module_name = model_type_to_module_name(model_type) if module_name not in self._modules: - self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") + try: + self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") + except ImportError: + raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.") return getattribute_from_module(self._modules[module_name], attr) def keys(self): diff --git a/fastNLP/transformers/torch/models/auto/configuration_auto.py b/fastNLP/transformers/torch/models/auto/configuration_auto.py index 0138aec7..45d3c071 100644 --- a/fastNLP/transformers/torch/models/auto/configuration_auto.py +++ b/fastNLP/transformers/torch/models/auto/configuration_auto.py @@ -26,221 +26,33 @@ from fastNLP.core.log import logger CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here - ("fnet", "FNetConfig"), - ("gptj", "GPTJConfig"), - ("layoutlmv2", "LayoutLMv2Config"), - ("beit", "BeitConfig"), - ("rembert", "RemBertConfig"), - ("visual_bert", "VisualBertConfig"), - ("canine", "CanineConfig"), - ("roformer", "RoFormerConfig"), - ("clip", "CLIPConfig"), - ("bigbird_pegasus", "BigBirdPegasusConfig"), - ("deit", "DeiTConfig"), - ("luke", "LukeConfig"), - ("detr", "DetrConfig"), - ("gpt_neo", "GPTNeoConfig"), - ("big_bird", "BigBirdConfig"), - ("speech_to_text_2", "Speech2Text2Config"), - ("speech_to_text", "Speech2TextConfig"), - ("vit", "ViTConfig"), - ("wav2vec2", "Wav2Vec2Config"), - ("m2m_100", "M2M100Config"), - ("convbert", "ConvBertConfig"), - ("led", "LEDConfig"), - ("blenderbot-small", "BlenderbotSmallConfig"), - ("retribert", "RetriBertConfig"), - ("ibert", "IBertConfig"), - ("mt5", "MT5Config"), - ("t5", "T5Config"), - ("mobilebert", "MobileBertConfig"), - ("distilbert", "DistilBertConfig"), - ("albert", "AlbertConfig"), - ("bert-generation", "BertGenerationConfig"), - ("camembert", "CamembertConfig"), - ("xlm-roberta", "XLMRobertaConfig"), - ("pegasus", "PegasusConfig"), - ("marian", "MarianConfig"), - ("mbart", "MBartConfig"), - ("megatron-bert", "MegatronBertConfig"), - ("mpnet", "MPNetConfig"), ("bart", "BartConfig"), - ("blenderbot", "BlenderbotConfig"), - ("reformer", "ReformerConfig"), - ("longformer", "LongformerConfig"), ("roberta", "RobertaConfig"), - ("deberta-v2", "DebertaV2Config"), - ("deberta", "DebertaConfig"), - ("flaubert", "FlaubertConfig"), - ("fsmt", "FSMTConfig"), - ("squeezebert", "SqueezeBertConfig"), - ("hubert", "HubertConfig"), ("bert", "BertConfig"), - ("openai-gpt", "OpenAIGPTConfig"), ("gpt2", "GPT2Config"), - ("transfo-xl", "TransfoXLConfig"), - ("xlnet", "XLNetConfig"), - ("xlm-prophetnet", "XLMProphetNetConfig"), - ("prophetnet", "ProphetNetConfig"), - ("xlm", "XLMConfig"), - ("ctrl", "CTRLConfig"), - ("electra", "ElectraConfig"), - ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"), - ("encoder-decoder", "EncoderDecoderConfig"), - ("funnel", "FunnelConfig"), - ("lxmert", "LxmertConfig"), - ("dpr", "DPRConfig"), - ("layoutlm", "LayoutLMConfig"), - ("rag", "RagConfig"), - ("tapas", "TapasConfig"), - ("splinter", "SplinterConfig"), + ("cpt", "CPTConfig"), ] ) CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here - ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("cpt", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), ] ) MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here - ("fnet", "FNet"), - ("gptj", "GPT-J"), - ("beit", "BeiT"), - ("rembert", "RemBERT"), - ("layoutlmv2", "LayoutLMv2"), - ("visual_bert", "VisualBert"), - ("canine", "Canine"), - ("roformer", "RoFormer"), - ("clip", "CLIP"), - ("bigbird_pegasus", "BigBirdPegasus"), - ("deit", "DeiT"), - ("luke", "LUKE"), - ("detr", "DETR"), - ("gpt_neo", "GPT Neo"), - ("big_bird", "BigBird"), - ("speech_to_text_2", "Speech2Text2"), - ("speech_to_text", "Speech2Text"), - ("vit", "ViT"), - ("wav2vec2", "Wav2Vec2"), - ("m2m_100", "M2M100"), - ("convbert", "ConvBERT"), - ("led", "LED"), - ("blenderbot-small", "BlenderbotSmall"), - ("retribert", "RetriBERT"), - ("ibert", "I-BERT"), - ("t5", "T5"), - ("mobilebert", "MobileBERT"), - ("distilbert", "DistilBERT"), - ("albert", "ALBERT"), - ("bert-generation", "Bert Generation"), - ("camembert", "CamemBERT"), - ("xlm-roberta", "XLM-RoBERTa"), - ("pegasus", "Pegasus"), - ("blenderbot", "Blenderbot"), - ("marian", "Marian"), - ("mbart", "mBART"), - ("megatron-bert", "MegatronBert"), ("bart", "BART"), - ("reformer", "Reformer"), - ("longformer", "Longformer"), ("roberta", "RoBERTa"), - ("flaubert", "FlauBERT"), - ("fsmt", "FairSeq Machine-Translation"), - ("squeezebert", "SqueezeBERT"), ("bert", "BERT"), - ("openai-gpt", "OpenAI GPT"), ("gpt2", "OpenAI GPT-2"), - ("transfo-xl", "Transformer-XL"), - ("xlnet", "XLNet"), - ("xlm", "XLM"), - ("ctrl", "CTRL"), - ("electra", "ELECTRA"), - ("encoder-decoder", "Encoder decoder"), - ("speech-encoder-decoder", "Speech Encoder decoder"), - ("funnel", "Funnel Transformer"), - ("lxmert", "LXMERT"), - ("deberta-v2", "DeBERTa-v2"), - ("deberta", "DeBERTa"), - ("layoutlm", "LayoutLM"), - ("dpr", "DPR"), - ("rag", "RAG"), - ("xlm-prophetnet", "XLMProphetNet"), - ("prophetnet", "ProphetNet"), - ("mt5", "mT5"), - ("mpnet", "MPNet"), - ("tapas", "TAPAS"), - ("hubert", "Hubert"), - ("barthez", "BARThez"), - ("phobert", "PhoBERT"), - ("cpm", "CPM"), - ("bertweet", "Bertweet"), - ("bert-japanese", "BertJapanese"), - ("byt5", "ByT5"), - ("mbart50", "mBART-50"), - ("splinter", "Splinter"), + ("cpt", "CPT") ] ) diff --git a/fastNLP/transformers/torch/models/auto/modeling_auto.py b/fastNLP/transformers/torch/models/auto/modeling_auto.py index 6406da14..bed12a2a 100644 --- a/fastNLP/transformers/torch/models/auto/modeling_auto.py +++ b/fastNLP/transformers/torch/models/auto/modeling_auto.py @@ -24,455 +24,118 @@ from fastNLP.core.log import logger MODEL_MAPPING_NAMES = OrderedDict( [ - # Base model mapping - ("fnet", "FNetModel"), - ("gptj", "GPTJModel"), - ("layoutlmv2", "LayoutLMv2Model"), - ("beit", "BeitModel"), - ("rembert", "RemBertModel"), - ("visual_bert", "VisualBertModel"), - ("canine", "CanineModel"), - ("roformer", "RoFormerModel"), - ("clip", "CLIPModel"), - ("bigbird_pegasus", "BigBirdPegasusModel"), - ("deit", "DeiTModel"), - ("luke", "LukeModel"), - ("detr", "DetrModel"), - ("gpt_neo", "GPTNeoModel"), - ("big_bird", "BigBirdModel"), - ("speech_to_text", "Speech2TextModel"), - ("vit", "ViTModel"), - ("wav2vec2", "Wav2Vec2Model"), - ("hubert", "HubertModel"), - ("m2m_100", "M2M100Model"), - ("convbert", "ConvBertModel"), - ("led", "LEDModel"), - ("blenderbot-small", "BlenderbotSmallModel"), - ("retribert", "RetriBertModel"), - ("mt5", "MT5Model"), - ("t5", "T5Model"), - ("pegasus", "PegasusModel"), - ("marian", "MarianModel"), - ("mbart", "MBartModel"), - ("blenderbot", "BlenderbotModel"), - ("distilbert", "DistilBertModel"), - ("albert", "AlbertModel"), - ("camembert", "CamembertModel"), - ("xlm-roberta", "XLMRobertaModel"), ("bart", "BartModel"), - ("longformer", "LongformerModel"), ("roberta", "RobertaModel"), - ("layoutlm", "LayoutLMModel"), - ("squeezebert", "SqueezeBertModel"), ("bert", "BertModel"), - ("openai-gpt", "OpenAIGPTModel"), ("gpt2", "GPT2Model"), - ("megatron-bert", "MegatronBertModel"), - ("mobilebert", "MobileBertModel"), - ("transfo-xl", "TransfoXLModel"), - ("xlnet", "XLNetModel"), - ("flaubert", "FlaubertModel"), - ("fsmt", "FSMTModel"), - ("xlm", "XLMModel"), - ("ctrl", "CTRLModel"), - ("electra", "ElectraModel"), - ("reformer", "ReformerModel"), - ("funnel", ("FunnelModel", "FunnelBaseModel")), - ("lxmert", "LxmertModel"), - ("bert-generation", "BertGenerationEncoder"), - ("deberta", "DebertaModel"), - ("deberta-v2", "DebertaV2Model"), - ("dpr", "DPRQuestionEncoder"), - ("xlm-prophetnet", "XLMProphetNetModel"), - ("prophetnet", "ProphetNetModel"), - ("mpnet", "MPNetModel"), - ("tapas", "TapasModel"), - ("ibert", "IBertModel"), - ("splinter", "SplinterModel"), + ("cpt", "CPTModel"), ] ) MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ - # Model for pre-training mapping - ("fnet", "FNetForPreTraining"), - ("visual_bert", "VisualBertForPreTraining"), - ("layoutlm", "LayoutLMForMaskedLM"), - ("retribert", "RetriBertModel"), - ("t5", "T5ForConditionalGeneration"), - ("distilbert", "DistilBertForMaskedLM"), - ("albert", "AlbertForPreTraining"), - ("camembert", "CamembertForMaskedLM"), - ("xlm-roberta", "XLMRobertaForMaskedLM"), ("bart", "BartForConditionalGeneration"), - ("fsmt", "FSMTForConditionalGeneration"), - ("longformer", "LongformerForMaskedLM"), ("roberta", "RobertaForMaskedLM"), - ("squeezebert", "SqueezeBertForMaskedLM"), ("bert", "BertForPreTraining"), - ("big_bird", "BigBirdForPreTraining"), - ("openai-gpt", "OpenAIGPTLMHeadModel"), ("gpt2", "GPT2LMHeadModel"), - ("megatron-bert", "MegatronBertForPreTraining"), - ("mobilebert", "MobileBertForPreTraining"), - ("transfo-xl", "TransfoXLLMHeadModel"), - ("xlnet", "XLNetLMHeadModel"), - ("flaubert", "FlaubertWithLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), - ("ctrl", "CTRLLMHeadModel"), - ("electra", "ElectraForPreTraining"), - ("lxmert", "LxmertForPreTraining"), - ("funnel", "FunnelForPreTraining"), - ("mpnet", "MPNetForMaskedLM"), - ("tapas", "TapasForMaskedLM"), - ("ibert", "IBertForMaskedLM"), - ("deberta", "DebertaForMaskedLM"), - ("deberta-v2", "DebertaV2ForMaskedLM"), - ("wav2vec2", "Wav2Vec2ForPreTraining"), + ("cpt", "CPTForConditionalGeneration"), ] ) MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping - ("fnet", "FNetForMaskedLM"), - ("gptj", "GPTJForCausalLM"), - ("rembert", "RemBertForMaskedLM"), - ("roformer", "RoFormerForMaskedLM"), - ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), - ("gpt_neo", "GPTNeoForCausalLM"), - ("big_bird", "BigBirdForMaskedLM"), - ("speech_to_text", "Speech2TextForConditionalGeneration"), - ("wav2vec2", "Wav2Vec2ForMaskedLM"), - ("m2m_100", "M2M100ForConditionalGeneration"), - ("convbert", "ConvBertForMaskedLM"), - ("led", "LEDForConditionalGeneration"), - ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), - ("layoutlm", "LayoutLMForMaskedLM"), - ("t5", "T5ForConditionalGeneration"), - ("distilbert", "DistilBertForMaskedLM"), - ("albert", "AlbertForMaskedLM"), - ("camembert", "CamembertForMaskedLM"), - ("xlm-roberta", "XLMRobertaForMaskedLM"), - ("marian", "MarianMTModel"), - ("fsmt", "FSMTForConditionalGeneration"), ("bart", "BartForConditionalGeneration"), - ("longformer", "LongformerForMaskedLM"), ("roberta", "RobertaForMaskedLM"), - ("squeezebert", "SqueezeBertForMaskedLM"), ("bert", "BertForMaskedLM"), - ("openai-gpt", "OpenAIGPTLMHeadModel"), ("gpt2", "GPT2LMHeadModel"), - ("megatron-bert", "MegatronBertForCausalLM"), - ("mobilebert", "MobileBertForMaskedLM"), - ("transfo-xl", "TransfoXLLMHeadModel"), - ("xlnet", "XLNetLMHeadModel"), - ("flaubert", "FlaubertWithLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), - ("ctrl", "CTRLLMHeadModel"), - ("electra", "ElectraForMaskedLM"), - ("encoder-decoder", "EncoderDecoderModel"), - ("reformer", "ReformerModelWithLMHead"), - ("funnel", "FunnelForMaskedLM"), - ("mpnet", "MPNetForMaskedLM"), - ("tapas", "TapasForMaskedLM"), - ("deberta", "DebertaForMaskedLM"), - ("deberta-v2", "DebertaV2ForMaskedLM"), - ("ibert", "IBertForMaskedLM"), + ("cpt", "CPTForConditionalGeneration"), ] ) MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - ("gptj", "GPTJForCausalLM"), - ("rembert", "RemBertForCausalLM"), - ("roformer", "RoFormerForCausalLM"), - ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), - ("gpt_neo", "GPTNeoForCausalLM"), - ("big_bird", "BigBirdForCausalLM"), - ("camembert", "CamembertForCausalLM"), - ("xlm-roberta", "XLMRobertaForCausalLM"), ("roberta", "RobertaForCausalLM"), ("bert", "BertLMHeadModel"), - ("openai-gpt", "OpenAIGPTLMHeadModel"), ("gpt2", "GPT2LMHeadModel"), - ("transfo-xl", "TransfoXLLMHeadModel"), - ("xlnet", "XLNetLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), - ("ctrl", "CTRLLMHeadModel"), - ("reformer", "ReformerModelWithLMHead"), - ("bert-generation", "BertGenerationDecoder"), - ("xlm-prophetnet", "XLMProphetNetForCausalLM"), - ("prophetnet", "ProphetNetForCausalLM"), ("bart", "BartForCausalLM"), - ("mbart", "MBartForCausalLM"), - ("pegasus", "PegasusForCausalLM"), - ("marian", "MarianForCausalLM"), - ("blenderbot", "BlenderbotForCausalLM"), - ("blenderbot-small", "BlenderbotSmallForCausalLM"), - ("megatron-bert", "MegatronBertForCausalLM"), - ("speech_to_text_2", "Speech2Text2ForCausalLM"), ] ) -MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Image Classification mapping - ("vit", "ViTForImageClassification"), - ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), - ("beit", "BeitForImageClassification"), - ] -) +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict([]) MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping - ("fnet", "FNetForMaskedLM"), - ("rembert", "RemBertForMaskedLM"), - ("roformer", "RoFormerForMaskedLM"), - ("big_bird", "BigBirdForMaskedLM"), - ("wav2vec2", "Wav2Vec2ForMaskedLM"), - ("convbert", "ConvBertForMaskedLM"), - ("layoutlm", "LayoutLMForMaskedLM"), - ("distilbert", "DistilBertForMaskedLM"), - ("albert", "AlbertForMaskedLM"), ("bart", "BartForConditionalGeneration"), - ("mbart", "MBartForConditionalGeneration"), - ("camembert", "CamembertForMaskedLM"), - ("xlm-roberta", "XLMRobertaForMaskedLM"), - ("longformer", "LongformerForMaskedLM"), ("roberta", "RobertaForMaskedLM"), - ("squeezebert", "SqueezeBertForMaskedLM"), ("bert", "BertForMaskedLM"), - ("megatron-bert", "MegatronBertForMaskedLM"), - ("mobilebert", "MobileBertForMaskedLM"), - ("flaubert", "FlaubertWithLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), - ("electra", "ElectraForMaskedLM"), - ("reformer", "ReformerForMaskedLM"), - ("funnel", "FunnelForMaskedLM"), - ("mpnet", "MPNetForMaskedLM"), - ("tapas", "TapasForMaskedLM"), - ("deberta", "DebertaForMaskedLM"), - ("deberta-v2", "DebertaV2ForMaskedLM"), - ("ibert", "IBertForMaskedLM"), + ("cpt", "CPTForConditionalGeneration"), ] ) -MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( - [ - # Model for Object Detection mapping - ("detr", "DetrForObjectDetection"), - ] -) +MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict([]) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping - ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), - ("m2m_100", "M2M100ForConditionalGeneration"), - ("led", "LEDForConditionalGeneration"), - ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), - ("mt5", "MT5ForConditionalGeneration"), - ("t5", "T5ForConditionalGeneration"), - ("pegasus", "PegasusForConditionalGeneration"), - ("marian", "MarianMTModel"), - ("mbart", "MBartForConditionalGeneration"), - ("blenderbot", "BlenderbotForConditionalGeneration"), ("bart", "BartForConditionalGeneration"), - ("fsmt", "FSMTForConditionalGeneration"), - ("encoder-decoder", "EncoderDecoderModel"), - ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), - ("prophetnet", "ProphetNetForConditionalGeneration"), + ("cpt", "CPTForConditionalGeneration"), ] ) -MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( - [ - ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), - ("speech_to_text", "Speech2TextForConditionalGeneration"), - ] -) +MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict([]) MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - ("fnet", "FNetForSequenceClassification"), - ("gptj", "GPTJForSequenceClassification"), - ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), - ("rembert", "RemBertForSequenceClassification"), - ("canine", "CanineForSequenceClassification"), - ("roformer", "RoFormerForSequenceClassification"), - ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), - ("big_bird", "BigBirdForSequenceClassification"), - ("convbert", "ConvBertForSequenceClassification"), - ("led", "LEDForSequenceClassification"), - ("distilbert", "DistilBertForSequenceClassification"), - ("albert", "AlbertForSequenceClassification"), - ("camembert", "CamembertForSequenceClassification"), - ("xlm-roberta", "XLMRobertaForSequenceClassification"), - ("mbart", "MBartForSequenceClassification"), ("bart", "BartForSequenceClassification"), - ("longformer", "LongformerForSequenceClassification"), ("roberta", "RobertaForSequenceClassification"), - ("squeezebert", "SqueezeBertForSequenceClassification"), - ("layoutlm", "LayoutLMForSequenceClassification"), ("bert", "BertForSequenceClassification"), - ("xlnet", "XLNetForSequenceClassification"), - ("megatron-bert", "MegatronBertForSequenceClassification"), - ("mobilebert", "MobileBertForSequenceClassification"), - ("flaubert", "FlaubertForSequenceClassification"), - ("xlm", "XLMForSequenceClassification"), - ("electra", "ElectraForSequenceClassification"), - ("funnel", "FunnelForSequenceClassification"), - ("deberta", "DebertaForSequenceClassification"), - ("deberta-v2", "DebertaV2ForSequenceClassification"), ("gpt2", "GPT2ForSequenceClassification"), - ("gpt_neo", "GPTNeoForSequenceClassification"), - ("openai-gpt", "OpenAIGPTForSequenceClassification"), - ("reformer", "ReformerForSequenceClassification"), - ("ctrl", "CTRLForSequenceClassification"), - ("transfo-xl", "TransfoXLForSequenceClassification"), - ("mpnet", "MPNetForSequenceClassification"), - ("tapas", "TapasForSequenceClassification"), - ("ibert", "IBertForSequenceClassification"), + ("cpt", "CPTForSequenceClassification"), ] ) MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - ("fnet", "FNetForQuestionAnswering"), - ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), - ("rembert", "RemBertForQuestionAnswering"), - ("canine", "CanineForQuestionAnswering"), - ("roformer", "RoFormerForQuestionAnswering"), - ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), - ("big_bird", "BigBirdForQuestionAnswering"), - ("convbert", "ConvBertForQuestionAnswering"), - ("led", "LEDForQuestionAnswering"), - ("distilbert", "DistilBertForQuestionAnswering"), - ("albert", "AlbertForQuestionAnswering"), - ("camembert", "CamembertForQuestionAnswering"), ("bart", "BartForQuestionAnswering"), - ("mbart", "MBartForQuestionAnswering"), - ("longformer", "LongformerForQuestionAnswering"), - ("xlm-roberta", "XLMRobertaForQuestionAnswering"), ("roberta", "RobertaForQuestionAnswering"), - ("squeezebert", "SqueezeBertForQuestionAnswering"), ("bert", "BertForQuestionAnswering"), - ("xlnet", "XLNetForQuestionAnsweringSimple"), - ("flaubert", "FlaubertForQuestionAnsweringSimple"), - ("megatron-bert", "MegatronBertForQuestionAnswering"), - ("mobilebert", "MobileBertForQuestionAnswering"), - ("xlm", "XLMForQuestionAnsweringSimple"), - ("electra", "ElectraForQuestionAnswering"), - ("reformer", "ReformerForQuestionAnswering"), - ("funnel", "FunnelForQuestionAnswering"), - ("lxmert", "LxmertForQuestionAnswering"), - ("mpnet", "MPNetForQuestionAnswering"), - ("deberta", "DebertaForQuestionAnswering"), - ("deberta-v2", "DebertaV2ForQuestionAnswering"), - ("ibert", "IBertForQuestionAnswering"), - ("splinter", "SplinterForQuestionAnswering"), + ("cpt", "CPTForQuestionAnswering"), ] ) -MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( - [ - # Model for Table Question Answering mapping - ("tapas", "TapasForQuestionAnswering"), - ] -) +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict([]) MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping - ("fnet", "FNetForTokenClassification"), - ("layoutlmv2", "LayoutLMv2ForTokenClassification"), - ("rembert", "RemBertForTokenClassification"), - ("canine", "CanineForTokenClassification"), - ("roformer", "RoFormerForTokenClassification"), - ("big_bird", "BigBirdForTokenClassification"), - ("convbert", "ConvBertForTokenClassification"), - ("layoutlm", "LayoutLMForTokenClassification"), - ("distilbert", "DistilBertForTokenClassification"), - ("camembert", "CamembertForTokenClassification"), - ("flaubert", "FlaubertForTokenClassification"), - ("xlm", "XLMForTokenClassification"), - ("xlm-roberta", "XLMRobertaForTokenClassification"), - ("longformer", "LongformerForTokenClassification"), ("roberta", "RobertaForTokenClassification"), - ("squeezebert", "SqueezeBertForTokenClassification"), ("bert", "BertForTokenClassification"), - ("megatron-bert", "MegatronBertForTokenClassification"), - ("mobilebert", "MobileBertForTokenClassification"), - ("xlnet", "XLNetForTokenClassification"), - ("albert", "AlbertForTokenClassification"), - ("electra", "ElectraForTokenClassification"), - ("funnel", "FunnelForTokenClassification"), - ("mpnet", "MPNetForTokenClassification"), - ("deberta", "DebertaForTokenClassification"), - ("deberta-v2", "DebertaV2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), - ("ibert", "IBertForTokenClassification"), ] ) MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping - ("fnet", "FNetForMultipleChoice"), - ("rembert", "RemBertForMultipleChoice"), - ("canine", "CanineForMultipleChoice"), - ("roformer", "RoFormerForMultipleChoice"), - ("big_bird", "BigBirdForMultipleChoice"), - ("convbert", "ConvBertForMultipleChoice"), - ("camembert", "CamembertForMultipleChoice"), - ("electra", "ElectraForMultipleChoice"), - ("xlm-roberta", "XLMRobertaForMultipleChoice"), - ("longformer", "LongformerForMultipleChoice"), ("roberta", "RobertaForMultipleChoice"), - ("squeezebert", "SqueezeBertForMultipleChoice"), ("bert", "BertForMultipleChoice"), - ("distilbert", "DistilBertForMultipleChoice"), - ("megatron-bert", "MegatronBertForMultipleChoice"), - ("mobilebert", "MobileBertForMultipleChoice"), - ("xlnet", "XLNetForMultipleChoice"), - ("albert", "AlbertForMultipleChoice"), - ("xlm", "XLMForMultipleChoice"), - ("flaubert", "FlaubertForMultipleChoice"), - ("funnel", "FunnelForMultipleChoice"), - ("mpnet", "MPNetForMultipleChoice"), - ("ibert", "IBertForMultipleChoice"), ] ) MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( [ ("bert", "BertForNextSentencePrediction"), - ("fnet", "FNetForNextSentencePrediction"), - ("megatron-bert", "MegatronBertForNextSentencePrediction"), - ("mobilebert", "MobileBertForNextSentencePrediction"), ] ) -MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - # Model for Audio Classification mapping - ("wav2vec2", "Wav2Vec2ForSequenceClassification"), - ("hubert", "HubertForSequenceClassification"), - ] -) +MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([]) -MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( - [ - # Model for Connectionist temporal classification (CTC) mapping - ("wav2vec2", "Wav2Vec2ForCTC"), - ("hubert", "HubertForCTC"), - ] -) +MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict([]) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) diff --git a/fastNLP/transformers/torch/models/auto/tokenization_auto.py b/fastNLP/transformers/torch/models/auto/tokenization_auto.py index f1618d6a..d30cbae1 100644 --- a/fastNLP/transformers/torch/models/auto/tokenization_auto.py +++ b/fastNLP/transformers/torch/models/auto/tokenization_auto.py @@ -29,171 +29,9 @@ if TYPE_CHECKING: else: TOKENIZER_MAPPING_NAMES = OrderedDict( [ - ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), - ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), - ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), - ( - "t5", - ( - "T5Tokenizer" if is_sentencepiece_available() else None, - "T5TokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "mt5", - ( - "MT5Tokenizer" if is_sentencepiece_available() else None, - "MT5TokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), - ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)), - ( - "albert", - ( - "AlbertTokenizer" if is_sentencepiece_available() else None, - "AlbertTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "camembert", - ( - "CamembertTokenizer" if is_sentencepiece_available() else None, - "CamembertTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "pegasus", - ( - "PegasusTokenizer" if is_sentencepiece_available() else None, - "PegasusTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "mbart", - ( - "MBartTokenizer" if is_sentencepiece_available() else None, - "MBartTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "xlm-roberta", - ( - "XLMRobertaTokenizer" if is_sentencepiece_available() else None, - "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), - ("blenderbot-small", ("BlenderbotSmallTokenizer", None)), - ("blenderbot", ("BlenderbotTokenizer", None)), - ("bart", ("BartTokenizer", "BartTokenizerFast")), - ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), - ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), - ( - "reformer", - ( - "ReformerTokenizer" if is_sentencepiece_available() else None, - "ReformerTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), - ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), - ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), - ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), - ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), - ( - "dpr", - ( - "DPRQuestionEncoderTokenizer", - "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "squeezebert", - ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), - ), - ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), - ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), - ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), - ("transfo-xl", ("TransfoXLTokenizer", None)), - ( - "xlnet", - ( - "XLNetTokenizer" if is_sentencepiece_available() else None, - "XLNetTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("flaubert", ("FlaubertTokenizer", None)), - ("xlm", ("XLMTokenizer", None)), - ("ctrl", ("CTRLTokenizer", None)), - ("fsmt", ("FSMTTokenizer", None)), - ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)), - ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)), - ("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)), - ("rag", ("RagTokenizer", None)), - ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)), - ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), - ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), - ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), - ("prophetnet", ("ProphetNetTokenizer", None)), - ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), - ("tapas", ("TapasTokenizer", None)), - ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), - ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), - ( - "big_bird", - ( - "BigBirdTokenizer" if is_sentencepiece_available() else None, - "BigBirdTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), - ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), - ("hubert", ("Wav2Vec2CTCTokenizer", None)), - ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), - ("luke", ("LukeTokenizer", None)), - ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)), - ("canine", ("CanineTokenizer", None)), - ("bertweet", ("BertweetTokenizer", None)), - ("bert-japanese", ("BertJapaneseTokenizer", None)), - ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")), - ("byt5", ("ByT5Tokenizer", None)), - ( - "cpm", - ( - "CpmTokenizer" if is_sentencepiece_available() else None, - "CpmTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), - ("phobert", ("PhobertTokenizer", None)), - ( - "barthez", - ( - "BarthezTokenizer" if is_sentencepiece_available() else None, - "BarthezTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "mbart50", - ( - "MBart50Tokenizer" if is_sentencepiece_available() else None, - "MBart50TokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "rembert", - ( - "RemBertTokenizer" if is_sentencepiece_available() else None, - "RemBertTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "clip", - ( - "CLIPTokenizer", - "CLIPTokenizerFast" if is_tokenizers_available() else None, - ), - ), + ("bart", ("BartTokenizer", None)), + ("roberta", ("RobertaTokenizer", None)), + ("bert", ("BertTokenizer", None)), + ("gpt2", ("GPT2Tokenizer", None)), ] ) \ No newline at end of file diff --git a/fastNLP/transformers/torch/models/cpt/__init__.py b/fastNLP/transformers/torch/models/cpt/__init__.py index 58d9f918..07a85d6c 100644 --- a/fastNLP/transformers/torch/models/cpt/__init__.py +++ b/fastNLP/transformers/torch/models/cpt/__init__.py @@ -1,5 +1,6 @@ __all__ = [ "CPT_PRETRAINED_MODEL_ARCHIVE_LIST", + "CPTConfig", "CPTForConditionalGeneration", "CPTForSequenceClassification", "CPTForMaskedLM", @@ -9,4 +10,4 @@ __all__ = [ ] from .modeling_cpt import CPT_PRETRAINED_MODEL_ARCHIVE_LIST, CPTForConditionalGeneration, CPTForSequenceClassification, \ - CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel \ No newline at end of file + CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel, CPTConfig \ No newline at end of file