From 296c1acc31c4d2f8fab292e32bb0cce7ecdab617 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 08:21:33 +0000 Subject: [PATCH 1/6] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dmixdataloader=E5=9C=A8tor?= =?UTF-8?q?ch1.6=E4=B8=8B=E5=8F=82=E6=95=B0=E4=B8=8D=E5=8C=B9=E9=85=8D?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dataloaders/torch_dataloader/mix_dataloader.py | 24 +++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py index 29b0cd0b..1b77be77 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py +++ b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py @@ -5,6 +5,7 @@ __all__ = [ from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping import numpy as np +from pkg_resources import parse_version from fastNLP.core.dataset import DataSet, Instance from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler @@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.collators import Collator if _NEED_IMPORT_TORCH: + from torch import __version__ as torchversion from torch.utils.data import DataLoader, Sampler else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader @@ -213,13 +215,21 @@ class MixDataLoader(DataLoader): else: raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") - super(MixDataLoader, self).__init__( - _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, - batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, - pin_memory=pin_memory, drop_last=False, timeout=0, - worker_init_fn=None, multiprocessing_context=None, generator=None, - prefetch_factor=2, persistent_workers=False - ) + if parse_version(torchversion) >= parse_version('1.7'): + super(MixDataLoader, self).__init__( + _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None, generator=None, + prefetch_factor=2, persistent_workers=False + ) + else: + super(MixDataLoader, self).__init__( + _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None, generator=None, + ) def __iter__(self): return super().__iter__() From e734d11d3b71b26b642a72201b6c818667b84e64 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 15:30:14 +0000 Subject: [PATCH 2/6] =?UTF-8?q?transformers=20=E6=B7=BB=E5=8A=A0=20AutoTok?= =?UTF-8?q?enizer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/transformers/torch/models/auto/__init__.py | 5 +- .../torch/models/auto/tokenization_auto.py | 316 ++++++++++++++++++++- .../torch/models/encoder_decoder/__init__.py | 5 + .../configuration_encoder_decoder.py | 114 ++++++++ 4 files changed, 435 insertions(+), 5 deletions(-) create mode 100644 fastNLP/transformers/torch/models/encoder_decoder/__init__.py create mode 100644 fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py diff --git a/fastNLP/transformers/torch/models/auto/__init__.py b/fastNLP/transformers/torch/models/auto/__init__.py index 0ce22235..7a6cfb21 100644 --- a/fastNLP/transformers/torch/models/auto/__init__.py +++ b/fastNLP/transformers/torch/models/auto/__init__.py @@ -3,7 +3,8 @@ __all__ = [ "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig", - "TOKENIZER_MAPPING_NAMES", + "TOKENIZER_MAPPING", + "AutoTokenizer", "get_values", "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", @@ -43,7 +44,7 @@ __all__ = [ from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ AutoConfig -from .tokenization_auto import TOKENIZER_MAPPING_NAMES +from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from .auto_factory import get_values from .modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, diff --git a/fastNLP/transformers/torch/models/auto/tokenization_auto.py b/fastNLP/transformers/torch/models/auto/tokenization_auto.py index d30cbae1..8adbe50b 100644 --- a/fastNLP/transformers/torch/models/auto/tokenization_auto.py +++ b/fastNLP/transformers/torch/models/auto/tokenization_auto.py @@ -13,14 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Auto Tokenizer class. """ - +import importlib +import json +import os from collections import OrderedDict -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from ...configuration_utils import PretrainedConfig from ...file_utils import ( + cached_path, + hf_bucket_url, + is_offline_mode, is_sentencepiece_available, is_tokenizers_available, ) +from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE +from ..encoder_decoder import EncoderDecoderConfig +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + config_class_to_model_type, + model_type_to_module_name, + replace_list_option_in_docstrings, +) +from fastNLP.core.log import logger if TYPE_CHECKING: # This significantly improves completion suggestion performance when @@ -34,4 +51,297 @@ else: ("bert", ("BertTokenizer", None)), ("gpt2", ("GPT2Tokenizer", None)), ] - ) \ No newline at end of file + ) + +TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES) + +CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()} + + +def tokenizer_class_from_name(class_name: str): + if class_name == "PreTrainedTokenizerFast": + raise RuntimeError("fastNLP does not support TokenizerFast now.") + + for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): + if class_name in tokenizers: + module_name = model_type_to_module_name(module_name) + + try: + module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") + except ImportError: + raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.") + return getattr(module, class_name) + + return None + + +def get_tokenizer_config( + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Loads the tokenizer configuration from a pretrained model tokenizer configuration. + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a configuration file saved using the + :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``. + + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`True`, will only try to load the tokenizer configuration from local files. + + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + + + Returns: + :obj:`Dict`: The configuration of the tokenizer. + + Examples:: + + # Download configuration from huggingface.co and cache. + tokenizer_config = get_tokenizer_config("bert-base-uncased") + # This model does not have a tokenizer config so the result will be an empty dict. + tokenizer_config = get_tokenizer_config("xlm-roberta-base") + + # Save a pretrained tokenizer locally and you can reload its config + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + tokenizer.save_pretrained("tokenizer-test") + tokenizer_config = get_tokenizer_config("tokenizer-test") + """ + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) + else: + config_file = hf_bucket_url( + pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None + ) + + try: + # Load from URL or cache if already cached + resolved_config_file = cached_path( + config_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except EnvironmentError: + logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") + return {} + + with open(resolved_config_file, encoding="utf-8") as reader: + return json.load(reader) + + +class AutoTokenizer: + r""" + This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when + created with the :meth:`AutoTokenizer.from_pretrained` class method. + + This class cannot be instantiated directly using ``__init__()`` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoTokenizer is designed to be instantiated " + "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES) + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + r""" + Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary. + + The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object + (either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's + missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + Params: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved + using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., + ``./my_model_directory/``. + - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a + single vocabulary file (like Bert or XLNet), e.g.: ``./my_model_directory/vocab.txt``. (Not + applicable to all derived classes) + inputs (additional positional arguments, `optional`): + Will be passed along to the Tokenizer ``__init__()`` method. + config (:class:`~transformers.PretrainedConfig`, `optional`) + The configuration object used to dertermine the tokenizer class to instantiate. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download the model weights and configuration files and override the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + subfolder (:obj:`str`, `optional`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for + facebook/rag-token-base), specify it here. + use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to try to load the fast version of the tokenizer. + tokenizer_type (:obj:`str`, `optional`): + Tokenizer type to be loaded. + kwargs (additional keyword arguments, `optional`): + Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like + ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, + ``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__()`` for more details. + + Examples:: + + >>> from transformers import AutoTokenizer + + >>> # Download vocabulary from huggingface.co and cache. + >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + + >>> # Download vocabulary from huggingface.co (user-uploaded) and cache. + >>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased') + + >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) + >>> tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') + + """ + config = kwargs.pop("config", None) + kwargs["_from_auto"] = True + + use_fast = kwargs.pop("use_fast", True) + tokenizer_type = kwargs.pop("tokenizer_type", None) + + # First, let's see whether the tokenizer_type is passed so that we can leverage it + if tokenizer_type is not None: + tokenizer_class = None + tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None) + + if tokenizer_class_tuple is None: + raise ValueError( + f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of " + f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}." + ) + + tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple + + if use_fast and tokenizer_fast_class_name is not None: + tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name) + + if tokenizer_class is None: + tokenizer_class = tokenizer_class_from_name(tokenizer_class_name) + + if tokenizer_class is None: + raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.") + + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + + # Next, let's try to use the tokenizer_config file to get the tokenizer class. + tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) + config_tokenizer_class = tokenizer_config.get("tokenizer_class") + + # If that did not work, let's try to use the config. + if config_tokenizer_class is None: + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config_tokenizer_class = config.tokenizer_class + + # If we have the tokenizer class from the tokenizer config or the model config we're good! + if config_tokenizer_class is not None: + tokenizer_class = None + if use_fast and not config_tokenizer_class.endswith("Fast"): + tokenizer_class_candidate = f"{config_tokenizer_class}Fast" + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) + if tokenizer_class is None: + tokenizer_class_candidate = config_tokenizer_class + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) + + if tokenizer_class is None: + raise ValueError( + f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported." + ) + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + + # Otherwise we have to be creative. + # if model is an encoder decoder, the encoder tokenizer class is used by default + if isinstance(config, EncoderDecoderConfig): + if type(config.decoder) is not type(config.encoder): # noqa: E721 + logger.warning( + f"The encoder model config class: {config.encoder.__class__} is different from the decoder model " + f"config class: {config.decoder.__class__}. It is not recommended to use the " + "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder " + "specific tokenizer classes." + ) + config = config.encoder + + model_type = config_class_to_model_type(type(config).__name__) + if model_type is not None: + tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)] + if tokenizer_class_fast and (use_fast or tokenizer_class_py is None): + return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + if tokenizer_class_py is not None: + return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + raise ValueError( + "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed " + "in order to use this tokenizer." + ) + + raise ValueError( + f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n" + f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}." + ) diff --git a/fastNLP/transformers/torch/models/encoder_decoder/__init__.py b/fastNLP/transformers/torch/models/encoder_decoder/__init__.py new file mode 100644 index 00000000..41327243 --- /dev/null +++ b/fastNLP/transformers/torch/models/encoder_decoder/__init__.py @@ -0,0 +1,5 @@ +__all__ = [ + "EncoderDecoderConfig", +] + +from .configuration_encoder_decoder import EncoderDecoderConfig \ No newline at end of file diff --git a/fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py b/fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py new file mode 100644 index 00000000..b36294d1 --- /dev/null +++ b/fastNLP/transformers/torch/models/encoder_decoder/configuration_encoder_decoder.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from ...configuration_utils import PretrainedConfig +from fastNLP.core.log import logger + +class EncoderDecoderConfig(PretrainedConfig): + r""" + :class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a + :class:`~transformers.EncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to the + specified arguments, defining the encoder and decoder configs. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + Args: + kwargs (`optional`): + Dictionary of keyword arguments. Notably: + + - **encoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration + object that defines the encoder config. + - **decoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration + object that defines the decoder config. + + Examples:: + + >>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel + + >>> # Initializing a BERT bert-base-uncased style configuration + >>> config_encoder = BertConfig() + >>> config_decoder = BertConfig() + + >>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) + + >>> # Initializing a Bert2Bert model from the bert-base-uncased style configurations + >>> model = EncoderDecoderModel(config=config) + + >>> # Accessing the model configuration + >>> config_encoder = model.config.encoder + >>> config_decoder = model.config.decoder + >>> # set decoder config to causal lm + >>> config_decoder.is_decoder = True + >>> config_decoder.add_cross_attention = True + + >>> # Saving the model, including its configuration + >>> model.save_pretrained('my-model') + + >>> # loading model and config from pretrained folder + >>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model') + >>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config) + """ + model_type = "encoder-decoder" + is_composition = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + assert ( + "encoder" in kwargs and "decoder" in kwargs + ), "Config has to be initialized with encoder and decoder config" + encoder_config = kwargs.pop("encoder") + encoder_model_type = encoder_config.pop("model_type") + decoder_config = kwargs.pop("decoder") + decoder_model_type = decoder_config.pop("model_type") + + from ..auto.configuration_auto import AutoConfig + + self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config) + self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config) + self.is_encoder_decoder = True + + @classmethod + def from_encoder_decoder_configs( + cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs + ) -> PretrainedConfig: + r""" + Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model + configuration and decoder model configuration. + + Returns: + :class:`EncoderDecoderConfig`: An instance of a configuration object + """ + logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["encoder"] = self.encoder.to_dict() + output["decoder"] = self.decoder.to_dict() + output["model_type"] = self.__class__.model_type + return output From efc2675741da3740d255a5d12aed2413d31e86dc Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 17:03:12 +0000 Subject: [PATCH 3/6] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9Driver.py=E7=9A=84?= =?UTF-8?q?=E5=B0=8F=E9=83=A8=E5=88=86=E6=96=87=E6=A1=A3=EF=BC=9B2.?= =?UTF-8?q?=E5=AE=8C=E5=96=84=20JittorDriver=20JittorSingleDriver=20?= =?UTF-8?q?=E7=9A=84=E9=83=A8=E5=88=86=E5=9F=BA=E7=A1=80=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/driver.py | 3 +- .../core/drivers/jittor_driver/jittor_driver.py | 168 +++++++++++++++------ fastNLP/core/drivers/jittor_driver/mpi.py | 15 +- .../core/drivers/jittor_driver/single_device.py | 54 ++++--- fastNLP/core/drivers/jittor_driver/utils.py | 50 ------ .../core/drivers/paddle_driver/paddle_driver.py | 9 +- 6 files changed, 178 insertions(+), 121 deletions(-) diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index bd06e705..fda1e6b1 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -41,7 +41,7 @@ class Driver(ABC): r""" 根据输入的 ``dataloader`` 得到一个 支持分布式 (``distributed``) 与 可复现的 (``reproducible``) 的 dataloader。 - :param dataloader: 根据 ``dataloade``r 设置其对应的分布式版本以及可复现版本; + :param dataloader: 根据 ``dataloader`` 设置其对应的分布式版本以及可复现版本; :param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``;为 ``None`` 时,表示不需要考虑当前 dataloader 切换为分布式状态;为 ``dist`` 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 不同 gpu 上出现重复;为 ``unrepeatdist`` 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 @@ -263,7 +263,6 @@ class Driver(ABC): :param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象; :param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath 模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 - :return: 返回加载指定文件后的结果; """ raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 4f7f23bd..0dd6d0fb 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -1,13 +1,17 @@ import os -import warnings -from typing import Optional, Callable, Dict +import random +from pathlib import Path +from typing import Union, Optional +from functools import partial + +import numpy as np -from .utils import _build_fp16_env from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.core.log import logger from fastNLP.core.utils import apply_to_collection +from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS if _NEED_IMPORT_JITTOR: import jittor as jt @@ -47,17 +51,18 @@ class JittorDriver(Driver): f"`jittor.Module` type.") super(JittorDriver, self).__init__(model) - self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) - self.grad_scaler = _grad_scaler() + if fp16: + jt.flags.auto_mixed_precision_level = 6 + else: + jt.flags.auto_mixed_precision_level = 0 + self.fp16 = fp16 # 用来设置是否关闭 auto_param_call 中的参数匹配问题; self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) def check_dataloader_legality(self, dataloader): - # 在fastnlp中实现了JittorDataLoader - if not isinstance(dataloader, Dataset): - raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`") - + if not isinstance(dataloader, (Dataset, JittorDataLoader)): + raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") @staticmethod def _check_optimizer_legality(optimizers): @@ -66,54 +71,102 @@ class JittorDriver(Driver): raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " f"not {type(each_optimizer)}.") - def check_evaluator_mode(self, mode: str): + def step(self): + for optimizer in self.optimizers: + optimizer.step() + + def backward(self, loss): + for optimizer in self.optimizers: + optimizer.backward(loss) + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad() + + def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): + r""" + 将模型保存到 ``filepath`` 中。 + + :param filepath: 保存文件的文件位置(需要包括文件名); + :param only_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持保存模型的 ``state_dict``。 + """ + if not only_state_dict: + logger.rank_zero_warning( + "Jittor only supports saving state_dict, and we will also save state_dict for you.", + once=True + ) + if isinstance(filepath, Path): + filepath = str(filepath) model = self.unwrap_model() - if mode == "evaluate": - if not hasattr(model, "evaluate_step"): - if hasattr(model, "test_step"): - logger.warning_once( - "Your model does not have 'evaluate_step' method but has 'test_step' method, but you" - "are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for" - "'evaluate_step'.") + model.save(filepath) - else: - if not hasattr(model, "test_step"): - if hasattr(model, "evaluate_step"): - logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" - "are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for" - "'test_step'.") - - def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): - if model_save_fn is not None: - outputs = model_save_fn(filepath) - if outputs is not None: - jt.save(outputs, filepath) - else: - if only_state_dict: - states = self.model.state_dict() - else: - warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.") - jt.save(states, filepath) + def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): + r""" + 加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。 - def load_model(self, filepath: str): - if not os.path.exists(filepath): - raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) - return jt.load(filepath) + :param filepath: 保存文件的文件位置(需要包括文件名); + :param load_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持加载模型的 ``state_dict``。 + """ + if not only_state_dict: + logger.rank_zero_warning( + "Jittor only supports loading state_dict, and we will also load state_dict for you.", + once=True + ) + if isinstance(filepath, Path): + filepath = str(filepath) + model = self.unwrap_model() + model.load(filepath) def save_checkpoint(self): ... + def get_optimizer_state(self): + # optimizers_state_dict = {} + # for i in range(len(self.optimizers)): + # optimizer: torch.optim.Optimizer = self.optimizers[i] + # optimizer_state = optimizer.state_dict() + # optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) + # optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; + # return optimizers_state_dict + ... + + def load_optimizer_state(self, states): + # assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ + # f"checkpoint it is:{len(states)}" + # for i in range(len(self.optimizers)): + # optimizer: torch.optim.Optimizer = self.optimizers[i] + # optimizer.load_state_dict(states[f"optimizer{i}"]) + # logger.debug("Load optimizer state dict.") + ... + def load_checkpoint(self): ... def get_evaluate_context(self): return jt.no_grad - def get_model_device(self): - return self.model_device + @staticmethod + def move_model_to_device(model: "jt.Module", device): + r""" + 将模型转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。 + """ + ... + + def move_data_to_device(self, batch): + """ + 将数据 ``batch`` 转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。 + """ + return batch @staticmethod def tensor_to_numeric(tensor, reduce=None): + r""" + 将一个 :class:`jittor.Var` 对象转换为 转换成 python 中的数值类型; + + :param tensor: :class:`jittor.Var` 类型的对象; + :param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; + :return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; + """ if tensor is None: return None @@ -145,7 +198,32 @@ class JittorDriver(Driver): """ return batch - # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): - # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; - # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): - # dataloader.batch_sampler.set_epoch(cur_epoch_idx) + @staticmethod + def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover + global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) + process_seed = jt.get_seed() + # back out the base seed so we can use all the bits + base_seed = process_seed - worker_id + ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) + # use 128 bits (4 x 32-bit words) + np.random.seed(ss.generate_state(4)) + # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module + jittor_ss, stdlib_ss = ss.spawn(2) + jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) + # use 128 bits expressed as an integer + stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() + random.seed(stdlib_seed) + + def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): + if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: + dataloader.worker_init_fn = partial(self.worker_init_function, + rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) + + def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): + # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; + if callable(getattr(dataloader.sampler, "set_epoch", None)): + dataloader.sampler.set_epoch(cur_epoch_idx) + + @staticmethod + def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): + pass diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index ee2514e9..93187ede 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -146,7 +146,10 @@ class JittorMPIDriver(JittorDriver): return self.model.no_sync def unwrap_model(self): - pass + """ + 返回训练使用的模型。 + """ + return self.model def get_local_rank(self) -> int: return self.local_rank @@ -155,4 +158,14 @@ class JittorMPIDriver(JittorDriver): pass def is_distributed(self): + """ + 判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``True``。 + """ return True + + @property + def data_device(self) -> str: + """ + :return: 数据所在的设备; + """ + return self.model_device \ No newline at end of file diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 19c4b4c2..b559fd92 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -27,28 +27,36 @@ class JittorSingleDriver(JittorDriver): 支持 cpu 和 gpu 的切换; 实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数 + :param model: 传入给 ``Trainer`` 的 ``model`` 参数; + :param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``; + + * 为 ``None`` 或 ``cpu`` 时 + 表示在 ``cpu`` 上进行训练; + * 为 ``gpu`` 或 ``cuda`` 时 + 表示在显卡设备上进行训练; + + :param fp16: 是否开启 fp16; """ def __init__(self, model, device=None, fp16: bool = False, **kwargs): + if device not in [None, "cpu", "gpu", "cuda"]: + raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .") super(JittorSingleDriver, self).__init__(model, fp16) - self.model_device = device + self.model_device = device if device is not None else "cpu" self.local_rank = 0 self.global_rank = 0 self.world_size = 1 - def step(self): - for optimizer in self.optimizers: - optimizer.step() - - def backward(self, loss): - for optimizer in self.optimizers: - optimizer.backward(loss) - - def zero_grad(self): - for optimizer in self.optimizers: - optimizer.zero_grad() + def setup(self): + r""" + 初始化训练环境;根据传入的 ``device`` 值设置模型的训练场景为 ``cpu`` 或 ``gpu``; + """ + if self.model_device in ["cpu", None]: + jt.flags.use_cuda = 0 # 使用 cpu + else: + jt.flags.use_cuda = 1 # 使用 cuda def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: if isinstance(batch, Dict) and not self.wo_auto_param_call: @@ -70,9 +78,15 @@ class JittorSingleDriver(JittorDriver): raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def unwrap_model(self): + """ + 返回训练使用的模型。 + """ return self.model def is_distributed(self): + """ + 判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``False``。 + """ return False def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], @@ -103,11 +117,15 @@ class JittorSingleDriver(JittorDriver): else: return dataloader - def setup(self): + def unwrap_model(self): """ - 支持 cpu 和 gpu 的切换 + 返回训练使用的模型。 """ - if self.model_device in ["cpu", None]: - jt.flags.use_cuda = 0 # 使用 cpu - else: - jt.flags.use_cuda = 1 # 使用 cuda + return self.model + + @property + def data_device(self) -> str: + """ + :return: 数据和模型所在的设备; + """ + return self.model_device diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py index c6d44cfc..50eed7e3 100644 --- a/fastNLP/core/drivers/jittor_driver/utils.py +++ b/fastNLP/core/drivers/jittor_driver/utils.py @@ -1,56 +1,6 @@ -from contextlib import ExitStack - from fastNLP.envs.imports import _NEED_IMPORT_JITTOR if _NEED_IMPORT_JITTOR: import jittor __all__ = [] - -class DummyGradScaler: - """ - 用于仿造的 **GradScaler** 对象,防止重复写大量的if判断 - """ - def __init__(self, *args, **kwargs): - pass - - def get_scale(self): - return 1.0 - - def is_enabled(self): - return False - - def scale(self, outputs): - return outputs - - def step(self, optimizer, *args, **kwargs): - optimizer.step(*args, **kwargs) - - def update(self, new_scale=None): - pass - - def unscale_(self, optimizer): - pass - - def load_state_dict(self, state_dict): - pass - - def state_dict(self): - return {} - - -def _build_fp16_env(dummy=False): - if dummy: - auto_cast = ExitStack - GradScaler = DummyGradScaler - else: - raise NotImplementedError("JittorDriver does not support fp16 now.") - # if not jt.flags.use_cuda: - # raise RuntimeError("No cuda") - # if paddle.device.cuda.get_device_capability(0)[0] < 7: - # log.warning( - # "NOTE: your device does NOT support faster training with fp16, " - # "please switch to FP32 which is likely to be faster" - # ) - # from paddle.amp import auto_cast, GradScaler - return auto_cast, GradScaler \ No newline at end of file diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index e879dd90..090bf567 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -113,12 +113,11 @@ class PaddleDriver(Driver): @staticmethod def tensor_to_numeric(tensor, reduce=None): r""" - 将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个元素则返回 float 或 int 。 + 将一个 :class:`paddle.Tensor` 对象转换为 转换成 python 中的数值类型; - :param tensor: 需要被转换的 `tensor` 对象 - :param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 - float 或 int 对象。 - :return: 转换后返回的结果 + :param tensor: :class:`paddle.Tensor` 类型的对象; + :param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; + :return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; """ if tensor is None: return None From 28eb1a58361f26f145783da0751aa37a3d33343d Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 17:14:01 +0000 Subject: [PATCH 4/6] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=AF=B9tensor?= =?UTF-8?q?=5Fto=5Fnumeric=20reduce=20=E5=8F=82=E6=95=B0=E7=9A=84=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/torch_driver.py | 2 +- .../drivers/paddle_driver/test_single_device.py | 29 ++++++++++++++++------ .../drivers/torch_driver/test_single_device.py | 24 ++++++++++++++---- 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 21325b5c..17d65d54 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: from torch.optim import Optimizer from torch.utils.data import RandomSampler as TorchRandomSampler _reduces = { - 'sum': torch.max, + 'sum': torch.sum, 'min': torch.min, 'max': torch.max, 'mean': torch.mean diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 9b7a8560..3c2d7e27 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -75,12 +75,12 @@ class TestPaddleDriverFunctions: 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 """ dataloader = DataLoader(PaddleNormalDataset()) - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # batch_size 和 batch_sampler 均为 None 的情形 dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建torch的dataloader dataloader = torch.utils.data.DataLoader( @@ -88,7 +88,7 @@ class TestPaddleDriverFunctions: batch_size=32, shuffle=True ) with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.torchpaddle def test_check_dataloader_legality_in_test(self): @@ -100,7 +100,7 @@ class TestPaddleDriverFunctions: "train": DataLoader(PaddleNormalDataset()), "test":DataLoader(PaddleNormalDataset()) } - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # batch_size 和 batch_sampler 均为 None 的情形 dataloader = { @@ -108,12 +108,12 @@ class TestPaddleDriverFunctions: "test":DataLoader(PaddleNormalDataset(), batch_size=None) } with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 传入的不是 dict ,应该报错 dataloader = DataLoader(PaddleNormalDataset()) with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建 torch 的 dataloader train_loader = torch.utils.data.DataLoader( @@ -126,7 +126,7 @@ class TestPaddleDriverFunctions: ) dataloader = {"train": train_loader, "test": test_loader} with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.paddle def test_tensor_to_numeric(self): @@ -183,6 +183,21 @@ class TestPaddleDriverFunctions: assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() @pytest.mark.paddle + def test_tensor_to_numeric_reduce(self): + tensor = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + res_max = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="max") + res_min = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="min") + res_sum = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="sum") + res_mean = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="mean") + + assert res_max == 6 + assert res_min == 1 + assert res_sum == 21 + assert res_mean == 3.5 + + + @pytest.mark.paddle def test_set_model_mode(self): """ 测试 set_model_mode 函数 diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 7839e1c9..4d92b05a 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -117,7 +117,7 @@ class TestTorchDriverFunctions: 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 """ dataloader = DataLoader(TorchNormalDataset()) - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建 paddle 的 dataloader dataloader = paddle.io.DataLoader( @@ -125,7 +125,7 @@ class TestTorchDriverFunctions: batch_size=32, shuffle=True ) with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.torchpaddle def test_check_dataloader_legality_in_test(self): @@ -137,12 +137,12 @@ class TestTorchDriverFunctions: "train": DataLoader(TorchNormalDataset()), "test": DataLoader(TorchNormalDataset()) } - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 传入的不是 dict,应该报错 dataloader = DataLoader(TorchNormalDataset()) with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建 paddle 的 dataloader train_loader = paddle.io.DataLoader( @@ -155,7 +155,7 @@ class TestTorchDriverFunctions: ) dataloader = {"train": train_loader, "test": test_loader} with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.torch def test_tensor_to_numeric(self): @@ -212,6 +212,20 @@ class TestTorchDriverFunctions: assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() @pytest.mark.torch + def test_tensor_to_numeric_reduce(self): + tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + res_max = TorchSingleDriver.tensor_to_numeric(tensor, reduce="max") + res_min = TorchSingleDriver.tensor_to_numeric(tensor, reduce="min") + res_sum = TorchSingleDriver.tensor_to_numeric(tensor, reduce="sum") + res_mean = TorchSingleDriver.tensor_to_numeric(tensor, reduce="mean") + + assert res_max == 6 + assert res_min == 1 + assert res_sum == 21 + assert res_mean == 3.5 + + @pytest.mark.torch def test_set_model_mode(self): """ 测试set_model_mode函数 From b29b002bccc272f5b5f5bb14c7a76ddee4f86cc2 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 17:17:58 +0000 Subject: [PATCH 5/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20paddle=5Fto=20?= =?UTF-8?q?=E4=B8=AD=E5=85=B3=E4=BA=8E=20device=20=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/utils/paddle_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index f14a2bce..9e7e73a4 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -68,7 +68,8 @@ def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_ 该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 :param data: 要迁移的张量; - :param device: 目标设备,可以是 ``str`` 或 ``int`` 类型; + :param device: 目标设备,可以是 ``str`` 或 ``int`` 及 **paddle** 自己的 :class:`paddle.fluid.core_avx.Place`、 + :class:`paddle.CPUPlace` 和 :class:`paddle.CUDAPlace` 类型; :return: 迁移后的张量; """ if isinstance(device, paddle.fluid.core_avx.Place): From a6cfc4086f1b5cfaebbec3c4f33af66b95254aee Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 17:34:32 +0000 Subject: [PATCH 6/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9seq=5Flen=5Fto=5Fmask?= =?UTF-8?q?=E7=9A=84jittor=E5=AE=9E=E7=8E=B0=E5=8F=8A=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E4=B8=AD=E7=9A=84=E4=B8=80=E5=A4=84=E4=BC=A0=E5=8F=82?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/utils/seq_len_to_mask.py | 2 +- tests/core/utils/test_seq_len_to_mask.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/utils/seq_len_to_mask.py b/fastNLP/core/utils/seq_len_to_mask.py index e244603c..710c0a2b 100644 --- a/fastNLP/core/utils/seq_len_to_mask.py +++ b/fastNLP/core/utils/seq_len_to_mask.py @@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None): if isinstance(seq_len, jittor.Var): assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." batch_size = seq_len.shape[0] - broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) + broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1) mask = broad_cast_seq_len < seq_len.unsqueeze(1) return mask except NameError as e: diff --git a/tests/core/utils/test_seq_len_to_mask.py b/tests/core/utils/test_seq_len_to_mask.py index 0a17bae6..64c84837 100644 --- a/tests/core/utils/test_seq_len_to_mask.py +++ b/tests/core/utils/test_seq_len_to_mask.py @@ -78,7 +78,7 @@ class TestSeqLenToMask: mask = seq_len_to_mask(seq_len) # 3. pad到指定长度 - seq_len = paddle.randint(1, 10, size=(10,)) + seq_len = paddle.randint(1, 10, shape=(10,)) mask = seq_len_to_mask(seq_len, 100) assert 100 == mask.shape[1]