Browse Source

为transformers 添加AutoModel

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
639de99563
6 changed files with 28 additions and 711 deletions
  1. +2
    -2
      fastNLP/transformers/torch/models/auto/__init__.py
  2. +4
    -1
      fastNLP/transformers/torch/models/auto/auto_factory.py
  3. +3
    -191
      fastNLP/transformers/torch/models/auto/configuration_auto.py
  4. +13
    -350
      fastNLP/transformers/torch/models/auto/modeling_auto.py
  5. +4
    -166
      fastNLP/transformers/torch/models/auto/tokenization_auto.py
  6. +2
    -1
      fastNLP/transformers/torch/models/cpt/__init__.py

+ 2
- 2
fastNLP/transformers/torch/models/auto/__init__.py View File

@@ -3,7 +3,7 @@ __all__ = [
"CONFIG_MAPPING",
"MODEL_NAMES_MAPPING",
"AutoConfig",
"TOKENIZER_MAPPING",
"TOKENIZER_MAPPING_NAMES",
"get_values",
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
@@ -43,7 +43,7 @@ __all__ = [

from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \
AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING
from .tokenization_auto import TOKENIZER_MAPPING_NAMES
from .auto_factory import get_values
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,


+ 4
- 1
fastNLP/transformers/torch/models/auto/auto_factory.py View File

@@ -516,7 +516,10 @@ class _LazyAutoMapping(OrderedDict):
def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
try:
self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models")
except ImportError:
raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.")
return getattribute_from_module(self._modules[module_name], attr)

def keys(self):


+ 3
- 191
fastNLP/transformers/torch/models/auto/configuration_auto.py View File

@@ -26,221 +26,33 @@ from fastNLP.core.log import logger
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("fnet", "FNetConfig"),
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("beit", "BeitConfig"),
("rembert", "RemBertConfig"),
("visual_bert", "VisualBertConfig"),
("canine", "CanineConfig"),
("roformer", "RoFormerConfig"),
("clip", "CLIPConfig"),
("bigbird_pegasus", "BigBirdPegasusConfig"),
("deit", "DeiTConfig"),
("luke", "LukeConfig"),
("detr", "DetrConfig"),
("gpt_neo", "GPTNeoConfig"),
("big_bird", "BigBirdConfig"),
("speech_to_text_2", "Speech2Text2Config"),
("speech_to_text", "Speech2TextConfig"),
("vit", "ViTConfig"),
("wav2vec2", "Wav2Vec2Config"),
("m2m_100", "M2M100Config"),
("convbert", "ConvBertConfig"),
("led", "LEDConfig"),
("blenderbot-small", "BlenderbotSmallConfig"),
("retribert", "RetriBertConfig"),
("ibert", "IBertConfig"),
("mt5", "MT5Config"),
("t5", "T5Config"),
("mobilebert", "MobileBertConfig"),
("distilbert", "DistilBertConfig"),
("albert", "AlbertConfig"),
("bert-generation", "BertGenerationConfig"),
("camembert", "CamembertConfig"),
("xlm-roberta", "XLMRobertaConfig"),
("pegasus", "PegasusConfig"),
("marian", "MarianConfig"),
("mbart", "MBartConfig"),
("megatron-bert", "MegatronBertConfig"),
("mpnet", "MPNetConfig"),
("bart", "BartConfig"),
("blenderbot", "BlenderbotConfig"),
("reformer", "ReformerConfig"),
("longformer", "LongformerConfig"),
("roberta", "RobertaConfig"),
("deberta-v2", "DebertaV2Config"),
("deberta", "DebertaConfig"),
("flaubert", "FlaubertConfig"),
("fsmt", "FSMTConfig"),
("squeezebert", "SqueezeBertConfig"),
("hubert", "HubertConfig"),
("bert", "BertConfig"),
("openai-gpt", "OpenAIGPTConfig"),
("gpt2", "GPT2Config"),
("transfo-xl", "TransfoXLConfig"),
("xlnet", "XLNetConfig"),
("xlm-prophetnet", "XLMProphetNetConfig"),
("prophetnet", "ProphetNetConfig"),
("xlm", "XLMConfig"),
("ctrl", "CTRLConfig"),
("electra", "ElectraConfig"),
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
("encoder-decoder", "EncoderDecoderConfig"),
("funnel", "FunnelConfig"),
("lxmert", "LxmertConfig"),
("dpr", "DPRConfig"),
("layoutlm", "LayoutLMConfig"),
("rag", "RagConfig"),
("tapas", "TapasConfig"),
("splinter", "SplinterConfig"),
("cpt", "CPTConfig"),
]
)

CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("cpt", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
]
)

MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("fnet", "FNet"),
("gptj", "GPT-J"),
("beit", "BeiT"),
("rembert", "RemBERT"),
("layoutlmv2", "LayoutLMv2"),
("visual_bert", "VisualBert"),
("canine", "Canine"),
("roformer", "RoFormer"),
("clip", "CLIP"),
("bigbird_pegasus", "BigBirdPegasus"),
("deit", "DeiT"),
("luke", "LUKE"),
("detr", "DETR"),
("gpt_neo", "GPT Neo"),
("big_bird", "BigBird"),
("speech_to_text_2", "Speech2Text2"),
("speech_to_text", "Speech2Text"),
("vit", "ViT"),
("wav2vec2", "Wav2Vec2"),
("m2m_100", "M2M100"),
("convbert", "ConvBERT"),
("led", "LED"),
("blenderbot-small", "BlenderbotSmall"),
("retribert", "RetriBERT"),
("ibert", "I-BERT"),
("t5", "T5"),
("mobilebert", "MobileBERT"),
("distilbert", "DistilBERT"),
("albert", "ALBERT"),
("bert-generation", "Bert Generation"),
("camembert", "CamemBERT"),
("xlm-roberta", "XLM-RoBERTa"),
("pegasus", "Pegasus"),
("blenderbot", "Blenderbot"),
("marian", "Marian"),
("mbart", "mBART"),
("megatron-bert", "MegatronBert"),
("bart", "BART"),
("reformer", "Reformer"),
("longformer", "Longformer"),
("roberta", "RoBERTa"),
("flaubert", "FlauBERT"),
("fsmt", "FairSeq Machine-Translation"),
("squeezebert", "SqueezeBERT"),
("bert", "BERT"),
("openai-gpt", "OpenAI GPT"),
("gpt2", "OpenAI GPT-2"),
("transfo-xl", "Transformer-XL"),
("xlnet", "XLNet"),
("xlm", "XLM"),
("ctrl", "CTRL"),
("electra", "ELECTRA"),
("encoder-decoder", "Encoder decoder"),
("speech-encoder-decoder", "Speech Encoder decoder"),
("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"),
("deberta-v2", "DeBERTa-v2"),
("deberta", "DeBERTa"),
("layoutlm", "LayoutLM"),
("dpr", "DPR"),
("rag", "RAG"),
("xlm-prophetnet", "XLMProphetNet"),
("prophetnet", "ProphetNet"),
("mt5", "mT5"),
("mpnet", "MPNet"),
("tapas", "TAPAS"),
("hubert", "Hubert"),
("barthez", "BARThez"),
("phobert", "PhoBERT"),
("cpm", "CPM"),
("bertweet", "Bertweet"),
("bert-japanese", "BertJapanese"),
("byt5", "ByT5"),
("mbart50", "mBART-50"),
("splinter", "Splinter"),
("cpt", "CPT")
]
)



+ 13
- 350
fastNLP/transformers/torch/models/auto/modeling_auto.py View File

@@ -24,455 +24,118 @@ from fastNLP.core.log import logger

MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("fnet", "FNetModel"),
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("beit", "BeitModel"),
("rembert", "RemBertModel"),
("visual_bert", "VisualBertModel"),
("canine", "CanineModel"),
("roformer", "RoFormerModel"),
("clip", "CLIPModel"),
("bigbird_pegasus", "BigBirdPegasusModel"),
("deit", "DeiTModel"),
("luke", "LukeModel"),
("detr", "DetrModel"),
("gpt_neo", "GPTNeoModel"),
("big_bird", "BigBirdModel"),
("speech_to_text", "Speech2TextModel"),
("vit", "ViTModel"),
("wav2vec2", "Wav2Vec2Model"),
("hubert", "HubertModel"),
("m2m_100", "M2M100Model"),
("convbert", "ConvBertModel"),
("led", "LEDModel"),
("blenderbot-small", "BlenderbotSmallModel"),
("retribert", "RetriBertModel"),
("mt5", "MT5Model"),
("t5", "T5Model"),
("pegasus", "PegasusModel"),
("marian", "MarianModel"),
("mbart", "MBartModel"),
("blenderbot", "BlenderbotModel"),
("distilbert", "DistilBertModel"),
("albert", "AlbertModel"),
("camembert", "CamembertModel"),
("xlm-roberta", "XLMRobertaModel"),
("bart", "BartModel"),
("longformer", "LongformerModel"),
("roberta", "RobertaModel"),
("layoutlm", "LayoutLMModel"),
("squeezebert", "SqueezeBertModel"),
("bert", "BertModel"),
("openai-gpt", "OpenAIGPTModel"),
("gpt2", "GPT2Model"),
("megatron-bert", "MegatronBertModel"),
("mobilebert", "MobileBertModel"),
("transfo-xl", "TransfoXLModel"),
("xlnet", "XLNetModel"),
("flaubert", "FlaubertModel"),
("fsmt", "FSMTModel"),
("xlm", "XLMModel"),
("ctrl", "CTRLModel"),
("electra", "ElectraModel"),
("reformer", "ReformerModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("lxmert", "LxmertModel"),
("bert-generation", "BertGenerationEncoder"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("dpr", "DPRQuestionEncoder"),
("xlm-prophetnet", "XLMProphetNetModel"),
("prophetnet", "ProphetNetModel"),
("mpnet", "MPNetModel"),
("tapas", "TapasModel"),
("ibert", "IBertModel"),
("splinter", "SplinterModel"),
("cpt", "CPTModel"),
]
)

MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("fnet", "FNetForPreTraining"),
("visual_bert", "VisualBertForPreTraining"),
("layoutlm", "LayoutLMForMaskedLM"),
("retribert", "RetriBertModel"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForPreTraining"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForPreTraining"),
("big_bird", "BigBirdForPreTraining"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForPreTraining"),
("lxmert", "LxmertForPreTraining"),
("funnel", "FunnelForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("wav2vec2", "Wav2Vec2ForPreTraining"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("fnet", "FNetForMaskedLM"),
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForMaskedLM"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("convbert", "ConvBertForMaskedLM"),
("led", "LEDForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("layoutlm", "LayoutLMForMaskedLM"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("marian", "MarianMTModel"),
("fsmt", "FSMTForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForCausalLM"),
("mobilebert", "MobileBertForMaskedLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("encoder-decoder", "EncoderDecoderModel"),
("reformer", "ReformerModelWithLMHead"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForCausalLM"),
("camembert", "CamembertForCausalLM"),
("xlm-roberta", "XLMRobertaForCausalLM"),
("roberta", "RobertaForCausalLM"),
("bert", "BertLMHeadModel"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("reformer", "ReformerModelWithLMHead"),
("bert-generation", "BertGenerationDecoder"),
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("bart", "BartForCausalLM"),
("mbart", "MBartForCausalLM"),
("pegasus", "PegasusForCausalLM"),
("marian", "MarianForCausalLM"),
("blenderbot", "BlenderbotForCausalLM"),
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("speech_to_text_2", "Speech2Text2ForCausalLM"),
]
)

MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image Classification mapping
("vit", "ViTForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
("beit", "BeitForImageClassification"),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("fnet", "FNetForMaskedLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("big_bird", "BigBirdForMaskedLM"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("mbart", "MBartForConditionalGeneration"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("reformer", "ReformerForMaskedLM"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# Model for Object Detection mapping
("detr", "DetrForObjectDetection"),
]
)
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("mt5", "MT5ForConditionalGeneration"),
("t5", "T5ForConditionalGeneration"),
("pegasus", "PegasusForConditionalGeneration"),
("marian", "MarianMTModel"),
("mbart", "MBartForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"),
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
("prophetnet", "ProphetNetForConditionalGeneration"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
]
)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("fnet", "FNetForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("roformer", "RoFormerForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("distilbert", "DistilBertForSequenceClassification"),
("albert", "AlbertForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("xlm-roberta", "XLMRobertaForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("bart", "BartForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("roberta", "RobertaForSequenceClassification"),
("squeezebert", "SqueezeBertForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("bert", "BertForSequenceClassification"),
("xlnet", "XLNetForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("flaubert", "FlaubertForSequenceClassification"),
("xlm", "XLMForSequenceClassification"),
("electra", "ElectraForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("deberta", "DebertaForSequenceClassification"),
("deberta-v2", "DebertaV2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_neo", "GPTNeoForSequenceClassification"),
("openai-gpt", "OpenAIGPTForSequenceClassification"),
("reformer", "ReformerForSequenceClassification"),
("ctrl", "CTRLForSequenceClassification"),
("transfo-xl", "TransfoXLForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
("tapas", "TapasForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
("cpt", "CPTForSequenceClassification"),
]
)

MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("fnet", "FNetForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"),
("canine", "CanineForQuestionAnswering"),
("roformer", "RoFormerForQuestionAnswering"),
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
("big_bird", "BigBirdForQuestionAnswering"),
("convbert", "ConvBertForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("distilbert", "DistilBertForQuestionAnswering"),
("albert", "AlbertForQuestionAnswering"),
("camembert", "CamembertForQuestionAnswering"),
("bart", "BartForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
("roberta", "RobertaForQuestionAnswering"),
("squeezebert", "SqueezeBertForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
("xlnet", "XLNetForQuestionAnsweringSimple"),
("flaubert", "FlaubertForQuestionAnsweringSimple"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("xlm", "XLMForQuestionAnsweringSimple"),
("electra", "ElectraForQuestionAnswering"),
("reformer", "ReformerForQuestionAnswering"),
("funnel", "FunnelForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
("deberta", "DebertaForQuestionAnswering"),
("deberta-v2", "DebertaV2ForQuestionAnswering"),
("ibert", "IBertForQuestionAnswering"),
("splinter", "SplinterForQuestionAnswering"),
("cpt", "CPTForQuestionAnswering"),
]
)

MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Table Question Answering mapping
("tapas", "TapasForQuestionAnswering"),
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("fnet", "FNetForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("rembert", "RemBertForTokenClassification"),
("canine", "CanineForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("convbert", "ConvBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("distilbert", "DistilBertForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("flaubert", "FlaubertForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("roberta", "RobertaForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("bert", "BertForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("xlnet", "XLNetForTokenClassification"),
("albert", "AlbertForTokenClassification"),
("electra", "ElectraForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
("deberta", "DebertaForTokenClassification"),
("deberta-v2", "DebertaV2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("ibert", "IBertForTokenClassification"),
]
)

MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("fnet", "FNetForMultipleChoice"),
("rembert", "RemBertForMultipleChoice"),
("canine", "CanineForMultipleChoice"),
("roformer", "RoFormerForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
("convbert", "ConvBertForMultipleChoice"),
("camembert", "CamembertForMultipleChoice"),
("electra", "ElectraForMultipleChoice"),
("xlm-roberta", "XLMRobertaForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("roberta", "RobertaForMultipleChoice"),
("squeezebert", "SqueezeBertForMultipleChoice"),
("bert", "BertForMultipleChoice"),
("distilbert", "DistilBertForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("xlnet", "XLNetForMultipleChoice"),
("albert", "AlbertForMultipleChoice"),
("xlm", "XLMForMultipleChoice"),
("flaubert", "FlaubertForMultipleChoice"),
("funnel", "FunnelForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
]
)

MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "BertForNextSentencePrediction"),
("fnet", "FNetForNextSentencePrediction"),
("megatron-bert", "MegatronBertForNextSentencePrediction"),
("mobilebert", "MobileBertForNextSentencePrediction"),
]
)

MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("hubert", "HubertForSequenceClassification"),
]
)
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
[
# Model for Connectionist temporal classification (CTC) mapping
("wav2vec2", "Wav2Vec2ForCTC"),
("hubert", "HubertForCTC"),
]
)
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict([])

MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)


+ 4
- 166
fastNLP/transformers/torch/models/auto/tokenization_auto.py View File

@@ -29,171 +29,9 @@ if TYPE_CHECKING:
else:
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
(
"t5",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mt5",
(
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
),
),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"albert",
(
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"camembert",
(
"CamembertTokenizer" if is_sentencepiece_available() else None,
"CamembertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"pegasus",
(
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart",
(
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"xlm-roberta",
(
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
),
),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("blenderbot", ("BlenderbotTokenizer", None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"reformer",
(
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
),
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
(
"dpr",
(
"DPRQuestionEncoderTokenizer",
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"squeezebert",
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
("flaubert", ("FlaubertTokenizer", None)),
("xlm", ("XLMTokenizer", None)),
("ctrl", ("CTRLTokenizer", None)),
("fsmt", ("FSMTTokenizer", None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("tapas", ("TapasTokenizer", None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"big_bird",
(
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
),
),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
("byt5", ("ByT5Tokenizer", None)),
(
"cpm",
(
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
),
),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
(
"barthez",
(
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart50",
(
"MBart50Tokenizer" if is_sentencepiece_available() else None,
"MBart50TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"rembert",
(
"RemBertTokenizer" if is_sentencepiece_available() else None,
"RemBertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"clip",
(
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
),
("bart", ("BartTokenizer", None)),
("roberta", ("RobertaTokenizer", None)),
("bert", ("BertTokenizer", None)),
("gpt2", ("GPT2Tokenizer", None)),
]
)

+ 2
- 1
fastNLP/transformers/torch/models/cpt/__init__.py View File

@@ -1,5 +1,6 @@
__all__ = [
"CPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"CPTConfig",
"CPTForConditionalGeneration",
"CPTForSequenceClassification",
"CPTForMaskedLM",
@@ -9,4 +10,4 @@ __all__ = [
]

from .modeling_cpt import CPT_PRETRAINED_MODEL_ARCHIVE_LIST, CPTForConditionalGeneration, CPTForSequenceClassification, \
CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel
CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel, CPTConfig

Loading…
Cancel
Save