Former-commit-id:masterb0d35d9aea
[formerly28b7582d25
] [formerly632f8af13a
[formerly9eb3583256
]] [formerlyab6f5a8a9c
[formerly2720e70c78
] [formerlyc1f615a3c5
[formerly21840b3787
]]] [formerly2184b60ad6
[formerly2dae6c4ed4
] [formerly6a577865b8
[formerlyee28aaed84
]] [formerly72078d88c8
[formerlyab1fb9d512
] [formerly476600a47c
[formerlyf5a6205bb5
]]]] [formerlye5c2c3deef
[formerly8585dc41a9
] [formerly9cc7fe2088
[formerly1d2104316f
]] [formerlyb51e614a11
[formerlyf9891a191d
] [formerly048aa2f114
[formerlyd4de64574b
]]] [formerlyff66f55a51
[formerly0b691b9a7f
] [formerlye64e8dd253
[formerlycc45939e01
]] [formerlycdad10712a
[formerly2789e20e79
] [formerly25924f293c
[formerly32997accab
]]]]] [formerlyf43b431040
[formerly95815d02ca
] [formerlyfe9bd45d44
[formerly6daf0aa73e
]] [formerly61ab30c9a3
[formerlya13c6e23b4
] [formerly86fa5919ee
[formerly1e49e1a303
]]] [formerly69c5bc967a
[formerlyab82915cbd
] [formerlyf8057c3b14
[formerly5232f34578
]] [formerly671c54e952
[formerly6454a28f26
] [formerly3db6ff66b9
[formerlyaa8c7fe127
]]]] [formerly86b2ec6b84
[formerlyf35c344efe
] [formerlyd5616f66cd
[formerly98c9dca7da
]] [formerlya7dcc62bc5
[formerly4ef4fa0c98
] [formerly55f670b9ae
[formerly1cd4421e2e
]]] [formerlyd7a5bab832
[formerlyc77c5b48df
] [formerly01ffd33e2f
[formerlyaea728ceb6
]] [formerly16afb18e35
[formerly4768b156f3
] [formerly3c1298c626
[formerly1e61cf0974
]]]]]] Former-commit-id:28b09a56cb
[formerly5241dbb36c
] [formerlyd43909d979
[formerly43b0cca7f5
]] [formerly8d8f384c8e
[formerlybcf58203c6
] [formerlyca56bff2d0
[formerly7a8750ffc9
]]] [formerly2ce9fa87ae
[formerlye7b1b542f5
] [formerly62a7edf94a
[formerly26ca5f220d
]] [formerly73d10253b9
[formerlyaf463cecb0
] [formerly961314f474
[formerly6e7141e5e2
]]]] [formerlyc3f8938ba5
[formerly7f66748292
] [formerly3e5ef2c136
[formerly266ddb4ccc
]] [formerly4a8a5437b3
[formerlycbcb0f8777
] [formerlyc212b8217f
[formerly5f3d3d01c8
]]] [formerly09764ba6cd
[formerly4db991be5f
] [formerlyf79c6ec15d
[formerly7f8eb54d47
]] [formerly771b00b188
[formerlyf1dcba565f
] [formerly83c42510ad
[formerly3c1298c626
]]]]] Former-commit-id:f5cdcca4f3
[formerlydc57f947a2
] [formerly55cd6eb9a3
[formerlye92d6c0923
]] [formerlyba80ed43d2
[formerly0a2f65401a
] [formerly73c0c2ebb3
[formerlybf80cf285e
]]] [formerlya34f21d933
[formerlyae311cb3e7
] [formerlyae0e3ed079
[formerlyc9030d0303
]] [formerlyf8f4f6a8ec
[formerly343dd65df6
] [formerly1c97e6a7ba
[formerlyc69d281aca
]]]] Former-commit-id:b93de57240
[formerly6fd32d0759
] [formerlyff29e71cb0
[formerlyee483b56f9
]] [formerly2f944ec28b
[formerly15c46e806c
] [formerlydca0c18b8b
[formerlyf185a8658c
]]] Former-commit-id:76ed7f184f
[formerlyc8adbe1dea
] [formerly392fc3e54b
[formerly5bd396738e
]] Former-commit-id:11579edbe1
[formerly6970795314
] Former-commit-id:8fca90698b
@@ -0,0 +1,297 @@ | |||
import os | |||
import os.path | |||
import hashlib | |||
import gzip | |||
import errno | |||
import tarfile | |||
from typing import Any, Callable, List, Iterable, Optional, TypeVar | |||
import zipfile | |||
# import torch | |||
# from torch.utils.model_zoo import tqdm | |||
from hub import tqdm | |||
def gen_bar_updater() -> Callable[[int, int, int], None]: | |||
pbar = tqdm(total=None) | |||
def bar_update(count, block_size, total_size): | |||
if pbar.total is None and total_size: | |||
pbar.total = total_size | |||
progress_bytes = count * block_size | |||
pbar.update(progress_bytes - pbar.n) | |||
return bar_update | |||
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: | |||
md5 = hashlib.md5() | |||
with open(fpath, 'rb') as f: | |||
for chunk in iter(lambda: f.read(chunk_size), b''): | |||
md5.update(chunk) | |||
return md5.hexdigest() | |||
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: | |||
return md5 == calculate_md5(fpath, **kwargs) | |||
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: | |||
if not os.path.isfile(fpath): | |||
return False | |||
if md5 is None: | |||
return True | |||
return check_md5(fpath, md5) | |||
def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: | |||
"""Download a file from a url and place it in root. | |||
Args: | |||
url (str): URL to download file from | |||
root (str): Directory to place downloaded file in | |||
filename (str, optional): Name to save the file under. If None, use the basename of the URL | |||
md5 (str, optional): MD5 checksum of the download. If None, do not check | |||
""" | |||
import urllib | |||
root = os.path.expanduser(root) | |||
if not filename: | |||
filename = os.path.basename(url) | |||
fpath = os.path.join(root, filename) | |||
os.makedirs(root, exist_ok=True) | |||
# check if file is already present locally | |||
if check_integrity(fpath, md5): | |||
print('Using downloaded and verified file: ' + fpath) | |||
else: # download the file | |||
try: | |||
print('Downloading ' + url + ' to ' + fpath) | |||
urllib.request.urlretrieve( | |||
url, fpath, | |||
reporthook=gen_bar_updater() | |||
) | |||
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] | |||
if url[:5] == 'https': | |||
url = url.replace('https:', 'http:') | |||
print('Failed download. Trying https -> http instead.' | |||
' Downloading ' + url + ' to ' + fpath) | |||
urllib.request.urlretrieve( | |||
url, fpath, | |||
reporthook=gen_bar_updater() | |||
) | |||
else: | |||
raise e | |||
# check integrity of downloaded file | |||
if not check_integrity(fpath, md5): | |||
raise RuntimeError("File not found or corrupted.") | |||
def list_dir(root: str, prefix: bool = False) -> List[str]: | |||
"""List all directories at a given root | |||
Args: | |||
root (str): Path to directory whose folders need to be listed | |||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
only returns the name of the directories found | |||
""" | |||
root = os.path.expanduser(root) | |||
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] | |||
if prefix is True: | |||
directories = [os.path.join(root, d) for d in directories] | |||
return directories | |||
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: | |||
"""List all files ending with a suffix at a given root | |||
Args: | |||
root (str): Path to directory whose folders need to be listed | |||
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). | |||
It uses the Python "str.endswith" method and is passed directly | |||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
only returns the name of the files found | |||
""" | |||
root = os.path.expanduser(root) | |||
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] | |||
if prefix is True: | |||
files = [os.path.join(root, d) for d in files] | |||
return files | |||
def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined] | |||
return "Google Drive - Quota exceeded" in response.text | |||
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): | |||
"""Download a Google Drive file from and place it in root. | |||
Args: | |||
file_id (str): id of file to be downloaded | |||
root (str): Directory to place downloaded file in | |||
filename (str, optional): Name to save the file under. If None, use the id of the file. | |||
md5 (str, optional): MD5 checksum of the download. If None, do not check | |||
""" | |||
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url | |||
import requests | |||
url = "https://docs.google.com/uc?export=download" | |||
root = os.path.expanduser(root) | |||
if not filename: | |||
filename = file_id | |||
fpath = os.path.join(root, filename) | |||
os.makedirs(root, exist_ok=True) | |||
if os.path.isfile(fpath) and check_integrity(fpath, md5): | |||
print('Using downloaded and verified file: ' + fpath) | |||
else: | |||
session = requests.Session() | |||
response = session.get(url, params={'id': file_id}, stream=True) | |||
token = _get_confirm_token(response) | |||
if token: | |||
params = {'id': file_id, 'confirm': token} | |||
response = session.get(url, params=params, stream=True) | |||
if _quota_exceeded(response): | |||
msg = ( | |||
f"The daily quota of the file {filename} is exceeded and it " | |||
f"can't be downloaded. This is a limitation of Google Drive " | |||
f"and can only be overcome by trying again later." | |||
) | |||
raise RuntimeError(msg) | |||
_save_response_content(response, fpath) | |||
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] | |||
for key, value in response.cookies.items(): | |||
if key.startswith('download_warning'): | |||
return value | |||
return None | |||
def _save_response_content( | |||
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined] | |||
) -> None: | |||
with open(destination, "wb") as f: | |||
pbar = tqdm(total=None) | |||
progress = 0 | |||
for chunk in response.iter_content(chunk_size): | |||
if chunk: # filter out keep-alive new chunks | |||
f.write(chunk) | |||
progress += len(chunk) | |||
pbar.update(progress - pbar.n) | |||
pbar.close() | |||
def _is_tarxz(filename: str) -> bool: | |||
return filename.endswith(".tar.xz") | |||
def _is_tar(filename: str) -> bool: | |||
return filename.endswith(".tar") | |||
def _is_targz(filename: str) -> bool: | |||
return filename.endswith(".tar.gz") | |||
def _is_tgz(filename: str) -> bool: | |||
return filename.endswith(".tgz") | |||
def _is_gzip(filename: str) -> bool: | |||
return filename.endswith(".gz") and not filename.endswith(".tar.gz") | |||
def _is_zip(filename: str) -> bool: | |||
return filename.endswith(".zip") | |||
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: | |||
if to_path is None: | |||
to_path = os.path.dirname(from_path) | |||
if _is_tar(from_path): | |||
with tarfile.open(from_path, 'r') as tar: | |||
tar.extractall(path=to_path) | |||
elif _is_targz(from_path) or _is_tgz(from_path): | |||
with tarfile.open(from_path, 'r:gz') as tar: | |||
tar.extractall(path=to_path) | |||
elif _is_tarxz(from_path): | |||
with tarfile.open(from_path, 'r:xz') as tar: | |||
tar.extractall(path=to_path) | |||
elif _is_gzip(from_path): | |||
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) | |||
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: | |||
out_f.write(zip_f.read()) | |||
elif _is_zip(from_path): | |||
with zipfile.ZipFile(from_path, 'r') as z: | |||
z.extractall(to_path) | |||
else: | |||
raise ValueError("Extraction of {} not supported".format(from_path)) | |||
if remove_finished: | |||
os.remove(from_path) | |||
def download_and_extract_archive( | |||
url: str, | |||
download_root: str, | |||
extract_root: Optional[str] = None, | |||
filename: Optional[str] = None, | |||
md5: Optional[str] = None, | |||
remove_finished: bool = False, | |||
) -> None: | |||
download_root = os.path.expanduser(download_root) | |||
if extract_root is None: | |||
extract_root = download_root | |||
if not filename: | |||
filename = os.path.basename(url) | |||
download_url(url, download_root, filename, md5) | |||
archive = os.path.join(download_root, filename) | |||
print("Extracting {} to {}".format(archive, extract_root)) | |||
# print(archive) | |||
# print(extract_root) | |||
# extract_archive(archive, extract_root, remove_finished) | |||
def iterable_to_str(iterable: Iterable) -> str: | |||
return "'" + "', '".join([str(item) for item in iterable]) + "'" | |||
T = TypeVar("T", str, bytes) | |||
def verify_str_arg( | |||
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, | |||
) -> T: | |||
if not isinstance(value, torch._six.string_classes): | |||
if arg is None: | |||
msg = "Expected type str, but got type {type}." | |||
else: | |||
msg = "Expected type str for argument {arg}, but got type {type}." | |||
msg = msg.format(type=type(value), arg=arg) | |||
raise ValueError(msg) | |||
if valid_values is None: | |||
return value | |||
if value not in valid_values: | |||
if custom_msg is not None: | |||
msg = custom_msg | |||
else: | |||
msg = ("Unknown value '{value}' for argument {arg}. " | |||
"Valid values are {{{valid_values}}}.") | |||
msg = msg.format(value=value, arg=arg, | |||
valid_values=iterable_to_str(valid_values)) | |||
raise ValueError(msg) | |||
return value |
@@ -0,0 +1,559 @@ | |||
import errno | |||
import hashlib | |||
import os | |||
import re | |||
import shutil | |||
import sys | |||
import tempfile | |||
# import torch | |||
import warnings | |||
import zipfile | |||
from urllib.request import urlopen, Request | |||
from urllib.parse import urlparse # noqa: F401 | |||
try: | |||
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available | |||
except ImportError: | |||
try: | |||
from tqdm import tqdm | |||
except ImportError: | |||
# fake tqdm if it's not installed | |||
class tqdm(object): # type: ignore | |||
def __init__(self, total=None, disable=False, | |||
unit=None, unit_scale=None, unit_divisor=None): | |||
self.total = total | |||
self.disable = disable | |||
self.n = 0 | |||
# ignore unit, unit_scale, unit_divisor; they're just for real tqdm | |||
def update(self, n): | |||
if self.disable: | |||
return | |||
self.n += n | |||
if self.total is None: | |||
sys.stderr.write("\r{0:.1f} bytes".format(self.n)) | |||
else: | |||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||
sys.stderr.flush() | |||
def __enter__(self): | |||
return self | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
if self.disable: | |||
return | |||
sys.stderr.write('\n') | |||
# # matches bfd8deac from resnet18-bfd8deac.pth | |||
# HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||
# | |||
# MASTER_BRANCH = 'master' | |||
# ENV_TORCH_HOME = 'TORCH_HOME' | |||
# ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' | |||
# DEFAULT_CACHE_DIR = '~/.cache' | |||
# VAR_DEPENDENCY = 'dependencies' | |||
# MODULE_HUBCONF = 'hubconf.py' | |||
# READ_DATA_CHUNK = 8192 | |||
# _hub_dir = None | |||
# | |||
# | |||
# # Copied from tools/shared/module_loader to be included in torch package | |||
# def import_module(name, path): | |||
# import importlib.util | |||
# from importlib.abc import Loader | |||
# spec = importlib.util.spec_from_file_location(name, path) | |||
# module = importlib.util.module_from_spec(spec) | |||
# assert isinstance(spec.loader, Loader) | |||
# spec.loader.exec_module(module) | |||
# return module | |||
# | |||
# | |||
# def _remove_if_exists(path): | |||
# if os.path.exists(path): | |||
# if os.path.isfile(path): | |||
# os.remove(path) | |||
# else: | |||
# shutil.rmtree(path) | |||
# | |||
# | |||
# def _git_archive_link(repo_owner, repo_name, branch): | |||
# return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch) | |||
# | |||
# | |||
# def _load_attr_from_module(module, func_name): | |||
# # Check if callable is defined in the module | |||
# if func_name not in dir(module): | |||
# return None | |||
# return getattr(module, func_name) | |||
# | |||
# | |||
# def _get_torch_home(): | |||
# torch_home = os.path.expanduser( | |||
# os.getenv(ENV_TORCH_HOME, | |||
# os.path.join(os.getenv(ENV_XDG_CACHE_HOME, | |||
# DEFAULT_CACHE_DIR), 'torch'))) | |||
# return torch_home | |||
# | |||
# | |||
# def _parse_repo_info(github): | |||
# branch = MASTER_BRANCH | |||
# if ':' in github: | |||
# repo_info, branch = github.split(':') | |||
# else: | |||
# repo_info = github | |||
# repo_owner, repo_name = repo_info.split('/') | |||
# return repo_owner, repo_name, branch | |||
# | |||
# | |||
# def _get_cache_or_reload(github, force_reload, verbose=True): | |||
# # Setup hub_dir to save downloaded files | |||
# hub_dir = get_dir() | |||
# if not os.path.exists(hub_dir): | |||
# os.makedirs(hub_dir) | |||
# # Parse github repo information | |||
# repo_owner, repo_name, branch = _parse_repo_info(github) | |||
# # Github allows branch name with slash '/', | |||
# # this causes confusion with path on both Linux and Windows. | |||
# # Backslash is not allowed in Github branch name so no need to | |||
# # to worry about it. | |||
# normalized_br = branch.replace('/', '_') | |||
# # Github renames folder repo-v1.x.x to repo-1.x.x | |||
# # We don't know the repo name before downloading the zip file | |||
# # and inspect name from it. | |||
# # To check if cached repo exists, we need to normalize folder names. | |||
# repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br])) | |||
# | |||
# use_cache = (not force_reload) and os.path.exists(repo_dir) | |||
# | |||
# if use_cache: | |||
# if verbose: | |||
# sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) | |||
# else: | |||
# cached_file = os.path.join(hub_dir, normalized_br + '.zip') | |||
# _remove_if_exists(cached_file) | |||
# | |||
# url = _git_archive_link(repo_owner, repo_name, branch) | |||
# sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file)) | |||
# download_url_to_file(url, cached_file, progress=False) | |||
# | |||
# with zipfile.ZipFile(cached_file) as cached_zipfile: | |||
# extraced_repo_name = cached_zipfile.infolist()[0].filename | |||
# extracted_repo = os.path.join(hub_dir, extraced_repo_name) | |||
# _remove_if_exists(extracted_repo) | |||
# # Unzip the code and rename the base folder | |||
# cached_zipfile.extractall(hub_dir) | |||
# | |||
# _remove_if_exists(cached_file) | |||
# _remove_if_exists(repo_dir) | |||
# shutil.move(extracted_repo, repo_dir) # rename the repo | |||
# | |||
# return repo_dir | |||
# | |||
# | |||
# def _check_module_exists(name): | |||
# if sys.version_info >= (3, 4): | |||
# import importlib.util | |||
# return importlib.util.find_spec(name) is not None | |||
# elif sys.version_info >= (3, 3): | |||
# # Special case for python3.3 | |||
# import importlib.find_loader | |||
# return importlib.find_loader(name) is not None | |||
# else: | |||
# # NB: Python2.7 imp.find_module() doesn't respect PEP 302, | |||
# # it cannot find a package installed as .egg(zip) file. | |||
# # Here we use workaround from: | |||
# # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1 | |||
# # Also imp doesn't handle hierarchical module names (names contains dots). | |||
# try: | |||
# # 1. Try imp.find_module(), which searches sys.path, but does | |||
# # not respect PEP 302 import hooks. | |||
# import imp | |||
# result = imp.find_module(name) | |||
# if result: | |||
# return True | |||
# except ImportError: | |||
# pass | |||
# path = sys.path | |||
# for item in path: | |||
# # 2. Scan path for import hooks. sys.path_importer_cache maps | |||
# # path items to optional "importer" objects, that implement | |||
# # find_module() etc. Note that path must be a subset of | |||
# # sys.path for this to work. | |||
# importer = sys.path_importer_cache.get(item) | |||
# if importer: | |||
# try: | |||
# result = importer.find_module(name, [item]) | |||
# if result: | |||
# return True | |||
# except ImportError: | |||
# pass | |||
# return False | |||
# | |||
# def _check_dependencies(m): | |||
# dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) | |||
# | |||
# if dependencies is not None: | |||
# missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] | |||
# if len(missing_deps): | |||
# raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) | |||
# | |||
# | |||
# def _load_entry_from_hubconf(m, model): | |||
# if not isinstance(model, str): | |||
# raise ValueError('Invalid input: model should be a string of function name') | |||
# | |||
# # Note that if a missing dependency is imported at top level of hubconf, it will | |||
# # throw before this function. It's a chicken and egg situation where we have to | |||
# # load hubconf to know what're the dependencies, but to import hubconf it requires | |||
# # a missing package. This is fine, Python will throw proper error message for users. | |||
# _check_dependencies(m) | |||
# | |||
# func = _load_attr_from_module(m, model) | |||
# | |||
# if func is None or not callable(func): | |||
# raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) | |||
# | |||
# return func | |||
# | |||
# | |||
# def get_dir(): | |||
# r""" | |||
# Get the Torch Hub cache directory used for storing downloaded models & weights. | |||
# | |||
# If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where | |||
# environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. | |||
# ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux | |||
# filesystem layout, with a default value ``~/.cache`` if the environment | |||
# variable is not set. | |||
# """ | |||
# # Issue warning to move data if old env is set | |||
# if os.getenv('TORCH_HUB'): | |||
# warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') | |||
# | |||
# if _hub_dir is not None: | |||
# return _hub_dir | |||
# return os.path.join(_get_torch_home(), 'hub') | |||
# | |||
# | |||
# def set_dir(d): | |||
# r""" | |||
# Optionally set the Torch Hub directory used to save downloaded models & weights. | |||
# | |||
# Args: | |||
# d (string): path to a local folder to save downloaded models & weights. | |||
# """ | |||
# global _hub_dir | |||
# _hub_dir = d | |||
# | |||
# | |||
# def list(github, force_reload=False): | |||
# r""" | |||
# List all entrypoints available in `github` hubconf. | |||
# | |||
# Args: | |||
# github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional | |||
# tag/branch. The default branch is `master` if not specified. | |||
# Example: 'pytorch/vision[:hub]' | |||
# force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||
# Default is `False`. | |||
# Returns: | |||
# entrypoints: a list of available entrypoint names | |||
# | |||
# Example: | |||
# >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) | |||
# """ | |||
# repo_dir = _get_cache_or_reload(github, force_reload, True) | |||
# | |||
# sys.path.insert(0, repo_dir) | |||
# | |||
# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||
# | |||
# sys.path.remove(repo_dir) | |||
# | |||
# # We take functions starts with '_' as internal helper functions | |||
# entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] | |||
# | |||
# return entrypoints | |||
# | |||
# | |||
# def help(github, model, force_reload=False): | |||
# r""" | |||
# Show the docstring of entrypoint `model`. | |||
# | |||
# Args: | |||
# github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional | |||
# tag/branch. The default branch is `master` if not specified. | |||
# Example: 'pytorch/vision[:hub]' | |||
# model (string): a string of entrypoint name defined in repo's hubconf.py | |||
# force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||
# Default is `False`. | |||
# Example: | |||
# >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) | |||
# """ | |||
# repo_dir = _get_cache_or_reload(github, force_reload, True) | |||
# | |||
# sys.path.insert(0, repo_dir) | |||
# | |||
# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||
# | |||
# sys.path.remove(repo_dir) | |||
# | |||
# entry = _load_entry_from_hubconf(hub_module, model) | |||
# | |||
# return entry.__doc__ | |||
# | |||
# | |||
# # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, | |||
# # but Python2 complains syntax error for it. We have to skip force_reload in function | |||
# # signature here but detect it in kwargs instead. | |||
# # TODO: fix it after Python2 EOL | |||
# def load(repo_or_dir, model, *args, **kwargs): | |||
# r""" | |||
# Load a model from a github repo or a local directory. | |||
# | |||
# Note: Loading a model is the typical use case, but this can also be used to | |||
# for loading other objects such as tokenizers, loss functions, etc. | |||
# | |||
# If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be | |||
# of the form ``repo_owner/repo_name[:tag_name]`` with an optional | |||
# tag/branch. | |||
# | |||
# If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a | |||
# path to a local directory. | |||
# | |||
# Args: | |||
# repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``), | |||
# if ``source = 'github'``; or a path to a local directory, if | |||
# ``source = 'local'``. | |||
# model (string): the name of a callable (entrypoint) defined in the | |||
# repo/dir's ``hubconf.py``. | |||
# *args (optional): the corresponding args for callable :attr:`model`. | |||
# source (string, optional): ``'github'`` | ``'local'``. Specifies how | |||
# ``repo_or_dir`` is to be interpreted. Default is ``'github'``. | |||
# force_reload (bool, optional): whether to force a fresh download of | |||
# the github repo unconditionally. Does not have any effect if | |||
# ``source = 'local'``. Default is ``False``. | |||
# verbose (bool, optional): If ``False``, mute messages about hitting | |||
# local caches. Note that the message about first download cannot be | |||
# muted. Does not have any effect if ``source = 'local'``. | |||
# Default is ``True``. | |||
# **kwargs (optional): the corresponding kwargs for callable | |||
# :attr:`model`. | |||
# | |||
# Returns: | |||
# The output of the :attr:`model` callable when called with the given | |||
# ``*args`` and ``**kwargs``. | |||
# | |||
# Example: | |||
# >>> # from a github repo | |||
# >>> repo = 'pytorch/vision' | |||
# >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) | |||
# >>> # from a local directory | |||
# >>> path = '/some/local/path/pytorch/vision' | |||
# >>> model = torch.hub.load(path, 'resnet50', pretrained=True) | |||
# """ | |||
# source = kwargs.pop('source', 'github').lower() | |||
# force_reload = kwargs.pop('force_reload', False) | |||
# verbose = kwargs.pop('verbose', True) | |||
# | |||
# if source not in ('github', 'local'): | |||
# raise ValueError( | |||
# f'Unknown source: "{source}". Allowed values: "github" | "local".') | |||
# | |||
# if source == 'github': | |||
# repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose) | |||
# | |||
# model = _load_local(repo_or_dir, model, *args, **kwargs) | |||
# return model | |||
# | |||
# | |||
# def _load_local(hubconf_dir, model, *args, **kwargs): | |||
# r""" | |||
# Load a model from a local directory with a ``hubconf.py``. | |||
# | |||
# Args: | |||
# hubconf_dir (string): path to a local directory that contains a | |||
# ``hubconf.py``. | |||
# model (string): name of an entrypoint defined in the directory's | |||
# `hubconf.py`. | |||
# *args (optional): the corresponding args for callable ``model``. | |||
# **kwargs (optional): the corresponding kwargs for callable ``model``. | |||
# | |||
# Returns: | |||
# a single model with corresponding pretrained weights. | |||
# | |||
# Example: | |||
# >>> path = '/some/local/path/pytorch/vision' | |||
# >>> model = _load_local(path, 'resnet50', pretrained=True) | |||
# """ | |||
# sys.path.insert(0, hubconf_dir) | |||
# | |||
# hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) | |||
# hub_module = import_module(MODULE_HUBCONF, hubconf_path) | |||
# | |||
# entry = _load_entry_from_hubconf(hub_module, model) | |||
# model = entry(*args, **kwargs) | |||
# | |||
# sys.path.remove(hubconf_dir) | |||
# | |||
# return model | |||
# | |||
# | |||
# def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||
# r"""Download object at the given URL to a local path. | |||
# | |||
# Args: | |||
# url (string): URL of the object to download | |||
# dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` | |||
# hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. | |||
# Default: None | |||
# progress (bool, optional): whether or not to display a progress bar to stderr | |||
# Default: True | |||
# | |||
# Example: | |||
# >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') | |||
# | |||
# """ | |||
# file_size = None | |||
# # We use a different API for python2 since urllib(2) doesn't recognize the CA | |||
# # certificates in older Python | |||
# req = Request(url, headers={"User-Agent": "torch.hub"}) | |||
# u = urlopen(req) | |||
# meta = u.info() | |||
# if hasattr(meta, 'getheaders'): | |||
# content_length = meta.getheaders("Content-Length") | |||
# else: | |||
# content_length = meta.get_all("Content-Length") | |||
# if content_length is not None and len(content_length) > 0: | |||
# file_size = int(content_length[0]) | |||
# | |||
# # We deliberately save it in a temp file and move it after | |||
# # download is complete. This prevents a local working checkpoint | |||
# # being overridden by a broken download. | |||
# dst = os.path.expanduser(dst) | |||
# dst_dir = os.path.dirname(dst) | |||
# f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |||
# | |||
# try: | |||
# if hash_prefix is not None: | |||
# sha256 = hashlib.sha256() | |||
# with tqdm(total=file_size, disable=not progress, | |||
# unit='B', unit_scale=True, unit_divisor=1024) as pbar: | |||
# while True: | |||
# buffer = u.read(8192) | |||
# if len(buffer) == 0: | |||
# break | |||
# f.write(buffer) | |||
# if hash_prefix is not None: | |||
# sha256.update(buffer) | |||
# pbar.update(len(buffer)) | |||
# | |||
# f.close() | |||
# if hash_prefix is not None: | |||
# digest = sha256.hexdigest() | |||
# if digest[:len(hash_prefix)] != hash_prefix: | |||
# raise RuntimeError('invalid hash value (expected "{}", got "{}")' | |||
# .format(hash_prefix, digest)) | |||
# shutil.move(f.name, dst) | |||
# finally: | |||
# f.close() | |||
# if os.path.exists(f.name): | |||
# os.remove(f.name) | |||
# | |||
# def _download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||
# warnings.warn('torch.hub._download_url_to_file has been renamed to\ | |||
# torch.hub.download_url_to_file to be a public API,\ | |||
# _download_url_to_file will be removed in after 1.3 release') | |||
# download_url_to_file(url, dst, hash_prefix, progress) | |||
# | |||
# # Hub used to support automatically extracts from zipfile manually compressed by users. | |||
# # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. | |||
# # We should remove this support since zipfile is now default zipfile format for torch.save(). | |||
# def _is_legacy_zip_format(filename): | |||
# if zipfile.is_zipfile(filename): | |||
# infolist = zipfile.ZipFile(filename).infolist() | |||
# return len(infolist) == 1 and not infolist[0].is_dir() | |||
# return False | |||
# | |||
# def _legacy_zip_load(filename, model_dir, map_location): | |||
# warnings.warn('Falling back to the old format < 1.6. This support will be ' | |||
# 'deprecated in favor of default zipfile format introduced in 1.6. ' | |||
# 'Please redo torch.save() to save it in the new zipfile format.') | |||
# # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. | |||
# # We deliberately don't handle tarfile here since our legacy serialization format was in tar. | |||
# # E.g. resnet18-5c106cde.pth which is widely used. | |||
# with zipfile.ZipFile(filename) as f: | |||
# members = f.infolist() | |||
# if len(members) != 1: | |||
# raise RuntimeError('Only one file(not dir) is allowed in the zipfile') | |||
# f.extractall(model_dir) | |||
# extraced_name = members[0].filename | |||
# extracted_file = os.path.join(model_dir, extraced_name) | |||
# return torch.load(extracted_file, map_location=map_location) | |||
# | |||
# def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): | |||
# r"""Loads the Torch serialized object at the given URL. | |||
# | |||
# If downloaded file is a zip file, it will be automatically | |||
# decompressed. | |||
# | |||
# If the object is already present in `model_dir`, it's deserialized and | |||
# returned. | |||
# The default value of `model_dir` is ``<hub_dir>/checkpoints`` where | |||
# `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. | |||
# | |||
# Args: | |||
# url (string): URL of the object to download | |||
# model_dir (string, optional): directory in which to save the object | |||
# map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |||
# progress (bool, optional): whether or not to display a progress bar to stderr. | |||
# Default: True | |||
# check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention | |||
# ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||
# digits of the SHA256 hash of the contents of the file. The hash is used to | |||
# ensure unique names and to verify the contents of the file. | |||
# Default: False | |||
# file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. | |||
# | |||
# Example: | |||
# >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |||
# | |||
# """ | |||
# # Issue warning to move data if old env is set | |||
# if os.getenv('TORCH_MODEL_ZOO'): | |||
# warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') | |||
# | |||
# if model_dir is None: | |||
# hub_dir = get_dir() | |||
# model_dir = os.path.join(hub_dir, 'checkpoints') | |||
# | |||
# try: | |||
# os.makedirs(model_dir) | |||
# except OSError as e: | |||
# if e.errno == errno.EEXIST: | |||
# # Directory already exists, ignore. | |||
# pass | |||
# else: | |||
# # Unexpected OSError, re-raise. | |||
# raise | |||
# | |||
# parts = urlparse(url) | |||
# filename = os.path.basename(parts.path) | |||
# if file_name is not None: | |||
# filename = file_name | |||
# cached_file = os.path.join(model_dir, filename) | |||
# if not os.path.exists(cached_file): | |||
# sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||
# hash_prefix = None | |||
# if check_hash: | |||
# r = HASH_REGEX.search(filename) # r is Optional[Match[str]] | |||
# hash_prefix = r.group(1) if r else None | |||
# download_url_to_file(url, cached_file, hash_prefix, progress=progress) | |||
# | |||
# if _is_legacy_zip_format(cached_file): | |||
# return _legacy_zip_load(cached_file, model_dir, map_location) | |||
# return torch.load(cached_file, map_location=map_location) |
@@ -0,0 +1,139 @@ | |||
import warnings | |||
import os | |||
import os.path | |||
import numpy as np | |||
import codecs | |||
import string | |||
import gzip | |||
import lzma | |||
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union | |||
from dataset_utils import download_url, download_and_extract_archive, extract_archive, verify_str_arg | |||
# tqdm >= 4.31.1 | |||
from tods import generate_dataset | |||
from sklearn import preprocessing | |||
import pandas as pd | |||
class TODS_dataset: | |||
resources = [] | |||
training_file = None | |||
testing_file = None | |||
ground_truth_index = None | |||
_repr_indent = None | |||
@property | |||
def raw_folder(self) -> str: | |||
return os.path.join(self.root, self.__class__.__name__, 'raw') | |||
@property | |||
def processed_folder(self) -> str: | |||
return os.path.join(self.root, self.__class__.__name__, 'processed') | |||
def __init__(self, root, train, transform=None, download=True): | |||
self.root = root | |||
self.train = train | |||
self.transform = self.transform_init(transform) | |||
if download: | |||
self.download() | |||
pass | |||
self.process() | |||
def _check_exists(self) -> bool: | |||
return (os.path.exists(os.path.join(self.processed_folder, | |||
self.training_file)) and | |||
os.path.exists(os.path.join(self.processed_folder, | |||
self.testing_file))) | |||
def download(self) -> None: | |||
if self._check_exists(): | |||
return | |||
os.makedirs(self.raw_folder, exist_ok=True) | |||
# download files | |||
for url, md5 in self.resources: | |||
filename = url.rpartition('/')[2] | |||
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) | |||
def process(self) -> None: | |||
pass | |||
def process_dataframe(self) -> None: | |||
if self.transform is None: | |||
pass | |||
else: | |||
self.transform.fit(self.training_set_dataframe) | |||
self.training_set_array = self.transform.transform(self.training_set_dataframe.values) | |||
self.testing_set_array = self.transform.transform(self.testing_set_dataframe.values) | |||
self.training_set_dataframe = pd.DataFrame(self.training_set_array) | |||
self.testing_set_dataframe = pd.DataFrame(self.testing_set_array) | |||
def transform_init(self, transform_str): | |||
if transform_str is None: | |||
return None | |||
elif transform_str == 'standardscale': | |||
return preprocessing.StandardScaler() | |||
elif transform_str == 'normalize': | |||
return preprocessing.Normalizer() | |||
elif transform_str == 'minmaxscale': | |||
return preprocessing.MinMaxScaler() | |||
elif transform_str == 'maxabsscale': | |||
return preprocessing.MaxAbsScaler() | |||
elif transform_str == 'binarize': | |||
return preprocessing.Binarizer() | |||
else: | |||
raise ValueError("Input parameter transform must take value of 'standardscale', 'normalize', " + | |||
"'minmaxscale', 'maxabsscale' or 'binarize'." | |||
) | |||
def to_axolotl_dataset(self): | |||
if self.train: | |||
return generate_dataset(self.training_set_dataframe, self.ground_truth_index) | |||
else: | |||
return generate_dataset(self.testing_set_dataframe, self.ground_truth_index) | |||
def __repr__(self) -> str: | |||
head = "Dataset " + self.__class__.__name__ | |||
body = ["Number of datapoints: {}".format(self.__len__())] | |||
if self.root is not None: | |||
body.append("Root location: {}".format(self.root)) | |||
body += self.extra_repr().splitlines() | |||
if hasattr(self, "transforms") and self.transforms is not None: | |||
body += [repr(self.transforms)] | |||
lines = [head] + [" " * self._repr_indent + line for line in body] | |||
print(self.training_set_dataframe) | |||
return '\n'.join(lines) | |||
def __len__(self) -> int: | |||
return len(self.training_set_dataframe) | |||
def extra_repr(self) -> str: | |||
return "" | |||
# kpi(root='./datasets', train=True) | |||
# class yahoo5: | |||
# | |||
# def __init__(self): | |||
# pass |
@@ -0,0 +1,116 @@ | |||
import os | |||
import pandas as pd | |||
from tods_dataset_base import TODS_dataset | |||
from shutil import copyfile | |||
class kpi_dataset(TODS_dataset): | |||
resources = [ | |||
# ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||
# ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||
# ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||
("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||
# needs a server to store the dataset. | |||
# ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
] | |||
training_file = 'learningData.csv' | |||
testing_file = 'testingData.csv' | |||
ground_truth_index = 3 | |||
_repr_indent = 4 | |||
# def __init__(self, root, train, transform=None, target_transform=None, download=True): | |||
# super().__init__(root, train, transform=None, target_transform=None, download=True) | |||
def process(self) -> None: | |||
print('Processing...') | |||
os.makedirs(self.processed_folder, exist_ok=True) | |||
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||
training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||
self.training_set_dataframe = pd.read_csv(training_set_fname) | |||
testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||
self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||
self.process_dataframe() | |||
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||
copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||
print('Done!') | |||
class yahoo_dataset(TODS_dataset): | |||
resources = [ | |||
# ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||
# ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||
# ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||
("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||
# needs a server to store the dataset. | |||
# ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
] | |||
training_file = 'learningData.csv' | |||
testing_file = 'testingData.csv' | |||
ground_truth_index = 7 | |||
_repr_indent = 4 | |||
def process(self) -> None: | |||
print('Processing...') | |||
os.makedirs(self.processed_folder, exist_ok=True) | |||
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||
training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||
self.training_set_dataframe = pd.read_csv(training_set_fname) | |||
testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||
self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||
self.process_dataframe() | |||
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||
copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||
print('Done!') | |||
class NAB_dataset(TODS_dataset): | |||
resources = [ | |||
("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.csv", None), | |||
# it needs md5 to check if local learningData.csv is the same with online. | |||
("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.json", None), | |||
# needs a server to store the dataset. | |||
] | |||
training_file = 'learningData.csv' | |||
testing_file = 'testingData.csv' | |||
ground_truth_index = 2 | |||
_repr_indent = 4 | |||
def process(self) -> None: | |||
print('Processing...') | |||
os.makedirs(self.processed_folder, exist_ok=True) | |||
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||
training_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') | |||
self.training_set_dataframe = pd.read_csv(training_set_fname) | |||
testing_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') # temperarily same with training set | |||
self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||
self.process_dataframe() | |||
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||
copyfile(os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.json'), | |||
os.path.join(self.processed_folder, 'datasetDoc.json')) | |||
print('Done!') | |||
# kpi_dataset(root='./datasets', train=True, transform='binarize') | |||
# yahoo_dataset(root='./datasets', train=True, transform='binarize') | |||
# NAB_dataset(root='./datasets', train=True, transform='binarize') |
@@ -1 +0,0 @@ | |||
Subproject commit af54e6970476a081bf0cd65990c9f56a1200d8a2 |
@@ -1 +0,0 @@ | |||
Subproject commit 046b20d2f6d4543dcbe18f0a1d4bcbb1f61cf518 |
@@ -1 +0,0 @@ | |||
Subproject commit 70aeefed6b7307941581357c4b7858bb3f88e1da |
@@ -0,0 +1,116 @@ | |||
import os | |||
import typing | |||
import numpy | |||
import pandas | |||
from d3m import container, exceptions, utils as d3m_utils | |||
from d3m.metadata import base as metadata_base, hyperparams | |||
from d3m.base import primitives | |||
__all__ = ('FixedSplitDatasetSplitPrimitive',) | |||
class Hyperparams(hyperparams.Hyperparams): | |||
primary_index_values = hyperparams.Set( | |||
elements=hyperparams.Hyperparameter[str](''), | |||
default=(), | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description='A set of primary index values of the main resource belonging to the test (score) split. Cannot be set together with "row_indices".', | |||
) | |||
row_indices = hyperparams.Set( | |||
elements=hyperparams.Hyperparameter[int](-1), | |||
default=(), | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description='A set of row indices of the main resource belonging to the test (score) split. Cannot be set together with "primary_index_values".', | |||
) | |||
delete_recursive = hyperparams.Hyperparameter[bool]( | |||
default=False, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||
) | |||
class FixedSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
""" | |||
A primitive which splits a tabular Dataset in a way that uses for the test | |||
(score) split a fixed list of primary index values or row indices of the main | |||
resource to be used. All other rows are added used for the train split. | |||
""" | |||
metadata = metadata_base.PrimitiveMetadata( | |||
{ | |||
'id': '1654f000-2178-4520-be4c-a95bc26b8d3a', | |||
'version': '0.1.0', | |||
'name': "Fixed split tabular dataset splits", | |||
'python_path': 'd3m.primitives.tods.evaluation.fixed_split_dataset_split', | |||
'source': { | |||
'name': "DATALab@TexasA&M University", | |||
'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||
'uris': [ | |||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/fixed_split.py', | |||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
], | |||
}, | |||
'algorithm_types': [ | |||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
], | |||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
}, | |||
) | |||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
# This should be handled by "Set" hyper-parameter, but we check it here again just to be sure. | |||
if d3m_utils.has_duplicates(self.hyperparams['primary_index_values']): | |||
raise exceptions.InvalidArgumentValueError("\"primary_index_values\" hyper-parameter has duplicate values.") | |||
if d3m_utils.has_duplicates(self.hyperparams['row_indices']): | |||
raise exceptions.InvalidArgumentValueError("\"row_indices\" hyper-parameter has duplicate values.") | |||
if self.hyperparams['primary_index_values'] and self.hyperparams['row_indices']: | |||
raise exceptions.InvalidArgumentValueError("Both \"primary_index_values\" and \"row_indices\" cannot be provided.") | |||
if self.hyperparams['primary_index_values']: | |||
primary_index_values = numpy.array(self.hyperparams['primary_index_values']) | |||
index_columns = dataset.metadata.get_index_columns(at=(main_resource_id,)) | |||
if not index_columns: | |||
raise exceptions.InvalidArgumentValueError("Cannot find index columns in the main resource of the dataset, but \"primary_index_values\" is provided.") | |||
main_resource = dataset[main_resource_id] | |||
# We reset the index so that the index corresponds to row indices. | |||
main_resource = main_resource.reset_index(drop=True) | |||
# We use just the "d3mIndex" column and ignore multi-key indices. | |||
# This works for now because it seems that every current multi-key | |||
# dataset in fact has an unique value in "d3mIndex" alone. | |||
# See: https://gitlab.datadrivendiscovery.org/MIT-LL/d3m_data_supply/issues/117 | |||
index_column = index_columns[0] | |||
score_data = numpy.array(main_resource.loc[main_resource.iloc[:, index_column].isin(primary_index_values)].index) | |||
score_data_set = set(score_data) | |||
assert len(score_data) == len(score_data_set), (len(score_data), len(score_data_set)) | |||
if len(score_data) != len(primary_index_values): | |||
raise exceptions.InvalidArgumentValueError("\"primary_index_values\" contains values which do not exist.") | |||
else: | |||
score_data = numpy.array(self.hyperparams['row_indices']) | |||
score_data_set = set(score_data) | |||
all_data_set = set(numpy.arange(len(attributes))) | |||
if not score_data_set <= all_data_set: | |||
raise exceptions.InvalidArgumentValueError("\"row_indices\" contains indices which do not exist, e.g., {indices}.".format( | |||
indices=sorted(score_data_set - all_data_set)[:5], | |||
)) | |||
train_data = [] | |||
for i in numpy.arange(len(attributes)): | |||
if i not in score_data_set: | |||
train_data.append(i) | |||
assert len(train_data) + len(score_data) == len(attributes), (len(train_data), len(score_data), len(attributes)) | |||
return [(numpy.array(train_data), score_data)] |
@@ -0,0 +1,87 @@ | |||
import os | |||
import typing | |||
import numpy | |||
import pandas | |||
from sklearn import model_selection | |||
from d3m import container, exceptions, utils as d3m_utils | |||
from d3m.metadata import base as metadata_base, hyperparams | |||
from d3m.base import primitives | |||
__all__ = ('KFoldDatasetSplitPrimitive',) | |||
class Hyperparams(hyperparams.Hyperparams): | |||
number_of_folds = hyperparams.Bounded[int]( | |||
lower=2, | |||
upper=None, | |||
default=5, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Number of folds for k-folds cross-validation.", | |||
) | |||
stratified = hyperparams.UniformBool( | |||
default=False, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||
) | |||
shuffle = hyperparams.UniformBool( | |||
default=False, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Whether to shuffle the data before splitting into batches.", | |||
) | |||
delete_recursive = hyperparams.Hyperparameter[bool]( | |||
default=False, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||
) | |||
class KFoldDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
""" | |||
A primitive which splits a tabular Dataset for k-fold cross-validation. | |||
""" | |||
__author__ = 'Mingjie Sun <sunmj15@gmail.com>' | |||
metadata = metadata_base.PrimitiveMetadata( | |||
{ | |||
'id': 'bfedaf3a-6dd0-4a83-ad83-3a50fe882bf8', | |||
'version': '0.1.0', | |||
'name': "K-fold cross-validation tabular dataset splits", | |||
'python_path': 'd3m.primitives.tods.evaluation.kfold_dataset_split', | |||
'source': { | |||
'name': 'DATALab@Texas A&M University', | |||
'contact': 'mailto:sunmj15@gmail.com', | |||
'uris': [ | |||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split.py', | |||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
], | |||
}, | |||
'algorithm_types': [ | |||
metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||
metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
], | |||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
}, | |||
) | |||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
if self.hyperparams['stratified']: | |||
if not len(targets.columns): | |||
raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||
k_fold = model_selection.StratifiedKFold( | |||
n_splits=self.hyperparams['number_of_folds'], | |||
shuffle=self.hyperparams['shuffle'], | |||
random_state=self._random_state, | |||
) | |||
else: | |||
k_fold = model_selection.KFold( | |||
n_splits=self.hyperparams['number_of_folds'], | |||
shuffle=self.hyperparams['shuffle'], | |||
random_state=self._random_state, | |||
) | |||
return list(k_fold.split(attributes, targets)) |
@@ -0,0 +1,187 @@ | |||
import os | |||
import uuid | |||
import typing | |||
from collections import OrderedDict | |||
import numpy | |||
import pandas | |||
from sklearn import model_selection | |||
from d3m import container, exceptions, utils as d3m_utils | |||
from d3m.metadata import base as metadata_base, hyperparams | |||
from d3m.base import primitives | |||
import utils | |||
__all__ = ('KFoldTimeSeriesSplitPrimitive',) | |||
class Hyperparams(hyperparams.Hyperparams): | |||
number_of_folds = hyperparams.Bounded[int]( | |||
lower=2, | |||
upper=None, | |||
default=5, | |||
semantic_types=[ | |||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
], | |||
description="Number of folds for k-folds cross-validation.", | |||
) | |||
number_of_window_folds = hyperparams.Union[typing.Union[int, None]]( | |||
configuration=OrderedDict( | |||
fixed=hyperparams.Bounded[int]( | |||
lower=1, | |||
upper=None, | |||
default=1, | |||
description="Number of folds in train set (window). These folds come directly " | |||
"before test set (streaming window).", | |||
), | |||
all_records=hyperparams.Constant( | |||
default=None, | |||
description="Number of folds in train set (window) = maximum number possible.", | |||
), | |||
), | |||
default='all_records', | |||
semantic_types=[ | |||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
], | |||
description="Maximum size for a single training set.", | |||
) | |||
time_column_index = hyperparams.Union[typing.Union[int, None]]( | |||
configuration=OrderedDict( | |||
fixed=hyperparams.Bounded[int]( | |||
lower=1, | |||
upper=None, | |||
default=1, | |||
description="Specific column that contains the time index", | |||
), | |||
one_column=hyperparams.Constant( | |||
default=None, | |||
description="Only one column contains a time index. " | |||
"It is detected automatically using semantic types.", | |||
), | |||
), | |||
default='one_column', | |||
semantic_types=[ | |||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
], | |||
description="Column index to use as datetime index. " | |||
"If None, it is required that only one column with time column role semantic type is " | |||
"present and otherwise an exception is raised. " | |||
"If column index specified is not a datetime column an exception is" | |||
"also raised.", | |||
) | |||
fuzzy_time_parsing = hyperparams.UniformBool( | |||
default=True, | |||
semantic_types=[ | |||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
], | |||
description="Use fuzzy time parsing.", | |||
) | |||
class KFoldTimeSeriesSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
""" | |||
A primitive which splits a tabular time-series Dataset for k-fold cross-validation. | |||
Primitive sorts the time column so care should be taken to assure sorting of a | |||
column is reasonable. E.g., if column is not numeric but of string structural type, | |||
strings should be formatted so that sorting by them also sorts by time. | |||
""" | |||
__author__ = 'Distil' | |||
__version__ = '0.3.0' | |||
__contact__ = 'mailto:jeffrey.gleason@yonder.co' | |||
metadata = metadata_base.PrimitiveMetadata( | |||
{ | |||
'id': '002f9ad1-46e3-40f4-89ed-eeffbb3a102b', | |||
'version': __version__, | |||
'name': "K-fold cross-validation timeseries dataset splits", | |||
'python_path': 'd3m.primitives.tods.evaluation.kfold_time_series_split', | |||
'source': { | |||
'name': 'DATALab@Texas A&M University', | |||
'contact': __contact__, | |||
'uris': [ | |||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split_timeseries.py', | |||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
], | |||
}, | |||
'algorithm_types': [ | |||
metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||
metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
], | |||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
}, | |||
) | |||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
time_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Time'], at=(main_resource_id,)) | |||
attribute_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Attribute'], at=(main_resource_id,)) | |||
# We want only time columns which are also attributes. | |||
time_column_indices = [time_column_index for time_column_index in time_column_indices if time_column_index in attribute_column_indices] | |||
if self.hyperparams['time_column_index'] is None: | |||
if len(time_column_indices) != 1: | |||
raise exceptions.InvalidArgumentValueError( | |||
"If \"time_column_index\" hyper-parameter is \"None\", it is required that exactly one column with time column role semantic type is present.", | |||
) | |||
else: | |||
# We know it exists because "time_column_indices" is a subset of "attribute_column_indices". | |||
time_column_index = attribute_column_indices.index( | |||
time_column_indices[0], | |||
) | |||
else: | |||
if self.hyperparams['time_column_index'] not in time_column_indices: | |||
raise exceptions.InvalidArgumentValueError( | |||
"Time column index specified does not have a time column role semantic type.", | |||
) | |||
else: | |||
time_column_index = attribute_column_indices.index( | |||
self.hyperparams['time_column_index'], | |||
) | |||
# We first reset index. | |||
attributes = attributes.reset_index(drop=True) | |||
# Then convert datetime column to consistent datetime representation | |||
attributes.insert( | |||
loc=0, | |||
column=uuid.uuid4(), # use uuid to ensure we are inserting a new column name | |||
value=self._parse_time_data( | |||
attributes, time_column_index, self.hyperparams['fuzzy_time_parsing'], | |||
), | |||
) | |||
# Then sort dataframe by new datetime column. Index contains original row order. | |||
attributes = attributes.sort_values(by=attributes.columns[0]) | |||
# Remove datetime representation used for sorting (primitives might choose to parse this str col differently). | |||
attributes = attributes.drop(attributes.columns[0], axis=1) | |||
max_train_size: typing.Optional[int] = None | |||
if self.hyperparams['number_of_window_folds'] is not None: | |||
max_train_size = int(attributes.shape[0] * self.hyperparams['number_of_window_folds'] / self.hyperparams['number_of_folds']) | |||
k_fold = model_selection.TimeSeriesSplit( | |||
n_splits=self.hyperparams['number_of_folds'], | |||
max_train_size=max_train_size | |||
) | |||
# We sorted "attributes" so we have to map indices on sorted "attributes" back to original | |||
# indices. We do that by using DataFrame's index which contains original row order. | |||
return [ | |||
( | |||
numpy.array([attributes.index[val] for val in train]), | |||
numpy.array([attributes.index[val] for val in test]), | |||
) | |||
for train, test in k_fold.split(attributes) | |||
] | |||
@classmethod | |||
def _parse_time_data(cls, inputs: container.DataFrame, column_index: metadata_base.SimpleSelectorSegment, fuzzy: bool) -> typing.List[float]: | |||
return [ | |||
utils.parse_datetime_to_float(value, fuzzy=fuzzy) | |||
for value in inputs.iloc[:, column_index] | |||
] |
@@ -0,0 +1,52 @@ | |||
import os | |||
import typing | |||
import numpy | |||
import pandas | |||
from d3m import container, utils as d3m_utils | |||
from d3m.metadata import base as metadata_base, hyperparams | |||
from d3m.base import primitives | |||
__all__ = ('NoSplitDatasetSplitPrimitive',) | |||
class Hyperparams(hyperparams.Hyperparams): | |||
pass | |||
class NoSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
""" | |||
A primitive which splits a tabular Dataset in a way that for all splits it | |||
produces the same (full) Dataset. Useful for unsupervised learning tasks. . | |||
""" | |||
metadata = metadata_base.PrimitiveMetadata( | |||
{ | |||
'id': '48c683ad-da9e-48cf-b3a0-7394dba5e5d2', | |||
'version': '0.1.0', | |||
'name': "No-split tabular dataset splits", | |||
'python_path': 'd3m.primitives.tods.evaluation.no_split_dataset_split', | |||
'source': { | |||
'name': 'DATALab@Texas A&M University', | |||
'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||
'uris': [ | |||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/no_split.py', | |||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
], | |||
}, | |||
'algorithm_types': [ | |||
metadata_base.PrimitiveAlgorithmType.IDENTITY_FUNCTION, | |||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
], | |||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
}, | |||
) | |||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
# We still go through the whole splitting process to assure full compatibility | |||
# (and error conditions) of a regular split, but we use all data for both splits. | |||
all_data = numpy.arange(len(attributes)) | |||
return [(all_data, all_data)] |
@@ -0,0 +1,160 @@ | |||
import copy | |||
import os | |||
import typing | |||
from d3m import container, exceptions, utils as d3m_utils | |||
from d3m.metadata import base as metadata_base, hyperparams | |||
from d3m.primitive_interfaces import base, transformer | |||
Inputs = container.List | |||
Outputs = container.List | |||
class Hyperparams(hyperparams.Hyperparams): | |||
match_logic = hyperparams.Enumeration( | |||
values=['all', 'any'], | |||
default='any', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Should a column have all of semantic types in \"semantic_types\" to be redacted, or any of them?", | |||
) | |||
semantic_types = hyperparams.Set( | |||
elements=hyperparams.Hyperparameter[str](''), | |||
default=(), | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Redact columns with these semantic types. Only columns having semantic types listed here will be operated on, based on \"match_logic\".", | |||
) | |||
add_semantic_types = hyperparams.Set( | |||
elements=hyperparams.Hyperparameter[str](''), | |||
default=(), | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Semantic types to add to redacted columns. All listed semantic types will be added to all columns which were redacted.", | |||
) | |||
# TODO: Make clear the assumption that both container type (List) and Datasets should have metadata. | |||
# Primitive is modifying metadata of Datasets, while there is officially no reason for them | |||
# to really have metadata: metadata is stored available on the input container type, not | |||
# values inside it. | |||
class RedactColumnsPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]): | |||
""" | |||
A primitive which takes as an input a list of ``Dataset`` objects and redacts values of all columns matching | |||
a given semantic type or types. | |||
Redaction is done by setting all values in a redacted column to an empty string. | |||
It operates only on DataFrame resources inside datasets. | |||
""" | |||
metadata = metadata_base.PrimitiveMetadata( | |||
{ | |||
'id': '744c4090-e2f6-489e-8efc-8b1e051bfad6', | |||
'version': '0.2.0', | |||
'name': "Redact columns for evaluation", | |||
'python_path': 'd3m.primitives.tods.evaluation.redact_columns', | |||
'source': { | |||
'name': 'DATALab@Texas A&M University', | |||
'contact': 'mailto:sunmj15@gmail.com', | |||
'uris': [ | |||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/redact_columns.py', | |||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
], | |||
}, | |||
'installation': [{ | |||
'type': metadata_base.PrimitiveInstallationType.PIP, | |||
'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||
git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||
), | |||
}], | |||
'algorithm_types': [ | |||
metadata_base.PrimitiveAlgorithmType.DATA_CONVERSION, | |||
], | |||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
}, | |||
) | |||
def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]: | |||
output_datasets = container.List(generate_metadata=True) | |||
for dataset in inputs: | |||
resources = {} | |||
metadata = dataset.metadata | |||
for resource_id, resource in dataset.items(): | |||
if not isinstance(resource, container.DataFrame): | |||
resources[resource_id] = resource | |||
continue | |||
columns_to_redact = self._get_columns_to_redact(metadata, (resource_id,)) | |||
if not columns_to_redact: | |||
resources[resource_id] = resource | |||
continue | |||
resource = copy.copy(resource) | |||
for column_index in columns_to_redact: | |||
column_metadata = dataset.metadata.query((resource_id, metadata_base.ALL_ELEMENTS, column_index)) | |||
if 'structural_type' in column_metadata and issubclass(column_metadata['structural_type'], str): | |||
resource.iloc[:, column_index] = '' | |||
else: | |||
raise TypeError("Primitive can operate only on columns with structural type \"str\", not \"{type}\".".format( | |||
type=column_metadata.get('structural_type', None), | |||
)) | |||
metadata = self._update_metadata(metadata, resource_id, column_index, ()) | |||
resources[resource_id] = resource | |||
dataset = container.Dataset(resources, metadata) | |||
output_datasets.append(dataset) | |||
output_datasets.metadata = metadata_base.DataMetadata({ | |||
'schema': metadata_base.CONTAINER_SCHEMA_VERSION, | |||
'structural_type': container.List, | |||
'dimension': { | |||
'length': len(output_datasets), | |||
}, | |||
}) | |||
# We update metadata based on metadata of each dataset. | |||
# TODO: In the future this might be done automatically by generate_metadata. | |||
# See: https://gitlab.com/datadrivendiscovery/d3m/issues/119 | |||
for index, dataset in enumerate(output_datasets): | |||
output_datasets.metadata = dataset.metadata.copy_to(output_datasets.metadata, (), (index,)) | |||
return base.CallResult(output_datasets) | |||
def _get_columns_to_redact(self, inputs_metadata: metadata_base.DataMetadata, at: metadata_base.Selector) -> typing.Sequence[int]: | |||
columns = [] | |||
for element in inputs_metadata.get_elements(list(at) + [metadata_base.ALL_ELEMENTS]): | |||
semantic_types = inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS, element]).get('semantic_types', ()) | |||
# TODO: Should we handle inheritance between semantic types here? | |||
if self.hyperparams['match_logic'] == 'all': | |||
matched = all(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||
elif self.hyperparams['match_logic'] == 'any': | |||
matched = any(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||
else: | |||
raise exceptions.UnexpectedValueError("Unknown value of hyper-parameter \"match_logic\": {value}".format(value=self.hyperparams['match_logic'])) | |||
if matched: | |||
if element is metadata_base.ALL_ELEMENTS: | |||
return list(range(inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS]).get('dimension', {}).get('length', 0))) | |||
else: | |||
columns.append(typing.cast(int, element)) | |||
return columns | |||
def _update_metadata( | |||
self, inputs_metadata: metadata_base.DataMetadata, resource_id: metadata_base.SelectorSegment, | |||
column_index: int, at: metadata_base.Selector, | |||
) -> metadata_base.DataMetadata: | |||
outputs_metadata = inputs_metadata | |||
for semantic_type in self.hyperparams['add_semantic_types']: | |||
outputs_metadata = outputs_metadata.add_semantic_type(tuple(at) + (resource_id, metadata_base.ALL_ELEMENTS, column_index), semantic_type) | |||
return outputs_metadata |
@@ -0,0 +1,88 @@ | |||
import os | |||
import typing | |||
import numpy | |||
import pandas | |||
from sklearn import model_selection | |||
from d3m import container, exceptions, utils as d3m_utils | |||
from d3m.metadata import base as metadata_base, hyperparams | |||
from d3m.base import primitives | |||
__all__ = ('TrainScoreDatasetSplitPrimitive',) | |||
class Hyperparams(hyperparams.Hyperparams): | |||
train_score_ratio = hyperparams.Uniform( | |||
lower=0, | |||
upper=1, | |||
default=0.75, | |||
upper_inclusive=True, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="The ratio between the train and score data and represents the proportion of the Dataset to include in the train split. The rest is included in the score split.", | |||
) | |||
stratified = hyperparams.UniformBool( | |||
default=False, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||
) | |||
shuffle = hyperparams.UniformBool( | |||
default=False, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Whether to shuffle the data before splitting into batches.", | |||
) | |||
delete_recursive = hyperparams.Hyperparameter[bool]( | |||
default=False, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||
) | |||
class TrainScoreDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
""" | |||
A primitive which splits a tabular Dataset into random train and score subsets. | |||
""" | |||
metadata = metadata_base.PrimitiveMetadata( | |||
{ | |||
'id': '3fcc6dc4-6681-4c86-948e-066d14e7d803', | |||
'version': '0.1.0', | |||
'name': "Train-score tabular dataset splits", | |||
'python_path': 'd3m.primitives.tods.evaluation.train_score_dataset_split', | |||
'source': { | |||
'name': 'DATALab@Texas A&M University', | |||
'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||
'uris': [ | |||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/train_score_split.py', | |||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
], | |||
}, | |||
'installation': [{ | |||
'type': metadata_base.PrimitiveInstallationType.PIP, | |||
'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||
git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||
), | |||
}], | |||
'algorithm_types': [ | |||
metadata_base.PrimitiveAlgorithmType.HOLDOUT, | |||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
], | |||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
}, | |||
) | |||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
if self.hyperparams['stratified'] and not len(targets.columns): | |||
raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||
train_data, score_data = model_selection.train_test_split( | |||
numpy.arange(len(attributes)), | |||
test_size=None, | |||
train_size=self.hyperparams['train_score_ratio'], | |||
random_state=self._random_state, | |||
shuffle=self.hyperparams['shuffle'], | |||
stratify=targets if self.hyperparams['stratified'] else None, | |||
) | |||
return [(train_data, score_data)] |
@@ -0,0 +1,192 @@ | |||
import datetime | |||
import logging | |||
import typing | |||
import dateutil.parser | |||
import numpy | |||
from d3m import container, deprecate | |||
from d3m.base import utils as base_utils | |||
from d3m.metadata import base as metadata_base | |||
logger = logging.getLogger(__name__) | |||
DEFAULT_DATETIME = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) | |||
@deprecate.function(message="it should not be used anymore") | |||
def copy_elements_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||
to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return source_metadata._copy_elements_metadata(target_metadata, list(from_selector), list(to_selector), [], ignore_all_elements) | |||
@deprecate.function(message="use Metadata.copy_to method instead") | |||
def copy_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||
to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return source_metadata.copy_to(target_metadata, from_selector, to_selector, ignore_all_elements=ignore_all_elements) | |||
@deprecate.function(message="use DataFrame.select_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def select_columns(inputs: container.DataFrame, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||
source: typing.Any = None) -> container.DataFrame: | |||
return inputs.select_columns(columns) | |||
@deprecate.function(message="use DataMetadata.select_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def select_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||
source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return inputs_metadata.select_columns(columns) | |||
@deprecate.function(message="use DataMetadata.list_columns_with_semantic_types method instead") | |||
def list_columns_with_semantic_types(metadata: metadata_base.DataMetadata, semantic_types: typing.Sequence[str], *, | |||
at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||
return metadata.list_columns_with_semantic_types(semantic_types, at=at) | |||
@deprecate.function(message="use DataMetadata.list_columns_with_structural_types method instead") | |||
def list_columns_with_structural_types(metadata: metadata_base.DataMetadata, structural_types: typing.Union[typing.Callable, typing.Sequence[typing.Union[str, type]]], *, | |||
at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||
return metadata.list_columns_with_structural_types(structural_types, at=at) | |||
@deprecate.function(message="use DataFrame.remove_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def remove_columns(inputs: container.DataFrame, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> container.DataFrame: | |||
return inputs.remove_columns(column_indices) | |||
@deprecate.function(message="use DataMetadata.remove_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def remove_columns_metadata(inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return inputs_metadata.remove_columns(column_indices) | |||
@deprecate.function(message="use DataFrame.append_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def append_columns(left: container.DataFrame, right: container.DataFrame, *, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||
return left.append_columns(right, use_right_metadata=use_right_metadata) | |||
@deprecate.function(message="use DataMetadata.append_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def append_columns_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return left_metadata.append_columns(right_metadata, use_right_metadata=use_right_metadata) | |||
@deprecate.function(message="use DataFrame.insert_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def insert_columns(inputs: container.DataFrame, columns: container.DataFrame, at_column_index: int, *, source: typing.Any = None) -> container.DataFrame: | |||
return inputs.insert_columns(columns, at_column_index) | |||
@deprecate.function(message="use DataMetadata.insert_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def insert_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, at_column_index: int, *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return inputs_metadata.insert_columns(columns_metadata, at_column_index) | |||
@deprecate.function(message="use DataFrame.replace_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def replace_columns(inputs: container.DataFrame, columns: container.DataFrame, column_indices: typing.Sequence[int], *, copy: bool = True, source: typing.Any = None) -> container.DataFrame: | |||
return inputs.replace_columns(columns, column_indices, copy=copy) | |||
@deprecate.function(message="use DataMetadata.replace_columns method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def replace_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return inputs_metadata.replace_columns(columns_metadata, column_indices) | |||
@deprecate.function(message="use DataMetadata.get_index_columns method instead") | |||
def get_index_columns(metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||
return metadata.get_index_columns(at=at) | |||
@deprecate.function(message="use DataFrame.horizontal_concat method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def horizontal_concat(left: container.DataFrame, right: container.DataFrame, *, use_index: bool = True, | |||
remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||
return left.horizontal_concat(right, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||
@deprecate.function(message="use DataMetadata.horizontal_concat method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def horizontal_concat_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, *, use_index: bool = True, | |||
remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return left_metadata.horizontal_concat(right_metadata, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||
@deprecate.function(message="use d3m.base.utils.get_columns_to_use function instead") | |||
def get_columns_to_use(metadata: metadata_base.DataMetadata, use_columns: typing.Sequence[int], exclude_columns: typing.Sequence[int], | |||
can_use_column: typing.Callable) -> typing.Tuple[typing.List[int], typing.List[int]]: | |||
return base_utils.get_columns_to_use(metadata, use_columns, exclude_columns, can_use_column) | |||
@deprecate.function(message="use d3m.base.utils.combine_columns function instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def combine_columns(return_result: str, add_index_columns: bool, inputs: container.DataFrame, column_indices: typing.Sequence[int], | |||
columns_list: typing.Sequence[container.DataFrame], *, source: typing.Any = None) -> container.DataFrame: | |||
return base_utils.combine_columns(inputs, column_indices, columns_list, return_result=return_result, add_index_columns=add_index_columns) | |||
@deprecate.function(message="use d3m.base.utils.combine_columns_metadata function instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def combine_columns_metadata(return_result: str, add_index_columns: bool, inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], | |||
columns_metadata_list: typing.Sequence[metadata_base.DataMetadata], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return base_utils.combine_columns_metadata(inputs_metadata, column_indices, columns_metadata_list, return_result=return_result, add_index_columns=add_index_columns) | |||
@deprecate.function(message="use DataMetadata.set_table_metadata method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def set_table_metadata(inputs_metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = (), source: typing.Any = None) -> metadata_base.DataMetadata: | |||
return inputs_metadata.set_table_metadata(at=at) | |||
@deprecate.function(message="use DataMetadata.get_column_index_from_column_name method instead") | |||
def get_column_index_from_column_name(inputs_metadata: metadata_base.DataMetadata, column_name: str, *, at: metadata_base.Selector = ()) -> int: | |||
return inputs_metadata.get_column_index_from_column_name(column_name, at=at) | |||
@deprecate.function(message="use Dataset.get_relations_graph method instead") | |||
def build_relation_graph(dataset: container.Dataset) -> typing.Dict[str, typing.List[typing.Tuple[str, bool, int, int, typing.Dict]]]: | |||
return dataset.get_relations_graph() | |||
@deprecate.function(message="use d3m.base.utils.get_tabular_resource function instead") | |||
def get_tabular_resource(dataset: container.Dataset, resource_id: typing.Optional[str], *, | |||
pick_entry_point: bool = True, pick_one: bool = True, has_hyperparameter: bool = True) -> typing.Tuple[str, container.DataFrame]: | |||
return base_utils.get_tabular_resource(dataset, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one, has_hyperparameter=has_hyperparameter) | |||
@deprecate.function(message="use d3m.base.utils.get_tabular_resource_metadata function instead") | |||
def get_tabular_resource_metadata(dataset_metadata: metadata_base.DataMetadata, resource_id: typing.Optional[metadata_base.SelectorSegment], *, | |||
pick_entry_point: bool = True, pick_one: bool = True) -> metadata_base.SelectorSegment: | |||
return base_utils.get_tabular_resource_metadata(dataset_metadata, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one) | |||
@deprecate.function(message="use Dataset.select_rows method instead") | |||
@deprecate.arguments('source', message="argument ignored") | |||
def cut_dataset(dataset: container.Dataset, row_indices_to_keep: typing.Mapping[str, typing.Sequence[int]], *, | |||
source: typing.Any = None) -> container.Dataset: | |||
return dataset.select_rows(row_indices_to_keep) | |||
def parse_datetime(value: str, *, fuzzy: bool = True) -> typing.Optional[datetime.datetime]: | |||
try: | |||
return dateutil.parser.parse(value, default=DEFAULT_DATETIME, fuzzy=fuzzy) | |||
except (ValueError, OverflowError, TypeError): | |||
return None | |||
def parse_datetime_to_float(value: str, *, fuzzy: bool = True) -> float: | |||
try: | |||
parsed = parse_datetime(value, fuzzy=fuzzy) | |||
if parsed is None: | |||
return numpy.nan | |||
else: | |||
return parsed.timestamp() | |||
except (ValueError, OverflowError, TypeError): | |||
return numpy.nan |