|
- import importlib
- import os
- import re
- import shutil
- import sys
- from pathlib import Path
- from typing import Dict, Optional, Union
-
- from fastNLP.transformers.torch.file_utils import (
- HF_MODULES_CACHE,
- TRANSFORMERS_DYNAMIC_MODULE_NAME,
- cached_path,
- hf_bucket_url,
- is_offline_mode,
- )
- from fastNLP.core.log import logger
-
- def init_hf_modules():
- """
- Creates the cache directory for modules with an init, and adds it to the Python path.
- """
- # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
- if HF_MODULES_CACHE in sys.path:
- return
-
- sys.path.append(HF_MODULES_CACHE)
- os.makedirs(HF_MODULES_CACHE, exist_ok=True)
- init_path = Path(HF_MODULES_CACHE) / "__init__.py"
- if not init_path.exists():
- init_path.touch()
-
- def create_dynamic_module(name: Union[str, os.PathLike]):
- """
- Creates a dynamic module in the cache directory for modules.
- """
- init_hf_modules()
- dynamic_module_path = Path(HF_MODULES_CACHE) / name
- # If the parent module does not exist yet, recursively create it.
- if not dynamic_module_path.parent.exists():
- create_dynamic_module(dynamic_module_path.parent)
- os.makedirs(dynamic_module_path, exist_ok=True)
- init_path = dynamic_module_path / "__init__.py"
- if not init_path.exists():
- init_path.touch()
-
- def check_imports(filename):
- """
- Check if the current Python environment contains all the libraries that are imported in a file.
- """
- with open(filename, "r", encoding="utf-8") as f:
- content = f.read()
-
- # Imports of the form `import xxx`
- imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
- # Imports of the form `from xxx import yyy`
- imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
- # Only keep the top-level module
- imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
-
- # Unique-ify and test we got them all
- imports = list(set(imports))
- missing_packages = []
- for imp in imports:
- try:
- importlib.import_module(imp)
- except ImportError:
- missing_packages.append(imp)
-
- if len(missing_packages) > 0:
- raise ImportError(
- "This modeling file requires the following packages that were not found in your environment: "
- f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
- )
-
-
- def get_class_in_module(class_name, module_path):
- """
- Import a module on the cache directory for modules and extract a class from it.
- """
- module_path = module_path.replace(os.path.sep, ".")
- module = importlib.import_module(module_path)
- return getattr(module, class_name)
-
- def get_class_from_dynamic_module(
- pretrained_model_name_or_path: Union[str, os.PathLike],
- module_file: str,
- class_name: str,
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- force_download: bool = False,
- resume_download: bool = False,
- proxies: Optional[Dict[str, str]] = None,
- use_auth_token: Optional[Union[bool, str]] = None,
- revision: Optional[str] = None,
- local_files_only: bool = False,
- **kwargs,
- ):
- """
- Extracts a class from a module file, present in the local folder or repository of a model.
-
- .. warning::
-
- Calling this function will execute the code in the module file found locally or downloaded from the Hub. It
- should therefore only be called on trusted repos.
-
- Args:
- pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
- This can be either:
-
- - a string, the `model id` of a pretrained model configuration hosted inside a model repo on
- huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
- namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- - a path to a `directory` containing a configuration file saved using the
- :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.
-
- module_file (:obj:`str`):
- The name of the module file containing the class to look for.
- class_name (:obj:`str`):
- The name of the class to import in the module.
- cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
- cache should not be used.
- force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Whether or not to force to (re-)download the configuration files and override the cached versions if they
- exist.
- resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
- proxies (:obj:`Dict[str, str]`, `optional`):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- use_auth_token (:obj:`str` or `bool`, `optional`):
- The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
- generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
- revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
- identifier allowed by git.
- local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
- If :obj:`True`, will only try to load the tokenizer configuration from local files.
-
- .. note::
-
- Passing :obj:`use_auth_token=True` is required when you want to use a private model.
-
-
- Returns:
- :obj:`type`: The class, dynamically imported from the module.
-
- Examples::
-
- # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
- # module.
- cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
- """
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- if os.path.isdir(pretrained_model_name_or_path):
- module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
- submodule = "local"
- else:
- module_file_or_url = hf_bucket_url(
- pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
- )
- submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
-
- try:
- # Load from URL or cache if already cached
- resolved_module_file = cached_path(
- module_file_or_url,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- )
-
- except EnvironmentError:
- logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
- raise
-
- # Check we have all the requirements in our environment
- check_imports(resolved_module_file)
-
- # Now we move the module inside our cached dynamic modules.
- full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
- create_dynamic_module(full_submodule)
- submodule_path = Path(HF_MODULES_CACHE) / full_submodule
- if submodule == "local":
- # We always copy local files (we could hash the file to see if there was a change, and give them the name of
- # that hash, to only copy when there is a modification but it seems overkill for now).
- # The only reason we do the copy is to avoid putting too many folders in sys.path.
- module_name = module_file
- shutil.copy(resolved_module_file, submodule_path / module_file)
- else:
- # The module file will end up being named module_file + the etag. This way we get the benefit of versioning.
- resolved_module_file_name = Path(resolved_module_file).name
- module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".")
- module_name = "_".join(module_name_parts) + ".py"
- if not (submodule_path / module_name).exists():
- shutil.copy(resolved_module_file, submodule_path / module_name)
-
- # And lastly we get the class inside our newly created module
- final_module = os.path.join(full_submodule, module_name.replace(".py", ""))
- return get_class_in_module(class_name, final_module)
|