|
- import copy
- import fnmatch
- import importlib.util
- import io
- import json
- import os
- import re
- import shutil
- import sys
- import tarfile
- import tempfile
- import operator
- import types
- import functools
- from collections import OrderedDict, UserDict
- from contextlib import contextmanager
- from dataclasses import fields
- from enum import Enum
- from functools import partial
- from hashlib import sha256
- from pathlib import Path
- from typing import Any, BinaryIO, Dict, Optional, Tuple, Union, List
- from urllib.parse import urlparse
- from uuid import uuid4
- from zipfile import ZipFile, is_zipfile
-
- import numpy as np
- # from tqdm.auto import tqdm
-
- import requests
-
- from . import __version__
- from .utils.versions import importlib_metadata
- from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8
- from fastNLP.envs.utils import _compare_version
- from fastNLP.core.log import logger
-
- if _NEED_IMPORT_TORCH:
- import torch
- _torch_version = importlib_metadata.version("torch")
-
- ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
-
- hf_cache_home = os.path.expanduser(
- os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
- )
- default_cache_path = os.path.join(hf_cache_home, "transformers")
-
- PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
- PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
- TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
- HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
- TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
- SESSION_ID = uuid4().hex
- DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
-
- WEIGHTS_NAME = "pytorch_model.bin"
- DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
-
- _staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
- _default_endpoint = "https://moon-staging.huggingface.co" if _staging_mode else "https://huggingface.co"
-
- HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", _default_endpoint)
- HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
-
- CONFIG_NAME = "config.json"
-
- _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
-
- @contextmanager
- def filelock(path):
- try:
- import fcntl
- open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
- fd = os.open(path, open_mode)
- fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
- except:
- pass
-
- yield
-
- try:
- fcntl.flock(fd, fcntl.LOCK_UN)
- os.close(fd)
- except:
- pass
-
- class HfFolder:
- """
- hugging_face.HfFolder
- version = 0.5.1
- """
- path_token = os.path.expanduser("~/.huggingface/token")
-
- @classmethod
- def save_token(cls, token):
- """
- Save token, creating folder as needed.
-
- Args:
- token (`str`):
- The token to save to the [`HfFolder`]
- """
- os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
- with open(cls.path_token, "w+") as f:
- f.write(token)
-
- @classmethod
- def get_token(cls):
- """
- Retrieves the token
-
- Returns:
- `str` or `None`: The token, `None` if it doesn't exist.
-
- """
- try:
- with open(cls.path_token, "r") as f:
- return f.read()
- except FileNotFoundError:
- pass
-
- @classmethod
- def delete_token(cls):
- """
- Deletes the token from storage. Does not fail if token does not exist.
- """
- try:
- os.remove(cls.path_token)
- except FileNotFoundError:
- pass
-
-
- def is_offline_mode():
- return _is_offline_mode
-
- def is_training_run_on_sagemaker():
- return "SAGEMAKER_JOB_NAME" in os.environ
-
- def add_start_docstrings(*docstr):
- def docstring_decorator(fn):
- fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
- return fn
-
- return docstring_decorator
-
-
- def add_start_docstrings_to_model_forward(*docstr):
- def docstring_decorator(fn):
- class_name = f":class:`~transformers.{fn.__qualname__.split('.')[0]}`"
- intro = f" The {class_name} forward method, overrides the :func:`__call__` special method."
- note = r"""
-
- .. note::
- Although the recipe for forward pass needs to be defined within this function, one should call the
- :class:`Module` instance afterwards instead of this since the former takes care of running the pre and post
- processing steps while the latter silently ignores them.
- """
- fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
- return fn
-
- return docstring_decorator
-
-
- def add_end_docstrings(*docstr):
- def docstring_decorator(fn):
- fn.__doc__ = fn.__doc__ + "".join(docstr)
- return fn
-
- return docstring_decorator
-
- PT_RETURN_INTRODUCTION = r"""
- Returns:
- :class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`: A :class:`~{full_output_type}` or a tuple of
- :obj:`torch.FloatTensor` (if ``return_dict=False`` is passed or when ``config.return_dict=False``) comprising
- various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.
-
- """
-
- def _get_indent(t):
- """Returns the indentation in the first line of t"""
- search = re.search(r"^(\s*)\S", t)
- return "" if search is None else search.groups()[0]
-
-
- def _convert_output_args_doc(output_args_doc):
- """Convert output_args_doc to display properly."""
- # Split output_arg_doc in blocks argument/description
- indent = _get_indent(output_args_doc)
- blocks = []
- current_block = ""
- for line in output_args_doc.split("\n"):
- # If the indent is the same as the beginning, the line is the name of new arg.
- if _get_indent(line) == indent:
- if len(current_block) > 0:
- blocks.append(current_block[:-1])
- current_block = f"{line}\n"
- else:
- # Otherwise it's part of the description of the current arg.
- # We need to remove 2 spaces to the indentation.
- current_block += f"{line[2:]}\n"
- blocks.append(current_block[:-1])
-
- # Format each block for proper rendering
- for i in range(len(blocks)):
- blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
- blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
-
- return "\n".join(blocks)
-
- def _prepare_output_docstrings(output_type, config_class):
- """
- Prepares the return part of the docstring using `output_type`.
- """
- docstrings = output_type.__doc__
-
- # Remove the head of the docstring to keep the list of args only
- lines = docstrings.split("\n")
- i = 0
- while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
- i += 1
- if i < len(lines):
- docstrings = "\n".join(lines[(i + 1) :])
- docstrings = _convert_output_args_doc(docstrings)
-
- # Add the return introduction
- full_output_type = f"{output_type.__module__}.{output_type.__name__}"
- intro = PT_RETURN_INTRODUCTION
- intro = intro.format(full_output_type=full_output_type, config_class=config_class)
- return intro + docstrings
-
- PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
- Example::
-
- >>> from transformers import {tokenizer_class}, {model_class}
- >>> import torch
-
- >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
- >>> model = {model_class}.from_pretrained('{checkpoint}')
-
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
-
- >>> outputs = model(**inputs, labels=labels)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- """
-
- PT_QUESTION_ANSWERING_SAMPLE = r"""
- Example::
-
- >>> from transformers import {tokenizer_class}, {model_class}
- >>> import torch
-
- >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
- >>> model = {model_class}.from_pretrained('{checkpoint}')
-
- >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
- >>> inputs = tokenizer(question, text, return_tensors='pt')
- >>> start_positions = torch.tensor([1])
- >>> end_positions = torch.tensor([3])
-
- >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
- >>> loss = outputs.loss
- >>> start_scores = outputs.start_logits
- >>> end_scores = outputs.end_logits
- """
-
- PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
- Example::
-
- >>> from transformers import {tokenizer_class}, {model_class}
- >>> import torch
-
- >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
- >>> model = {model_class}.from_pretrained('{checkpoint}')
-
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
- >>> outputs = model(**inputs, labels=labels)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- """
-
- PT_MASKED_LM_SAMPLE = r"""
- Example::
-
- >>> from transformers import {tokenizer_class}, {model_class}
- >>> import torch
-
- >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
- >>> model = {model_class}.from_pretrained('{checkpoint}')
-
- >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
- >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
-
- >>> outputs = model(**inputs, labels=labels)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- """
-
- PT_BASE_MODEL_SAMPLE = r"""
- Example::
-
- >>> from transformers import {tokenizer_class}, {model_class}
- >>> import torch
-
- >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
- >>> model = {model_class}.from_pretrained('{checkpoint}')
-
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> outputs = model(**inputs)
-
- >>> last_hidden_states = outputs.last_hidden_state
- """
-
- PT_MULTIPLE_CHOICE_SAMPLE = r"""
- Example::
-
- >>> from transformers import {tokenizer_class}, {model_class}
- >>> import torch
-
- >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
- >>> model = {model_class}.from_pretrained('{checkpoint}')
-
- >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- >>> choice0 = "It is eaten with a fork and a knife."
- >>> choice1 = "It is eaten while held in the hand."
- >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
-
- >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors='pt', padding=True)
- >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
-
- >>> # the linear classifier still needs to be trained
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- """
-
- PT_CAUSAL_LM_SAMPLE = r"""
- Example::
-
- >>> import torch
- >>> from transformers import {tokenizer_class}, {model_class}
-
- >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
- >>> model = {model_class}.from_pretrained('{checkpoint}')
-
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> outputs = model(**inputs, labels=inputs["input_ids"])
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- """
-
- PT_SAMPLE_DOCSTRINGS = {
- "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
- "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
- "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
- "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
- "MaskedLM": PT_MASKED_LM_SAMPLE,
- "LMHead": PT_CAUSAL_LM_SAMPLE,
- "BaseModel": PT_BASE_MODEL_SAMPLE,
- }
-
- def add_code_sample_docstrings(
- *docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None
- ):
- def docstring_decorator(fn):
- # model_class defaults to function's class if not specified otherwise
- model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
-
- sample_docstrings = PT_SAMPLE_DOCSTRINGS
-
- doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
-
- if "SequenceClassification" in model_class:
- code_sample = sample_docstrings["SequenceClassification"]
- elif "QuestionAnswering" in model_class:
- code_sample = sample_docstrings["QuestionAnswering"]
- elif "TokenClassification" in model_class:
- code_sample = sample_docstrings["TokenClassification"]
- elif "MultipleChoice" in model_class:
- code_sample = sample_docstrings["MultipleChoice"]
- elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
- doc_kwargs["mask"] = "[MASK]" if mask is None else mask
- code_sample = sample_docstrings["MaskedLM"]
- elif "LMHead" in model_class or "CausalLM" in model_class:
- code_sample = sample_docstrings["LMHead"]
- elif "Model" in model_class or "Encoder" in model_class:
- code_sample = sample_docstrings["BaseModel"]
- else:
- raise ValueError(f"Docstring can't be built for model {model_class}")
-
- output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
- built_doc = code_sample.format(**doc_kwargs)
- fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
- return fn
-
- return docstring_decorator
-
- def replace_return_docstrings(output_type=None, config_class=None):
- def docstring_decorator(fn):
- docstrings = fn.__doc__
- lines = docstrings.split("\n")
- i = 0
- while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
- i += 1
- if i < len(lines):
- lines[i] = _prepare_output_docstrings(output_type, config_class)
- docstrings = "\n".join(lines)
- else:
- raise ValueError(
- f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
- )
- fn.__doc__ = docstrings
- return fn
-
- return docstring_decorator
-
- def is_remote_url(url_or_filename):
- parsed = urlparse(url_or_filename)
- return parsed.scheme in ("http", "https")
-
- def hf_bucket_url(
- model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
- ) -> str:
- """
- Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
- to Cloudfront (a Content Delivery Network, or CDN) for large files.
-
- Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
- bandwidth costs).
-
- Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
- because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
- in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
- can't ever be stale.
-
- In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
- its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
- are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
- """
- if subfolder is not None:
- filename = f"{subfolder}/{filename}"
-
- if mirror:
- if mirror in ["tuna", "bfsu"]:
- raise ValueError("The Tuna and BFSU mirrors are no longer available. Try removing the mirror argument.")
- legacy_format = "/" not in model_id
- if legacy_format:
- return f"{mirror}/{model_id}-{filename}"
- else:
- return f"{mirror}/{model_id}/{filename}"
-
- if revision is None:
- revision = "main"
- return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
-
- def url_to_filename(url: str, etag: Optional[str] = None) -> str:
- """
- Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
- delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
- identify it as a HDF5 file (see
- https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
- """
- url_bytes = url.encode("utf-8")
- filename = sha256(url_bytes).hexdigest()
-
- if etag:
- etag_bytes = etag.encode("utf-8")
- filename += "." + sha256(etag_bytes).hexdigest()
-
- if url.endswith(".h5"):
- filename += ".h5"
-
- return filename
-
- def cached_path(
- url_or_filename,
- cache_dir=None,
- force_download=False,
- proxies=None,
- resume_download=False,
- user_agent: Union[Dict, str, None] = None,
- extract_compressed_file=False,
- force_extract=False,
- use_auth_token: Union[bool, str, None] = None,
- local_files_only=False,
- ) -> Optional[str]:
- """
- Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
- and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
- then return the path
-
- Args:
- cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
- force_download: if True, re-download the file even if it's already cached in the cache dir.
- resume_download: if True, resume the download if incompletely received file is found.
- user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
- use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
- will get token from ~/.huggingface.
- extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
- file in a folder along the archive.
- force_extract: if True when extract_compressed_file is True and the archive was already extracted,
- re-extract the archive and override the folder where it was extracted.
-
- Return:
- Local path (string) of file or if networking is off, last version of file cached on disk.
-
- Raises:
- In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
- """
- if cache_dir is None:
- cache_dir = TRANSFORMERS_CACHE
- if isinstance(url_or_filename, Path):
- url_or_filename = str(url_or_filename)
- if isinstance(cache_dir, Path):
- cache_dir = str(cache_dir)
-
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- if is_remote_url(url_or_filename):
- # URL, so get it from the cache (downloading if necessary)
- output_path = get_from_cache(
- url_or_filename,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- user_agent=user_agent,
- use_auth_token=use_auth_token,
- local_files_only=local_files_only,
- )
- elif os.path.exists(url_or_filename):
- # File, and it exists.
- output_path = url_or_filename
- elif urlparse(url_or_filename).scheme == "":
- # File, but it doesn't exist.
- raise EnvironmentError(f"file {url_or_filename} not found")
- else:
- # Something unknown
- raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")
-
- if extract_compressed_file:
- if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
- return output_path
-
- # Path where we extract compressed archives
- # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
- output_dir, output_file = os.path.split(output_path)
- output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
- output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
-
- if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
- return output_path_extracted
-
- # Prevent parallel extractions
- lock_path = output_path + ".lock"
- with filelock(lock_path):
- shutil.rmtree(output_path_extracted, ignore_errors=True)
- os.makedirs(output_path_extracted)
- if is_zipfile(output_path):
- with ZipFile(output_path, "r") as zip_file:
- zip_file.extractall(output_path_extracted)
- zip_file.close()
- elif tarfile.is_tarfile(output_path):
- tar_file = tarfile.open(output_path)
- tar_file.extractall(output_path_extracted)
- tar_file.close()
- else:
- raise EnvironmentError(f"Archive format of {output_path} could not be identified")
-
- return output_path_extracted
-
- return output_path
-
- def define_sagemaker_information():
- try:
- instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
- dlc_container_used = instance_data["Image"]
- dlc_tag = instance_data["Image"].split(":")[1]
- except Exception:
- dlc_container_used = None
- dlc_tag = None
-
- sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
- runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
- account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None
-
- sagemaker_object = {
- "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
- "sm_region": os.getenv("AWS_REGION", None),
- "sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
- "sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
- "sm_distributed_training": runs_distributed_training,
- "sm_deep_learning_container": dlc_container_used,
- "sm_deep_learning_container_tag": dlc_tag,
- "sm_account_id": account_id,
- }
- return sagemaker_object
-
- def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
- """
- Formats a user-agent string with basic info about a request.
- """
- ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
- if _NEED_IMPORT_TORCH:
- ua += f"; torch/{_torch_version}"
- if DISABLE_TELEMETRY:
- return ua + "; telemetry/off"
- if is_training_run_on_sagemaker():
- ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
- # CI will set this value to True
- if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
- ua += "; is_ci/true"
- if isinstance(user_agent, dict):
- ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
- elif isinstance(user_agent, str):
- ua += "; " + user_agent
- return ua
-
- def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
- """
- Download remote file. Do not gobble up errors.
- """
- headers = copy.deepcopy(headers)
- if resume_size > 0:
- headers["Range"] = f"bytes={resume_size}-"
- r = requests.get(url, stream=True, proxies=proxies, headers=headers)
- r.raise_for_status()
- content_length = r.headers.get("Content-Length")
- total = resume_size + int(content_length) if content_length is not None else None
- # progress = tqdm(
- # unit="B",
- # unit_scale=True,
- # unit_divisor=1024,
- # total=total,
- # initial=resume_size,
- # desc="Downloading",
- # disable=bool(logging.get_verbosity() == logging.NOTSET),
- # )
- for chunk in r.iter_content(chunk_size=1024):
- if chunk: # filter out keep-alive new chunks
- # progress.update(len(chunk))
- temp_file.write(chunk)
- # progress.close()
-
- def get_from_cache(
- url: str,
- cache_dir=None,
- force_download=False,
- proxies=None,
- etag_timeout=10,
- resume_download=False,
- user_agent: Union[Dict, str, None] = None,
- use_auth_token: Union[bool, str, None] = None,
- local_files_only=False,
- ) -> Optional[str]:
- """
- Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
- path to the cached file.
-
- Return:
- Local path (string) of file or if networking is off, last version of file cached on disk.
-
- Raises:
- In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
- """
- if cache_dir is None:
- cache_dir = TRANSFORMERS_CACHE
- if isinstance(cache_dir, Path):
- cache_dir = str(cache_dir)
-
- os.makedirs(cache_dir, exist_ok=True)
-
- headers = {"user-agent": http_user_agent(user_agent)}
- if isinstance(use_auth_token, str):
- headers["authorization"] = f"Bearer {use_auth_token}"
- elif use_auth_token:
- token = HfFolder.get_token()
- if token is None:
- raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
- headers["authorization"] = f"Bearer {token}"
-
- url_to_download = url
- etag = None
- if not local_files_only:
- try:
- r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
- r.raise_for_status()
- etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
- # We favor a custom header indicating the etag of the linked resource, and
- # we fallback to the regular etag header.
- # If we don't have any of those, raise an error.
- if etag is None:
- raise OSError(
- "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
- )
- # In case of a redirect,
- # save an extra redirect on the request.get call,
- # and ensure we download the exact atomic version even if it changed
- # between the HEAD and the GET (unlikely, but hey).
- if 300 <= r.status_code <= 399:
- url_to_download = r.headers["Location"]
- except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
- # Actually raise for those subclasses of ConnectionError
- raise
- except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
- # Otherwise, our Internet connection is down.
- # etag is None
- pass
-
- filename = url_to_filename(url, etag)
-
- # get cache path to put the file
- cache_path = os.path.join(cache_dir, filename)
-
- # etag is None == we don't have a connection or we passed local_files_only.
- # try to get the last downloaded one
- if etag is None:
- if os.path.exists(cache_path):
- return cache_path
- else:
- matching_files = [
- file
- for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
- if not file.endswith(".json") and not file.endswith(".lock")
- ]
- if len(matching_files) > 0:
- return os.path.join(cache_dir, matching_files[-1])
- else:
- # If files cannot be found and local_files_only=True,
- # the models might've been found if local_files_only=False
- # Notify the user about that
- if local_files_only:
- raise FileNotFoundError(
- "Cannot find the requested files in the cached path and outgoing traffic has been"
- " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
- " to False."
- )
- else:
- raise ValueError(
- "Connection error, and we cannot find the requested files in the cached path."
- " Please try again or make sure your Internet connection is on."
- )
-
- # From now on, etag is not None.
- if os.path.exists(cache_path) and not force_download:
- return cache_path
-
- # Prevent parallel downloads of the same file with a lock.
- lock_path = cache_path + ".lock"
- with filelock(lock_path):
-
- # If the download just completed while the lock was activated.
- if os.path.exists(cache_path) and not force_download:
- # Even if returning early like here, the lock will be released.
- return cache_path
-
- if resume_download:
- incomplete_path = cache_path + ".incomplete"
-
- @contextmanager
- def _resumable_file_manager() -> "io.BufferedWriter":
- with open(incomplete_path, "ab") as f:
- yield f
-
- temp_file_manager = _resumable_file_manager
- if os.path.exists(incomplete_path):
- resume_size = os.stat(incomplete_path).st_size
- else:
- resume_size = 0
- else:
- temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
- resume_size = 0
-
- # Download to temporary file, then copy to cache dir once finished.
- # Otherwise you get corrupt cache entries if the download gets interrupted.
- with temp_file_manager() as temp_file:
- logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
-
- http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
-
- logger.info(f"storing {url} in cache at {cache_path}")
- os.replace(temp_file.name, cache_path)
-
- # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
- umask = os.umask(0o666)
- os.umask(umask)
- os.chmod(cache_path, 0o666 & ~umask)
-
- logger.info(f"creating metadata file for {cache_path}")
- meta = {"url": url, "etag": etag}
- meta_path = cache_path + ".json"
- with open(meta_path, "w") as meta_file:
- json.dump(meta, meta_file)
-
- return cache_path
-
- def get_list_of_files(
- path_or_repo: Union[str, os.PathLike],
- revision: Optional[str] = None,
- use_auth_token: Optional[Union[bool, str]] = None,
- local_files_only: bool = False,
- ) -> List[str]:
- """
- Gets the list of files inside :obj:`path_or_repo`.
-
- Args:
- path_or_repo (:obj:`str` or :obj:`os.PathLike`):
- Can be either the id of a repo on huggingface.co or a path to a `directory`.
- revision (:obj:`str`, `optional`, defaults to :obj:`"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
- identifier allowed by git.
- use_auth_token (:obj:`str` or `bool`, `optional`):
- The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
- generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
- local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Whether or not to only rely on local files and not to attempt to download any files.
-
- Returns:
- :obj:`List[str]`: The list of files available in :obj:`path_or_repo`.
- """
- path_or_repo = str(path_or_repo)
- # If path_or_repo is a folder, we just return what is inside (subdirectories included).
- if os.path.isdir(path_or_repo):
- list_of_files = []
- for path, dir_names, file_names in os.walk(path_or_repo):
- list_of_files.extend([os.path.join(path, f) for f in file_names])
- return list_of_files
-
- # Can't grab the files if we are on offline mode.
- if is_offline_mode() or local_files_only:
- return []
-
- # Otherwise we grab the token and use the model_info method.
- if isinstance(use_auth_token, str):
- token = use_auth_token
- elif use_auth_token is True:
- token = HfFolder.get_token()
- else:
- token = None
- # model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info(
- # path_or_repo, revision=revision, token=token
- # )
- endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT
- path = (
- f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}"
- if revision is None
- else f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}/revision/{revision}"
- )
- headers = {"authorization": f"Bearer {token}"} if token is not None else None
- status_query_param = None
- r = requests.get(
- path, headers=headers, timeout=None, params=status_query_param
- )
- r.raise_for_status()
- d = r.json()
- siblings = d.get("siblings", None)
- rfilenames = (
- [x["rfilename"] for x in siblings] if siblings is not None else None
- )
- return rfilenames
-
- def is_torch_fx_available():
- return _TORCH_GREATER_EQUAL_1_8 and _compare_version("torch", operator.lt, "1.9.0")
-
- def is_torch_fx_proxy(x):
- if is_torch_fx_available():
- import torch.fx
-
- return isinstance(x, torch.fx.Proxy)
- return False
-
- def is_sentencepiece_available():
- return importlib.util.find_spec("sentencepiece") is not None
-
- def is_tokenizers_available():
- return importlib.util.find_spec("tokenizers") is not None
-
- def is_tensor(x):
- """
- Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
- :obj:`np.ndarray`.
- """
- if is_torch_fx_proxy(x):
- return True
-
- if isinstance(x, torch.Tensor):
- return True
-
- return isinstance(x, np.ndarray)
-
- def to_py_obj(obj):
- """
- Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
- """
- if isinstance(obj, (dict, UserDict)):
- return {k: to_py_obj(v) for k, v in obj.items()}
- elif isinstance(obj, (list, tuple)):
- return [to_py_obj(o) for o in obj]
- elif _NEED_IMPORT_TORCH and _is_torch(obj):
- return obj.detach().cpu().tolist()
- elif isinstance(obj, np.ndarray):
- return obj.tolist()
- else:
- return obj
-
- def _is_numpy(x):
- return isinstance(x, np.ndarray)
-
- def _is_torch(x):
- import torch
-
- return isinstance(x, torch.Tensor)
-
-
- def _is_torch_device(x):
- import torch
-
- return isinstance(x, torch.device)
-
- class ModelOutput(OrderedDict):
- """
- Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
- a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular
- python dictionary.
-
- .. warning::
- You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`
- method to convert it to a tuple before.
- """
-
- def __post_init__(self):
- class_fields = fields(self)
-
- # Safety and consistency checks
- assert len(class_fields), f"{self.__class__.__name__} has no fields."
- assert all(
- field.default is None for field in class_fields[1:]
- ), f"{self.__class__.__name__} should not have more than one required field."
-
- first_field = getattr(self, class_fields[0].name)
- other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
-
- if other_fields_are_none and not is_tensor(first_field):
- if isinstance(first_field, dict):
- iterator = first_field.items()
- first_field_iterator = True
- else:
- try:
- iterator = iter(first_field)
- first_field_iterator = True
- except TypeError:
- first_field_iterator = False
-
- # if we provided an iterator as first field and the iterator is a (key, value) iterator
- # set the associated fields
- if first_field_iterator:
- for element in iterator:
- if (
- not isinstance(element, (list, tuple))
- or not len(element) == 2
- or not isinstance(element[0], str)
- ):
- break
- setattr(self, element[0], element[1])
- if element[1] is not None:
- self[element[0]] = element[1]
- elif first_field is not None:
- self[class_fields[0].name] = first_field
- else:
- for field in class_fields:
- v = getattr(self, field.name)
- if v is not None:
- self[field.name] = v
-
- def __delitem__(self, *args, **kwargs):
- raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
-
- def setdefault(self, *args, **kwargs):
- raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
-
- def pop(self, *args, **kwargs):
- raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
-
- def update(self, *args, **kwargs):
- raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
-
- def __getitem__(self, k):
- if isinstance(k, str):
- inner_dict = {k: v for (k, v) in self.items()}
- return inner_dict[k]
- else:
- return self.to_tuple()[k]
-
- def __setattr__(self, name, value):
- if name in self.keys() and value is not None:
- # Don't call self.__setitem__ to avoid recursion errors
- super().__setitem__(name, value)
- super().__setattr__(name, value)
-
- def __setitem__(self, key, value):
- # Will raise a KeyException if needed
- super().__setitem__(key, value)
- # Don't call self.__setattr__ to avoid recursion errors
- super().__setattr__(key, value)
-
- def to_tuple(self) -> Tuple[Any]:
- """
- Convert self to a tuple containing all the attributes/keys that are not ``None``.
- """
- return tuple(self[k] for k in self.keys())
-
-
- class ExplicitEnum(Enum):
- """
- Enum with more explicit error message for missing values.
- """
-
- @classmethod
- def _missing_(cls, value):
- raise ValueError(
- f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
- )
-
-
- class PaddingStrategy(ExplicitEnum):
- """
- Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
- in an IDE.
- """
-
- LONGEST = "longest"
- MAX_LENGTH = "max_length"
- DO_NOT_PAD = "do_not_pad"
-
-
- class TensorType(ExplicitEnum):
- """
- Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
- tab-completion in an IDE.
- """
-
- PYTORCH = "pt"
- NUMPY = "np"
-
- def copy_func(f):
- """Returns a copy of a function f."""
- # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
- g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
- g = functools.update_wrapper(g, f)
- g.__kwdefaults__ = f.__kwdefaults__
- return g
|