|
|
@@ -24,455 +24,118 @@ from fastNLP.core.log import logger |
|
|
|
|
|
|
|
MODEL_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Base model mapping |
|
|
|
("fnet", "FNetModel"), |
|
|
|
("gptj", "GPTJModel"), |
|
|
|
("layoutlmv2", "LayoutLMv2Model"), |
|
|
|
("beit", "BeitModel"), |
|
|
|
("rembert", "RemBertModel"), |
|
|
|
("visual_bert", "VisualBertModel"), |
|
|
|
("canine", "CanineModel"), |
|
|
|
("roformer", "RoFormerModel"), |
|
|
|
("clip", "CLIPModel"), |
|
|
|
("bigbird_pegasus", "BigBirdPegasusModel"), |
|
|
|
("deit", "DeiTModel"), |
|
|
|
("luke", "LukeModel"), |
|
|
|
("detr", "DetrModel"), |
|
|
|
("gpt_neo", "GPTNeoModel"), |
|
|
|
("big_bird", "BigBirdModel"), |
|
|
|
("speech_to_text", "Speech2TextModel"), |
|
|
|
("vit", "ViTModel"), |
|
|
|
("wav2vec2", "Wav2Vec2Model"), |
|
|
|
("hubert", "HubertModel"), |
|
|
|
("m2m_100", "M2M100Model"), |
|
|
|
("convbert", "ConvBertModel"), |
|
|
|
("led", "LEDModel"), |
|
|
|
("blenderbot-small", "BlenderbotSmallModel"), |
|
|
|
("retribert", "RetriBertModel"), |
|
|
|
("mt5", "MT5Model"), |
|
|
|
("t5", "T5Model"), |
|
|
|
("pegasus", "PegasusModel"), |
|
|
|
("marian", "MarianModel"), |
|
|
|
("mbart", "MBartModel"), |
|
|
|
("blenderbot", "BlenderbotModel"), |
|
|
|
("distilbert", "DistilBertModel"), |
|
|
|
("albert", "AlbertModel"), |
|
|
|
("camembert", "CamembertModel"), |
|
|
|
("xlm-roberta", "XLMRobertaModel"), |
|
|
|
("bart", "BartModel"), |
|
|
|
("longformer", "LongformerModel"), |
|
|
|
("roberta", "RobertaModel"), |
|
|
|
("layoutlm", "LayoutLMModel"), |
|
|
|
("squeezebert", "SqueezeBertModel"), |
|
|
|
("bert", "BertModel"), |
|
|
|
("openai-gpt", "OpenAIGPTModel"), |
|
|
|
("gpt2", "GPT2Model"), |
|
|
|
("megatron-bert", "MegatronBertModel"), |
|
|
|
("mobilebert", "MobileBertModel"), |
|
|
|
("transfo-xl", "TransfoXLModel"), |
|
|
|
("xlnet", "XLNetModel"), |
|
|
|
("flaubert", "FlaubertModel"), |
|
|
|
("fsmt", "FSMTModel"), |
|
|
|
("xlm", "XLMModel"), |
|
|
|
("ctrl", "CTRLModel"), |
|
|
|
("electra", "ElectraModel"), |
|
|
|
("reformer", "ReformerModel"), |
|
|
|
("funnel", ("FunnelModel", "FunnelBaseModel")), |
|
|
|
("lxmert", "LxmertModel"), |
|
|
|
("bert-generation", "BertGenerationEncoder"), |
|
|
|
("deberta", "DebertaModel"), |
|
|
|
("deberta-v2", "DebertaV2Model"), |
|
|
|
("dpr", "DPRQuestionEncoder"), |
|
|
|
("xlm-prophetnet", "XLMProphetNetModel"), |
|
|
|
("prophetnet", "ProphetNetModel"), |
|
|
|
("mpnet", "MPNetModel"), |
|
|
|
("tapas", "TapasModel"), |
|
|
|
("ibert", "IBertModel"), |
|
|
|
("splinter", "SplinterModel"), |
|
|
|
("cpt", "CPTModel"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for pre-training mapping |
|
|
|
("fnet", "FNetForPreTraining"), |
|
|
|
("visual_bert", "VisualBertForPreTraining"), |
|
|
|
("layoutlm", "LayoutLMForMaskedLM"), |
|
|
|
("retribert", "RetriBertModel"), |
|
|
|
("t5", "T5ForConditionalGeneration"), |
|
|
|
("distilbert", "DistilBertForMaskedLM"), |
|
|
|
("albert", "AlbertForPreTraining"), |
|
|
|
("camembert", "CamembertForMaskedLM"), |
|
|
|
("xlm-roberta", "XLMRobertaForMaskedLM"), |
|
|
|
("bart", "BartForConditionalGeneration"), |
|
|
|
("fsmt", "FSMTForConditionalGeneration"), |
|
|
|
("longformer", "LongformerForMaskedLM"), |
|
|
|
("roberta", "RobertaForMaskedLM"), |
|
|
|
("squeezebert", "SqueezeBertForMaskedLM"), |
|
|
|
("bert", "BertForPreTraining"), |
|
|
|
("big_bird", "BigBirdForPreTraining"), |
|
|
|
("openai-gpt", "OpenAIGPTLMHeadModel"), |
|
|
|
("gpt2", "GPT2LMHeadModel"), |
|
|
|
("megatron-bert", "MegatronBertForPreTraining"), |
|
|
|
("mobilebert", "MobileBertForPreTraining"), |
|
|
|
("transfo-xl", "TransfoXLLMHeadModel"), |
|
|
|
("xlnet", "XLNetLMHeadModel"), |
|
|
|
("flaubert", "FlaubertWithLMHeadModel"), |
|
|
|
("xlm", "XLMWithLMHeadModel"), |
|
|
|
("ctrl", "CTRLLMHeadModel"), |
|
|
|
("electra", "ElectraForPreTraining"), |
|
|
|
("lxmert", "LxmertForPreTraining"), |
|
|
|
("funnel", "FunnelForPreTraining"), |
|
|
|
("mpnet", "MPNetForMaskedLM"), |
|
|
|
("tapas", "TapasForMaskedLM"), |
|
|
|
("ibert", "IBertForMaskedLM"), |
|
|
|
("deberta", "DebertaForMaskedLM"), |
|
|
|
("deberta-v2", "DebertaV2ForMaskedLM"), |
|
|
|
("wav2vec2", "Wav2Vec2ForPreTraining"), |
|
|
|
("cpt", "CPTForConditionalGeneration"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model with LM heads mapping |
|
|
|
("fnet", "FNetForMaskedLM"), |
|
|
|
("gptj", "GPTJForCausalLM"), |
|
|
|
("rembert", "RemBertForMaskedLM"), |
|
|
|
("roformer", "RoFormerForMaskedLM"), |
|
|
|
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), |
|
|
|
("gpt_neo", "GPTNeoForCausalLM"), |
|
|
|
("big_bird", "BigBirdForMaskedLM"), |
|
|
|
("speech_to_text", "Speech2TextForConditionalGeneration"), |
|
|
|
("wav2vec2", "Wav2Vec2ForMaskedLM"), |
|
|
|
("m2m_100", "M2M100ForConditionalGeneration"), |
|
|
|
("convbert", "ConvBertForMaskedLM"), |
|
|
|
("led", "LEDForConditionalGeneration"), |
|
|
|
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), |
|
|
|
("layoutlm", "LayoutLMForMaskedLM"), |
|
|
|
("t5", "T5ForConditionalGeneration"), |
|
|
|
("distilbert", "DistilBertForMaskedLM"), |
|
|
|
("albert", "AlbertForMaskedLM"), |
|
|
|
("camembert", "CamembertForMaskedLM"), |
|
|
|
("xlm-roberta", "XLMRobertaForMaskedLM"), |
|
|
|
("marian", "MarianMTModel"), |
|
|
|
("fsmt", "FSMTForConditionalGeneration"), |
|
|
|
("bart", "BartForConditionalGeneration"), |
|
|
|
("longformer", "LongformerForMaskedLM"), |
|
|
|
("roberta", "RobertaForMaskedLM"), |
|
|
|
("squeezebert", "SqueezeBertForMaskedLM"), |
|
|
|
("bert", "BertForMaskedLM"), |
|
|
|
("openai-gpt", "OpenAIGPTLMHeadModel"), |
|
|
|
("gpt2", "GPT2LMHeadModel"), |
|
|
|
("megatron-bert", "MegatronBertForCausalLM"), |
|
|
|
("mobilebert", "MobileBertForMaskedLM"), |
|
|
|
("transfo-xl", "TransfoXLLMHeadModel"), |
|
|
|
("xlnet", "XLNetLMHeadModel"), |
|
|
|
("flaubert", "FlaubertWithLMHeadModel"), |
|
|
|
("xlm", "XLMWithLMHeadModel"), |
|
|
|
("ctrl", "CTRLLMHeadModel"), |
|
|
|
("electra", "ElectraForMaskedLM"), |
|
|
|
("encoder-decoder", "EncoderDecoderModel"), |
|
|
|
("reformer", "ReformerModelWithLMHead"), |
|
|
|
("funnel", "FunnelForMaskedLM"), |
|
|
|
("mpnet", "MPNetForMaskedLM"), |
|
|
|
("tapas", "TapasForMaskedLM"), |
|
|
|
("deberta", "DebertaForMaskedLM"), |
|
|
|
("deberta-v2", "DebertaV2ForMaskedLM"), |
|
|
|
("ibert", "IBertForMaskedLM"), |
|
|
|
("cpt", "CPTForConditionalGeneration"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Causal LM mapping |
|
|
|
("gptj", "GPTJForCausalLM"), |
|
|
|
("rembert", "RemBertForCausalLM"), |
|
|
|
("roformer", "RoFormerForCausalLM"), |
|
|
|
("bigbird_pegasus", "BigBirdPegasusForCausalLM"), |
|
|
|
("gpt_neo", "GPTNeoForCausalLM"), |
|
|
|
("big_bird", "BigBirdForCausalLM"), |
|
|
|
("camembert", "CamembertForCausalLM"), |
|
|
|
("xlm-roberta", "XLMRobertaForCausalLM"), |
|
|
|
("roberta", "RobertaForCausalLM"), |
|
|
|
("bert", "BertLMHeadModel"), |
|
|
|
("openai-gpt", "OpenAIGPTLMHeadModel"), |
|
|
|
("gpt2", "GPT2LMHeadModel"), |
|
|
|
("transfo-xl", "TransfoXLLMHeadModel"), |
|
|
|
("xlnet", "XLNetLMHeadModel"), |
|
|
|
("xlm", "XLMWithLMHeadModel"), |
|
|
|
("ctrl", "CTRLLMHeadModel"), |
|
|
|
("reformer", "ReformerModelWithLMHead"), |
|
|
|
("bert-generation", "BertGenerationDecoder"), |
|
|
|
("xlm-prophetnet", "XLMProphetNetForCausalLM"), |
|
|
|
("prophetnet", "ProphetNetForCausalLM"), |
|
|
|
("bart", "BartForCausalLM"), |
|
|
|
("mbart", "MBartForCausalLM"), |
|
|
|
("pegasus", "PegasusForCausalLM"), |
|
|
|
("marian", "MarianForCausalLM"), |
|
|
|
("blenderbot", "BlenderbotForCausalLM"), |
|
|
|
("blenderbot-small", "BlenderbotSmallForCausalLM"), |
|
|
|
("megatron-bert", "MegatronBertForCausalLM"), |
|
|
|
("speech_to_text_2", "Speech2Text2ForCausalLM"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Image Classification mapping |
|
|
|
("vit", "ViTForImageClassification"), |
|
|
|
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), |
|
|
|
("beit", "BeitForImageClassification"), |
|
|
|
] |
|
|
|
) |
|
|
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict([]) |
|
|
|
|
|
|
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Masked LM mapping |
|
|
|
("fnet", "FNetForMaskedLM"), |
|
|
|
("rembert", "RemBertForMaskedLM"), |
|
|
|
("roformer", "RoFormerForMaskedLM"), |
|
|
|
("big_bird", "BigBirdForMaskedLM"), |
|
|
|
("wav2vec2", "Wav2Vec2ForMaskedLM"), |
|
|
|
("convbert", "ConvBertForMaskedLM"), |
|
|
|
("layoutlm", "LayoutLMForMaskedLM"), |
|
|
|
("distilbert", "DistilBertForMaskedLM"), |
|
|
|
("albert", "AlbertForMaskedLM"), |
|
|
|
("bart", "BartForConditionalGeneration"), |
|
|
|
("mbart", "MBartForConditionalGeneration"), |
|
|
|
("camembert", "CamembertForMaskedLM"), |
|
|
|
("xlm-roberta", "XLMRobertaForMaskedLM"), |
|
|
|
("longformer", "LongformerForMaskedLM"), |
|
|
|
("roberta", "RobertaForMaskedLM"), |
|
|
|
("squeezebert", "SqueezeBertForMaskedLM"), |
|
|
|
("bert", "BertForMaskedLM"), |
|
|
|
("megatron-bert", "MegatronBertForMaskedLM"), |
|
|
|
("mobilebert", "MobileBertForMaskedLM"), |
|
|
|
("flaubert", "FlaubertWithLMHeadModel"), |
|
|
|
("xlm", "XLMWithLMHeadModel"), |
|
|
|
("electra", "ElectraForMaskedLM"), |
|
|
|
("reformer", "ReformerForMaskedLM"), |
|
|
|
("funnel", "FunnelForMaskedLM"), |
|
|
|
("mpnet", "MPNetForMaskedLM"), |
|
|
|
("tapas", "TapasForMaskedLM"), |
|
|
|
("deberta", "DebertaForMaskedLM"), |
|
|
|
("deberta-v2", "DebertaV2ForMaskedLM"), |
|
|
|
("ibert", "IBertForMaskedLM"), |
|
|
|
("cpt", "CPTForConditionalGeneration"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Object Detection mapping |
|
|
|
("detr", "DetrForObjectDetection"), |
|
|
|
] |
|
|
|
) |
|
|
|
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict([]) |
|
|
|
|
|
|
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Seq2Seq Causal LM mapping |
|
|
|
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), |
|
|
|
("m2m_100", "M2M100ForConditionalGeneration"), |
|
|
|
("led", "LEDForConditionalGeneration"), |
|
|
|
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), |
|
|
|
("mt5", "MT5ForConditionalGeneration"), |
|
|
|
("t5", "T5ForConditionalGeneration"), |
|
|
|
("pegasus", "PegasusForConditionalGeneration"), |
|
|
|
("marian", "MarianMTModel"), |
|
|
|
("mbart", "MBartForConditionalGeneration"), |
|
|
|
("blenderbot", "BlenderbotForConditionalGeneration"), |
|
|
|
("bart", "BartForConditionalGeneration"), |
|
|
|
("fsmt", "FSMTForConditionalGeneration"), |
|
|
|
("encoder-decoder", "EncoderDecoderModel"), |
|
|
|
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), |
|
|
|
("prophetnet", "ProphetNetForConditionalGeneration"), |
|
|
|
("cpt", "CPTForConditionalGeneration"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
("speech-encoder-decoder", "SpeechEncoderDecoderModel"), |
|
|
|
("speech_to_text", "Speech2TextForConditionalGeneration"), |
|
|
|
] |
|
|
|
) |
|
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict([]) |
|
|
|
|
|
|
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Sequence Classification mapping |
|
|
|
("fnet", "FNetForSequenceClassification"), |
|
|
|
("gptj", "GPTJForSequenceClassification"), |
|
|
|
("layoutlmv2", "LayoutLMv2ForSequenceClassification"), |
|
|
|
("rembert", "RemBertForSequenceClassification"), |
|
|
|
("canine", "CanineForSequenceClassification"), |
|
|
|
("roformer", "RoFormerForSequenceClassification"), |
|
|
|
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), |
|
|
|
("big_bird", "BigBirdForSequenceClassification"), |
|
|
|
("convbert", "ConvBertForSequenceClassification"), |
|
|
|
("led", "LEDForSequenceClassification"), |
|
|
|
("distilbert", "DistilBertForSequenceClassification"), |
|
|
|
("albert", "AlbertForSequenceClassification"), |
|
|
|
("camembert", "CamembertForSequenceClassification"), |
|
|
|
("xlm-roberta", "XLMRobertaForSequenceClassification"), |
|
|
|
("mbart", "MBartForSequenceClassification"), |
|
|
|
("bart", "BartForSequenceClassification"), |
|
|
|
("longformer", "LongformerForSequenceClassification"), |
|
|
|
("roberta", "RobertaForSequenceClassification"), |
|
|
|
("squeezebert", "SqueezeBertForSequenceClassification"), |
|
|
|
("layoutlm", "LayoutLMForSequenceClassification"), |
|
|
|
("bert", "BertForSequenceClassification"), |
|
|
|
("xlnet", "XLNetForSequenceClassification"), |
|
|
|
("megatron-bert", "MegatronBertForSequenceClassification"), |
|
|
|
("mobilebert", "MobileBertForSequenceClassification"), |
|
|
|
("flaubert", "FlaubertForSequenceClassification"), |
|
|
|
("xlm", "XLMForSequenceClassification"), |
|
|
|
("electra", "ElectraForSequenceClassification"), |
|
|
|
("funnel", "FunnelForSequenceClassification"), |
|
|
|
("deberta", "DebertaForSequenceClassification"), |
|
|
|
("deberta-v2", "DebertaV2ForSequenceClassification"), |
|
|
|
("gpt2", "GPT2ForSequenceClassification"), |
|
|
|
("gpt_neo", "GPTNeoForSequenceClassification"), |
|
|
|
("openai-gpt", "OpenAIGPTForSequenceClassification"), |
|
|
|
("reformer", "ReformerForSequenceClassification"), |
|
|
|
("ctrl", "CTRLForSequenceClassification"), |
|
|
|
("transfo-xl", "TransfoXLForSequenceClassification"), |
|
|
|
("mpnet", "MPNetForSequenceClassification"), |
|
|
|
("tapas", "TapasForSequenceClassification"), |
|
|
|
("ibert", "IBertForSequenceClassification"), |
|
|
|
("cpt", "CPTForSequenceClassification"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Question Answering mapping |
|
|
|
("fnet", "FNetForQuestionAnswering"), |
|
|
|
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), |
|
|
|
("rembert", "RemBertForQuestionAnswering"), |
|
|
|
("canine", "CanineForQuestionAnswering"), |
|
|
|
("roformer", "RoFormerForQuestionAnswering"), |
|
|
|
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), |
|
|
|
("big_bird", "BigBirdForQuestionAnswering"), |
|
|
|
("convbert", "ConvBertForQuestionAnswering"), |
|
|
|
("led", "LEDForQuestionAnswering"), |
|
|
|
("distilbert", "DistilBertForQuestionAnswering"), |
|
|
|
("albert", "AlbertForQuestionAnswering"), |
|
|
|
("camembert", "CamembertForQuestionAnswering"), |
|
|
|
("bart", "BartForQuestionAnswering"), |
|
|
|
("mbart", "MBartForQuestionAnswering"), |
|
|
|
("longformer", "LongformerForQuestionAnswering"), |
|
|
|
("xlm-roberta", "XLMRobertaForQuestionAnswering"), |
|
|
|
("roberta", "RobertaForQuestionAnswering"), |
|
|
|
("squeezebert", "SqueezeBertForQuestionAnswering"), |
|
|
|
("bert", "BertForQuestionAnswering"), |
|
|
|
("xlnet", "XLNetForQuestionAnsweringSimple"), |
|
|
|
("flaubert", "FlaubertForQuestionAnsweringSimple"), |
|
|
|
("megatron-bert", "MegatronBertForQuestionAnswering"), |
|
|
|
("mobilebert", "MobileBertForQuestionAnswering"), |
|
|
|
("xlm", "XLMForQuestionAnsweringSimple"), |
|
|
|
("electra", "ElectraForQuestionAnswering"), |
|
|
|
("reformer", "ReformerForQuestionAnswering"), |
|
|
|
("funnel", "FunnelForQuestionAnswering"), |
|
|
|
("lxmert", "LxmertForQuestionAnswering"), |
|
|
|
("mpnet", "MPNetForQuestionAnswering"), |
|
|
|
("deberta", "DebertaForQuestionAnswering"), |
|
|
|
("deberta-v2", "DebertaV2ForQuestionAnswering"), |
|
|
|
("ibert", "IBertForQuestionAnswering"), |
|
|
|
("splinter", "SplinterForQuestionAnswering"), |
|
|
|
("cpt", "CPTForQuestionAnswering"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Table Question Answering mapping |
|
|
|
("tapas", "TapasForQuestionAnswering"), |
|
|
|
] |
|
|
|
) |
|
|
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict([]) |
|
|
|
|
|
|
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Token Classification mapping |
|
|
|
("fnet", "FNetForTokenClassification"), |
|
|
|
("layoutlmv2", "LayoutLMv2ForTokenClassification"), |
|
|
|
("rembert", "RemBertForTokenClassification"), |
|
|
|
("canine", "CanineForTokenClassification"), |
|
|
|
("roformer", "RoFormerForTokenClassification"), |
|
|
|
("big_bird", "BigBirdForTokenClassification"), |
|
|
|
("convbert", "ConvBertForTokenClassification"), |
|
|
|
("layoutlm", "LayoutLMForTokenClassification"), |
|
|
|
("distilbert", "DistilBertForTokenClassification"), |
|
|
|
("camembert", "CamembertForTokenClassification"), |
|
|
|
("flaubert", "FlaubertForTokenClassification"), |
|
|
|
("xlm", "XLMForTokenClassification"), |
|
|
|
("xlm-roberta", "XLMRobertaForTokenClassification"), |
|
|
|
("longformer", "LongformerForTokenClassification"), |
|
|
|
("roberta", "RobertaForTokenClassification"), |
|
|
|
("squeezebert", "SqueezeBertForTokenClassification"), |
|
|
|
("bert", "BertForTokenClassification"), |
|
|
|
("megatron-bert", "MegatronBertForTokenClassification"), |
|
|
|
("mobilebert", "MobileBertForTokenClassification"), |
|
|
|
("xlnet", "XLNetForTokenClassification"), |
|
|
|
("albert", "AlbertForTokenClassification"), |
|
|
|
("electra", "ElectraForTokenClassification"), |
|
|
|
("funnel", "FunnelForTokenClassification"), |
|
|
|
("mpnet", "MPNetForTokenClassification"), |
|
|
|
("deberta", "DebertaForTokenClassification"), |
|
|
|
("deberta-v2", "DebertaV2ForTokenClassification"), |
|
|
|
("gpt2", "GPT2ForTokenClassification"), |
|
|
|
("ibert", "IBertForTokenClassification"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Multiple Choice mapping |
|
|
|
("fnet", "FNetForMultipleChoice"), |
|
|
|
("rembert", "RemBertForMultipleChoice"), |
|
|
|
("canine", "CanineForMultipleChoice"), |
|
|
|
("roformer", "RoFormerForMultipleChoice"), |
|
|
|
("big_bird", "BigBirdForMultipleChoice"), |
|
|
|
("convbert", "ConvBertForMultipleChoice"), |
|
|
|
("camembert", "CamembertForMultipleChoice"), |
|
|
|
("electra", "ElectraForMultipleChoice"), |
|
|
|
("xlm-roberta", "XLMRobertaForMultipleChoice"), |
|
|
|
("longformer", "LongformerForMultipleChoice"), |
|
|
|
("roberta", "RobertaForMultipleChoice"), |
|
|
|
("squeezebert", "SqueezeBertForMultipleChoice"), |
|
|
|
("bert", "BertForMultipleChoice"), |
|
|
|
("distilbert", "DistilBertForMultipleChoice"), |
|
|
|
("megatron-bert", "MegatronBertForMultipleChoice"), |
|
|
|
("mobilebert", "MobileBertForMultipleChoice"), |
|
|
|
("xlnet", "XLNetForMultipleChoice"), |
|
|
|
("albert", "AlbertForMultipleChoice"), |
|
|
|
("xlm", "XLMForMultipleChoice"), |
|
|
|
("flaubert", "FlaubertForMultipleChoice"), |
|
|
|
("funnel", "FunnelForMultipleChoice"), |
|
|
|
("mpnet", "MPNetForMultipleChoice"), |
|
|
|
("ibert", "IBertForMultipleChoice"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
("bert", "BertForNextSentencePrediction"), |
|
|
|
("fnet", "FNetForNextSentencePrediction"), |
|
|
|
("megatron-bert", "MegatronBertForNextSentencePrediction"), |
|
|
|
("mobilebert", "MobileBertForNextSentencePrediction"), |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Audio Classification mapping |
|
|
|
("wav2vec2", "Wav2Vec2ForSequenceClassification"), |
|
|
|
("hubert", "HubertForSequenceClassification"), |
|
|
|
] |
|
|
|
) |
|
|
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([]) |
|
|
|
|
|
|
|
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( |
|
|
|
[ |
|
|
|
# Model for Connectionist temporal classification (CTC) mapping |
|
|
|
("wav2vec2", "Wav2Vec2ForCTC"), |
|
|
|
("hubert", "HubertForCTC"), |
|
|
|
] |
|
|
|
) |
|
|
|
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict([]) |
|
|
|
|
|
|
|
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) |
|
|
|
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) |
|
|
|