Former-commit-id:masterb0d35d9aea
[formerly28b7582d25
] [formerly632f8af13a
[formerly9eb3583256
]] [formerlyab6f5a8a9c
[formerly2720e70c78
] [formerlyc1f615a3c5
[formerly21840b3787
]]] [formerly2184b60ad6
[formerly2dae6c4ed4
] [formerly6a577865b8
[formerlyee28aaed84
]] [formerly72078d88c8
[formerlyab1fb9d512
] [formerly476600a47c
[formerlyf5a6205bb5
]]]] [formerlye5c2c3deef
[formerly8585dc41a9
] [formerly9cc7fe2088
[formerly1d2104316f
]] [formerlyb51e614a11
[formerlyf9891a191d
] [formerly048aa2f114
[formerlyd4de64574b
]]] [formerlyff66f55a51
[formerly0b691b9a7f
] [formerlye64e8dd253
[formerlycc45939e01
]] [formerlycdad10712a
[formerly2789e20e79
] [formerly25924f293c
[formerly32997accab
]]]]] [formerlyf43b431040
[formerly95815d02ca
] [formerlyfe9bd45d44
[formerly6daf0aa73e
]] [formerly61ab30c9a3
[formerlya13c6e23b4
] [formerly86fa5919ee
[formerly1e49e1a303
]]] [formerly69c5bc967a
[formerlyab82915cbd
] [formerlyf8057c3b14
[formerly5232f34578
]] [formerly671c54e952
[formerly6454a28f26
] [formerly3db6ff66b9
[formerlyaa8c7fe127
]]]] [formerly86b2ec6b84
[formerlyf35c344efe
] [formerlyd5616f66cd
[formerly98c9dca7da
]] [formerlya7dcc62bc5
[formerly4ef4fa0c98
] [formerly55f670b9ae
[formerly1cd4421e2e
]]] [formerlyd7a5bab832
[formerlyc77c5b48df
] [formerly01ffd33e2f
[formerlyaea728ceb6
]] [formerly16afb18e35
[formerly4768b156f3
] [formerly3c1298c626
[formerly1e61cf0974
]]]]]] Former-commit-id:28b09a56cb
[formerly5241dbb36c
] [formerlyd43909d979
[formerly43b0cca7f5
]] [formerly8d8f384c8e
[formerlybcf58203c6
] [formerlyca56bff2d0
[formerly7a8750ffc9
]]] [formerly2ce9fa87ae
[formerlye7b1b542f5
] [formerly62a7edf94a
[formerly26ca5f220d
]] [formerly73d10253b9
[formerlyaf463cecb0
] [formerly961314f474
[formerly6e7141e5e2
]]]] [formerlyc3f8938ba5
[formerly7f66748292
] [formerly3e5ef2c136
[formerly266ddb4ccc
]] [formerly4a8a5437b3
[formerlycbcb0f8777
] [formerlyc212b8217f
[formerly5f3d3d01c8
]]] [formerly09764ba6cd
[formerly4db991be5f
] [formerlyf79c6ec15d
[formerly7f8eb54d47
]] [formerly771b00b188
[formerlyf1dcba565f
] [formerly83c42510ad
[formerly3c1298c626
]]]]] Former-commit-id:f5cdcca4f3
[formerlydc57f947a2
] [formerly55cd6eb9a3
[formerlye92d6c0923
]] [formerlyba80ed43d2
[formerly0a2f65401a
] [formerly73c0c2ebb3
[formerlybf80cf285e
]]] [formerlya34f21d933
[formerlyae311cb3e7
] [formerlyae0e3ed079
[formerlyc9030d0303
]] [formerlyf8f4f6a8ec
[formerly343dd65df6
] [formerly1c97e6a7ba
[formerlyc69d281aca
]]]] Former-commit-id:b93de57240
[formerly6fd32d0759
] [formerlyff29e71cb0
[formerlyee483b56f9
]] [formerly2f944ec28b
[formerly15c46e806c
] [formerlydca0c18b8b
[formerlyf185a8658c
]]] Former-commit-id:76ed7f184f
[formerlyc8adbe1dea
] [formerly392fc3e54b
[formerly5bd396738e
]] Former-commit-id:11579edbe1
[formerly6970795314
] Former-commit-id:8fca90698b
@@ -0,0 +1,297 @@ | |||||
import os | |||||
import os.path | |||||
import hashlib | |||||
import gzip | |||||
import errno | |||||
import tarfile | |||||
from typing import Any, Callable, List, Iterable, Optional, TypeVar | |||||
import zipfile | |||||
# import torch | |||||
# from torch.utils.model_zoo import tqdm | |||||
from hub import tqdm | |||||
def gen_bar_updater() -> Callable[[int, int, int], None]: | |||||
pbar = tqdm(total=None) | |||||
def bar_update(count, block_size, total_size): | |||||
if pbar.total is None and total_size: | |||||
pbar.total = total_size | |||||
progress_bytes = count * block_size | |||||
pbar.update(progress_bytes - pbar.n) | |||||
return bar_update | |||||
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: | |||||
md5 = hashlib.md5() | |||||
with open(fpath, 'rb') as f: | |||||
for chunk in iter(lambda: f.read(chunk_size), b''): | |||||
md5.update(chunk) | |||||
return md5.hexdigest() | |||||
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: | |||||
return md5 == calculate_md5(fpath, **kwargs) | |||||
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: | |||||
if not os.path.isfile(fpath): | |||||
return False | |||||
if md5 is None: | |||||
return True | |||||
return check_md5(fpath, md5) | |||||
def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: | |||||
"""Download a file from a url and place it in root. | |||||
Args: | |||||
url (str): URL to download file from | |||||
root (str): Directory to place downloaded file in | |||||
filename (str, optional): Name to save the file under. If None, use the basename of the URL | |||||
md5 (str, optional): MD5 checksum of the download. If None, do not check | |||||
""" | |||||
import urllib | |||||
root = os.path.expanduser(root) | |||||
if not filename: | |||||
filename = os.path.basename(url) | |||||
fpath = os.path.join(root, filename) | |||||
os.makedirs(root, exist_ok=True) | |||||
# check if file is already present locally | |||||
if check_integrity(fpath, md5): | |||||
print('Using downloaded and verified file: ' + fpath) | |||||
else: # download the file | |||||
try: | |||||
print('Downloading ' + url + ' to ' + fpath) | |||||
urllib.request.urlretrieve( | |||||
url, fpath, | |||||
reporthook=gen_bar_updater() | |||||
) | |||||
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] | |||||
if url[:5] == 'https': | |||||
url = url.replace('https:', 'http:') | |||||
print('Failed download. Trying https -> http instead.' | |||||
' Downloading ' + url + ' to ' + fpath) | |||||
urllib.request.urlretrieve( | |||||
url, fpath, | |||||
reporthook=gen_bar_updater() | |||||
) | |||||
else: | |||||
raise e | |||||
# check integrity of downloaded file | |||||
if not check_integrity(fpath, md5): | |||||
raise RuntimeError("File not found or corrupted.") | |||||
def list_dir(root: str, prefix: bool = False) -> List[str]: | |||||
"""List all directories at a given root | |||||
Args: | |||||
root (str): Path to directory whose folders need to be listed | |||||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||||
only returns the name of the directories found | |||||
""" | |||||
root = os.path.expanduser(root) | |||||
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] | |||||
if prefix is True: | |||||
directories = [os.path.join(root, d) for d in directories] | |||||
return directories | |||||
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: | |||||
"""List all files ending with a suffix at a given root | |||||
Args: | |||||
root (str): Path to directory whose folders need to be listed | |||||
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). | |||||
It uses the Python "str.endswith" method and is passed directly | |||||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||||
only returns the name of the files found | |||||
""" | |||||
root = os.path.expanduser(root) | |||||
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] | |||||
if prefix is True: | |||||
files = [os.path.join(root, d) for d in files] | |||||
return files | |||||
def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined] | |||||
return "Google Drive - Quota exceeded" in response.text | |||||
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): | |||||
"""Download a Google Drive file from and place it in root. | |||||
Args: | |||||
file_id (str): id of file to be downloaded | |||||
root (str): Directory to place downloaded file in | |||||
filename (str, optional): Name to save the file under. If None, use the id of the file. | |||||
md5 (str, optional): MD5 checksum of the download. If None, do not check | |||||
""" | |||||
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url | |||||
import requests | |||||
url = "https://docs.google.com/uc?export=download" | |||||
root = os.path.expanduser(root) | |||||
if not filename: | |||||
filename = file_id | |||||
fpath = os.path.join(root, filename) | |||||
os.makedirs(root, exist_ok=True) | |||||
if os.path.isfile(fpath) and check_integrity(fpath, md5): | |||||
print('Using downloaded and verified file: ' + fpath) | |||||
else: | |||||
session = requests.Session() | |||||
response = session.get(url, params={'id': file_id}, stream=True) | |||||
token = _get_confirm_token(response) | |||||
if token: | |||||
params = {'id': file_id, 'confirm': token} | |||||
response = session.get(url, params=params, stream=True) | |||||
if _quota_exceeded(response): | |||||
msg = ( | |||||
f"The daily quota of the file {filename} is exceeded and it " | |||||
f"can't be downloaded. This is a limitation of Google Drive " | |||||
f"and can only be overcome by trying again later." | |||||
) | |||||
raise RuntimeError(msg) | |||||
_save_response_content(response, fpath) | |||||
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] | |||||
for key, value in response.cookies.items(): | |||||
if key.startswith('download_warning'): | |||||
return value | |||||
return None | |||||
def _save_response_content( | |||||
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined] | |||||
) -> None: | |||||
with open(destination, "wb") as f: | |||||
pbar = tqdm(total=None) | |||||
progress = 0 | |||||
for chunk in response.iter_content(chunk_size): | |||||
if chunk: # filter out keep-alive new chunks | |||||
f.write(chunk) | |||||
progress += len(chunk) | |||||
pbar.update(progress - pbar.n) | |||||
pbar.close() | |||||
def _is_tarxz(filename: str) -> bool: | |||||
return filename.endswith(".tar.xz") | |||||
def _is_tar(filename: str) -> bool: | |||||
return filename.endswith(".tar") | |||||
def _is_targz(filename: str) -> bool: | |||||
return filename.endswith(".tar.gz") | |||||
def _is_tgz(filename: str) -> bool: | |||||
return filename.endswith(".tgz") | |||||
def _is_gzip(filename: str) -> bool: | |||||
return filename.endswith(".gz") and not filename.endswith(".tar.gz") | |||||
def _is_zip(filename: str) -> bool: | |||||
return filename.endswith(".zip") | |||||
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: | |||||
if to_path is None: | |||||
to_path = os.path.dirname(from_path) | |||||
if _is_tar(from_path): | |||||
with tarfile.open(from_path, 'r') as tar: | |||||
tar.extractall(path=to_path) | |||||
elif _is_targz(from_path) or _is_tgz(from_path): | |||||
with tarfile.open(from_path, 'r:gz') as tar: | |||||
tar.extractall(path=to_path) | |||||
elif _is_tarxz(from_path): | |||||
with tarfile.open(from_path, 'r:xz') as tar: | |||||
tar.extractall(path=to_path) | |||||
elif _is_gzip(from_path): | |||||
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) | |||||
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: | |||||
out_f.write(zip_f.read()) | |||||
elif _is_zip(from_path): | |||||
with zipfile.ZipFile(from_path, 'r') as z: | |||||
z.extractall(to_path) | |||||
else: | |||||
raise ValueError("Extraction of {} not supported".format(from_path)) | |||||
if remove_finished: | |||||
os.remove(from_path) | |||||
def download_and_extract_archive( | |||||
url: str, | |||||
download_root: str, | |||||
extract_root: Optional[str] = None, | |||||
filename: Optional[str] = None, | |||||
md5: Optional[str] = None, | |||||
remove_finished: bool = False, | |||||
) -> None: | |||||
download_root = os.path.expanduser(download_root) | |||||
if extract_root is None: | |||||
extract_root = download_root | |||||
if not filename: | |||||
filename = os.path.basename(url) | |||||
download_url(url, download_root, filename, md5) | |||||
archive = os.path.join(download_root, filename) | |||||
print("Extracting {} to {}".format(archive, extract_root)) | |||||
# print(archive) | |||||
# print(extract_root) | |||||
# extract_archive(archive, extract_root, remove_finished) | |||||
def iterable_to_str(iterable: Iterable) -> str: | |||||
return "'" + "', '".join([str(item) for item in iterable]) + "'" | |||||
T = TypeVar("T", str, bytes) | |||||
def verify_str_arg( | |||||
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, | |||||
) -> T: | |||||
if not isinstance(value, torch._six.string_classes): | |||||
if arg is None: | |||||
msg = "Expected type str, but got type {type}." | |||||
else: | |||||
msg = "Expected type str for argument {arg}, but got type {type}." | |||||
msg = msg.format(type=type(value), arg=arg) | |||||
raise ValueError(msg) | |||||
if valid_values is None: | |||||
return value | |||||
if value not in valid_values: | |||||
if custom_msg is not None: | |||||
msg = custom_msg | |||||
else: | |||||
msg = ("Unknown value '{value}' for argument {arg}. " | |||||
"Valid values are {{{valid_values}}}.") | |||||
msg = msg.format(value=value, arg=arg, | |||||
valid_values=iterable_to_str(valid_values)) | |||||
raise ValueError(msg) | |||||
return value |
@@ -0,0 +1,559 @@ | |||||
import errno | |||||
import hashlib | |||||
import os | |||||
import re | |||||
import shutil | |||||
import sys | |||||
import tempfile | |||||
# import torch | |||||
import warnings | |||||
import zipfile | |||||
from urllib.request import urlopen, Request | |||||
from urllib.parse import urlparse # noqa: F401 | |||||
try: | |||||
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available | |||||
except ImportError: | |||||
try: | |||||
from tqdm import tqdm | |||||
except ImportError: | |||||
# fake tqdm if it's not installed | |||||
class tqdm(object): # type: ignore | |||||
def __init__(self, total=None, disable=False, | |||||
unit=None, unit_scale=None, unit_divisor=None): | |||||
self.total = total | |||||
self.disable = disable | |||||
self.n = 0 | |||||
# ignore unit, unit_scale, unit_divisor; they're just for real tqdm | |||||
def update(self, n): | |||||
if self.disable: | |||||
return | |||||
self.n += n | |||||
if self.total is None: | |||||
sys.stderr.write("\r{0:.1f} bytes".format(self.n)) | |||||
else: | |||||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||||
sys.stderr.flush() | |||||
def __enter__(self): | |||||
return self | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | |||||
if self.disable: | |||||
return | |||||
sys.stderr.write('\n') | |||||
# # matches bfd8deac from resnet18-bfd8deac.pth | |||||
# HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||||
# | |||||
# MASTER_BRANCH = 'master' | |||||
# ENV_TORCH_HOME = 'TORCH_HOME' | |||||
# ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' | |||||
# DEFAULT_CACHE_DIR = '~/.cache' | |||||
# VAR_DEPENDENCY = 'dependencies' | |||||
# MODULE_HUBCONF = 'hubconf.py' | |||||
# READ_DATA_CHUNK = 8192 | |||||
# _hub_dir = None | |||||
# | |||||
# | |||||
# # Copied from tools/shared/module_loader to be included in torch package | |||||
# def import_module(name, path): | |||||
# import importlib.util | |||||
# from importlib.abc import Loader | |||||
# spec = importlib.util.spec_from_file_location(name, path) | |||||
# module = importlib.util.module_from_spec(spec) | |||||
# assert isinstance(spec.loader, Loader) | |||||
# spec.loader.exec_module(module) | |||||
# return module | |||||
# | |||||
# | |||||
# def _remove_if_exists(path): | |||||
# if os.path.exists(path): | |||||
# if os.path.isfile(path): | |||||
# os.remove(path) | |||||
# else: | |||||
# shutil.rmtree(path) | |||||
# | |||||
# | |||||
# def _git_archive_link(repo_owner, repo_name, branch): | |||||
# return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch) | |||||
# | |||||
# | |||||
# def _load_attr_from_module(module, func_name): | |||||
# # Check if callable is defined in the module | |||||
# if func_name not in dir(module): | |||||
# return None | |||||
# return getattr(module, func_name) | |||||
# | |||||
# | |||||
# def _get_torch_home(): | |||||
# torch_home = os.path.expanduser( | |||||
# os.getenv(ENV_TORCH_HOME, | |||||
# os.path.join(os.getenv(ENV_XDG_CACHE_HOME, | |||||
# DEFAULT_CACHE_DIR), 'torch'))) | |||||
# return torch_home | |||||
# | |||||
# | |||||
# def _parse_repo_info(github): | |||||
# branch = MASTER_BRANCH | |||||
# if ':' in github: | |||||
# repo_info, branch = github.split(':') | |||||
# else: | |||||
# repo_info = github | |||||
# repo_owner, repo_name = repo_info.split('/') | |||||
# return repo_owner, repo_name, branch | |||||
# | |||||
# | |||||
# def _get_cache_or_reload(github, force_reload, verbose=True): | |||||
# # Setup hub_dir to save downloaded files | |||||
# hub_dir = get_dir() | |||||
# if not os.path.exists(hub_dir): | |||||
# os.makedirs(hub_dir) | |||||
# # Parse github repo information | |||||
# repo_owner, repo_name, branch = _parse_repo_info(github) | |||||
# # Github allows branch name with slash '/', | |||||
# # this causes confusion with path on both Linux and Windows. | |||||
# # Backslash is not allowed in Github branch name so no need to | |||||
# # to worry about it. | |||||
# normalized_br = branch.replace('/', '_') | |||||
# # Github renames folder repo-v1.x.x to repo-1.x.x | |||||
# # We don't know the repo name before downloading the zip file | |||||
# # and inspect name from it. | |||||
# # To check if cached repo exists, we need to normalize folder names. | |||||
# repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br])) | |||||
# | |||||
# use_cache = (not force_reload) and os.path.exists(repo_dir) | |||||
# | |||||
# if use_cache: | |||||
# if verbose: | |||||
# sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) | |||||
# else: | |||||
# cached_file = os.path.join(hub_dir, normalized_br + '.zip') | |||||
# _remove_if_exists(cached_file) | |||||
# | |||||
# url = _git_archive_link(repo_owner, repo_name, branch) | |||||
# sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file)) | |||||
# download_url_to_file(url, cached_file, progress=False) | |||||
# | |||||
# with zipfile.ZipFile(cached_file) as cached_zipfile: | |||||
# extraced_repo_name = cached_zipfile.infolist()[0].filename | |||||
# extracted_repo = os.path.join(hub_dir, extraced_repo_name) | |||||
# _remove_if_exists(extracted_repo) | |||||
# # Unzip the code and rename the base folder | |||||
# cached_zipfile.extractall(hub_dir) | |||||
# | |||||
# _remove_if_exists(cached_file) | |||||
# _remove_if_exists(repo_dir) | |||||
# shutil.move(extracted_repo, repo_dir) # rename the repo | |||||
# | |||||
# return repo_dir | |||||
# | |||||
# | |||||
# def _check_module_exists(name): | |||||
# if sys.version_info >= (3, 4): | |||||
# import importlib.util | |||||
# return importlib.util.find_spec(name) is not None | |||||
# elif sys.version_info >= (3, 3): | |||||
# # Special case for python3.3 | |||||
# import importlib.find_loader | |||||
# return importlib.find_loader(name) is not None | |||||
# else: | |||||
# # NB: Python2.7 imp.find_module() doesn't respect PEP 302, | |||||
# # it cannot find a package installed as .egg(zip) file. | |||||
# # Here we use workaround from: | |||||
# # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1 | |||||
# # Also imp doesn't handle hierarchical module names (names contains dots). | |||||
# try: | |||||
# # 1. Try imp.find_module(), which searches sys.path, but does | |||||
# # not respect PEP 302 import hooks. | |||||
# import imp | |||||
# result = imp.find_module(name) | |||||
# if result: | |||||
# return True | |||||
# except ImportError: | |||||
# pass | |||||
# path = sys.path | |||||
# for item in path: | |||||
# # 2. Scan path for import hooks. sys.path_importer_cache maps | |||||
# # path items to optional "importer" objects, that implement | |||||
# # find_module() etc. Note that path must be a subset of | |||||
# # sys.path for this to work. | |||||
# importer = sys.path_importer_cache.get(item) | |||||
# if importer: | |||||
# try: | |||||
# result = importer.find_module(name, [item]) | |||||
# if result: | |||||
# return True | |||||
# except ImportError: | |||||
# pass | |||||
# return False | |||||
# | |||||
# def _check_dependencies(m): | |||||
# dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) | |||||
# | |||||
# if dependencies is not None: | |||||
# missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] | |||||
# if len(missing_deps): | |||||
# raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) | |||||
# | |||||
# | |||||
# def _load_entry_from_hubconf(m, model): | |||||
# if not isinstance(model, str): | |||||
# raise ValueError('Invalid input: model should be a string of function name') | |||||
# | |||||
# # Note that if a missing dependency is imported at top level of hubconf, it will | |||||
# # throw before this function. It's a chicken and egg situation where we have to | |||||
# # load hubconf to know what're the dependencies, but to import hubconf it requires | |||||
# # a missing package. This is fine, Python will throw proper error message for users. | |||||
# _check_dependencies(m) | |||||
# | |||||
# func = _load_attr_from_module(m, model) | |||||
# | |||||
# if func is None or not callable(func): | |||||
# raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) | |||||
# | |||||
# return func | |||||
# | |||||
# | |||||
# def get_dir(): | |||||
# r""" | |||||
# Get the Torch Hub cache directory used for storing downloaded models & weights. | |||||
# | |||||
# If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where | |||||
# environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. | |||||
# ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux | |||||
# filesystem layout, with a default value ``~/.cache`` if the environment | |||||
# variable is not set. | |||||
# """ | |||||
# # Issue warning to move data if old env is set | |||||
# if os.getenv('TORCH_HUB'): | |||||
# warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') | |||||
# | |||||
# if _hub_dir is not None: | |||||
# return _hub_dir | |||||
# return os.path.join(_get_torch_home(), 'hub') | |||||
# | |||||
# | |||||
# def set_dir(d): | |||||
# r""" | |||||
# Optionally set the Torch Hub directory used to save downloaded models & weights. | |||||
# | |||||
# Args: | |||||
# d (string): path to a local folder to save downloaded models & weights. | |||||
# """ | |||||
# global _hub_dir | |||||
# _hub_dir = d | |||||
# | |||||
# | |||||
# def list(github, force_reload=False): | |||||
# r""" | |||||
# List all entrypoints available in `github` hubconf. | |||||
# | |||||
# Args: | |||||
# github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional | |||||
# tag/branch. The default branch is `master` if not specified. | |||||
# Example: 'pytorch/vision[:hub]' | |||||
# force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||||
# Default is `False`. | |||||
# Returns: | |||||
# entrypoints: a list of available entrypoint names | |||||
# | |||||
# Example: | |||||
# >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) | |||||
# """ | |||||
# repo_dir = _get_cache_or_reload(github, force_reload, True) | |||||
# | |||||
# sys.path.insert(0, repo_dir) | |||||
# | |||||
# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||||
# | |||||
# sys.path.remove(repo_dir) | |||||
# | |||||
# # We take functions starts with '_' as internal helper functions | |||||
# entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] | |||||
# | |||||
# return entrypoints | |||||
# | |||||
# | |||||
# def help(github, model, force_reload=False): | |||||
# r""" | |||||
# Show the docstring of entrypoint `model`. | |||||
# | |||||
# Args: | |||||
# github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional | |||||
# tag/branch. The default branch is `master` if not specified. | |||||
# Example: 'pytorch/vision[:hub]' | |||||
# model (string): a string of entrypoint name defined in repo's hubconf.py | |||||
# force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||||
# Default is `False`. | |||||
# Example: | |||||
# >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) | |||||
# """ | |||||
# repo_dir = _get_cache_or_reload(github, force_reload, True) | |||||
# | |||||
# sys.path.insert(0, repo_dir) | |||||
# | |||||
# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||||
# | |||||
# sys.path.remove(repo_dir) | |||||
# | |||||
# entry = _load_entry_from_hubconf(hub_module, model) | |||||
# | |||||
# return entry.__doc__ | |||||
# | |||||
# | |||||
# # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, | |||||
# # but Python2 complains syntax error for it. We have to skip force_reload in function | |||||
# # signature here but detect it in kwargs instead. | |||||
# # TODO: fix it after Python2 EOL | |||||
# def load(repo_or_dir, model, *args, **kwargs): | |||||
# r""" | |||||
# Load a model from a github repo or a local directory. | |||||
# | |||||
# Note: Loading a model is the typical use case, but this can also be used to | |||||
# for loading other objects such as tokenizers, loss functions, etc. | |||||
# | |||||
# If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be | |||||
# of the form ``repo_owner/repo_name[:tag_name]`` with an optional | |||||
# tag/branch. | |||||
# | |||||
# If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a | |||||
# path to a local directory. | |||||
# | |||||
# Args: | |||||
# repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``), | |||||
# if ``source = 'github'``; or a path to a local directory, if | |||||
# ``source = 'local'``. | |||||
# model (string): the name of a callable (entrypoint) defined in the | |||||
# repo/dir's ``hubconf.py``. | |||||
# *args (optional): the corresponding args for callable :attr:`model`. | |||||
# source (string, optional): ``'github'`` | ``'local'``. Specifies how | |||||
# ``repo_or_dir`` is to be interpreted. Default is ``'github'``. | |||||
# force_reload (bool, optional): whether to force a fresh download of | |||||
# the github repo unconditionally. Does not have any effect if | |||||
# ``source = 'local'``. Default is ``False``. | |||||
# verbose (bool, optional): If ``False``, mute messages about hitting | |||||
# local caches. Note that the message about first download cannot be | |||||
# muted. Does not have any effect if ``source = 'local'``. | |||||
# Default is ``True``. | |||||
# **kwargs (optional): the corresponding kwargs for callable | |||||
# :attr:`model`. | |||||
# | |||||
# Returns: | |||||
# The output of the :attr:`model` callable when called with the given | |||||
# ``*args`` and ``**kwargs``. | |||||
# | |||||
# Example: | |||||
# >>> # from a github repo | |||||
# >>> repo = 'pytorch/vision' | |||||
# >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) | |||||
# >>> # from a local directory | |||||
# >>> path = '/some/local/path/pytorch/vision' | |||||
# >>> model = torch.hub.load(path, 'resnet50', pretrained=True) | |||||
# """ | |||||
# source = kwargs.pop('source', 'github').lower() | |||||
# force_reload = kwargs.pop('force_reload', False) | |||||
# verbose = kwargs.pop('verbose', True) | |||||
# | |||||
# if source not in ('github', 'local'): | |||||
# raise ValueError( | |||||
# f'Unknown source: "{source}". Allowed values: "github" | "local".') | |||||
# | |||||
# if source == 'github': | |||||
# repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose) | |||||
# | |||||
# model = _load_local(repo_or_dir, model, *args, **kwargs) | |||||
# return model | |||||
# | |||||
# | |||||
# def _load_local(hubconf_dir, model, *args, **kwargs): | |||||
# r""" | |||||
# Load a model from a local directory with a ``hubconf.py``. | |||||
# | |||||
# Args: | |||||
# hubconf_dir (string): path to a local directory that contains a | |||||
# ``hubconf.py``. | |||||
# model (string): name of an entrypoint defined in the directory's | |||||
# `hubconf.py`. | |||||
# *args (optional): the corresponding args for callable ``model``. | |||||
# **kwargs (optional): the corresponding kwargs for callable ``model``. | |||||
# | |||||
# Returns: | |||||
# a single model with corresponding pretrained weights. | |||||
# | |||||
# Example: | |||||
# >>> path = '/some/local/path/pytorch/vision' | |||||
# >>> model = _load_local(path, 'resnet50', pretrained=True) | |||||
# """ | |||||
# sys.path.insert(0, hubconf_dir) | |||||
# | |||||
# hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) | |||||
# hub_module = import_module(MODULE_HUBCONF, hubconf_path) | |||||
# | |||||
# entry = _load_entry_from_hubconf(hub_module, model) | |||||
# model = entry(*args, **kwargs) | |||||
# | |||||
# sys.path.remove(hubconf_dir) | |||||
# | |||||
# return model | |||||
# | |||||
# | |||||
# def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||||
# r"""Download object at the given URL to a local path. | |||||
# | |||||
# Args: | |||||
# url (string): URL of the object to download | |||||
# dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` | |||||
# hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. | |||||
# Default: None | |||||
# progress (bool, optional): whether or not to display a progress bar to stderr | |||||
# Default: True | |||||
# | |||||
# Example: | |||||
# >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') | |||||
# | |||||
# """ | |||||
# file_size = None | |||||
# # We use a different API for python2 since urllib(2) doesn't recognize the CA | |||||
# # certificates in older Python | |||||
# req = Request(url, headers={"User-Agent": "torch.hub"}) | |||||
# u = urlopen(req) | |||||
# meta = u.info() | |||||
# if hasattr(meta, 'getheaders'): | |||||
# content_length = meta.getheaders("Content-Length") | |||||
# else: | |||||
# content_length = meta.get_all("Content-Length") | |||||
# if content_length is not None and len(content_length) > 0: | |||||
# file_size = int(content_length[0]) | |||||
# | |||||
# # We deliberately save it in a temp file and move it after | |||||
# # download is complete. This prevents a local working checkpoint | |||||
# # being overridden by a broken download. | |||||
# dst = os.path.expanduser(dst) | |||||
# dst_dir = os.path.dirname(dst) | |||||
# f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |||||
# | |||||
# try: | |||||
# if hash_prefix is not None: | |||||
# sha256 = hashlib.sha256() | |||||
# with tqdm(total=file_size, disable=not progress, | |||||
# unit='B', unit_scale=True, unit_divisor=1024) as pbar: | |||||
# while True: | |||||
# buffer = u.read(8192) | |||||
# if len(buffer) == 0: | |||||
# break | |||||
# f.write(buffer) | |||||
# if hash_prefix is not None: | |||||
# sha256.update(buffer) | |||||
# pbar.update(len(buffer)) | |||||
# | |||||
# f.close() | |||||
# if hash_prefix is not None: | |||||
# digest = sha256.hexdigest() | |||||
# if digest[:len(hash_prefix)] != hash_prefix: | |||||
# raise RuntimeError('invalid hash value (expected "{}", got "{}")' | |||||
# .format(hash_prefix, digest)) | |||||
# shutil.move(f.name, dst) | |||||
# finally: | |||||
# f.close() | |||||
# if os.path.exists(f.name): | |||||
# os.remove(f.name) | |||||
# | |||||
# def _download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||||
# warnings.warn('torch.hub._download_url_to_file has been renamed to\ | |||||
# torch.hub.download_url_to_file to be a public API,\ | |||||
# _download_url_to_file will be removed in after 1.3 release') | |||||
# download_url_to_file(url, dst, hash_prefix, progress) | |||||
# | |||||
# # Hub used to support automatically extracts from zipfile manually compressed by users. | |||||
# # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. | |||||
# # We should remove this support since zipfile is now default zipfile format for torch.save(). | |||||
# def _is_legacy_zip_format(filename): | |||||
# if zipfile.is_zipfile(filename): | |||||
# infolist = zipfile.ZipFile(filename).infolist() | |||||
# return len(infolist) == 1 and not infolist[0].is_dir() | |||||
# return False | |||||
# | |||||
# def _legacy_zip_load(filename, model_dir, map_location): | |||||
# warnings.warn('Falling back to the old format < 1.6. This support will be ' | |||||
# 'deprecated in favor of default zipfile format introduced in 1.6. ' | |||||
# 'Please redo torch.save() to save it in the new zipfile format.') | |||||
# # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. | |||||
# # We deliberately don't handle tarfile here since our legacy serialization format was in tar. | |||||
# # E.g. resnet18-5c106cde.pth which is widely used. | |||||
# with zipfile.ZipFile(filename) as f: | |||||
# members = f.infolist() | |||||
# if len(members) != 1: | |||||
# raise RuntimeError('Only one file(not dir) is allowed in the zipfile') | |||||
# f.extractall(model_dir) | |||||
# extraced_name = members[0].filename | |||||
# extracted_file = os.path.join(model_dir, extraced_name) | |||||
# return torch.load(extracted_file, map_location=map_location) | |||||
# | |||||
# def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): | |||||
# r"""Loads the Torch serialized object at the given URL. | |||||
# | |||||
# If downloaded file is a zip file, it will be automatically | |||||
# decompressed. | |||||
# | |||||
# If the object is already present in `model_dir`, it's deserialized and | |||||
# returned. | |||||
# The default value of `model_dir` is ``<hub_dir>/checkpoints`` where | |||||
# `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. | |||||
# | |||||
# Args: | |||||
# url (string): URL of the object to download | |||||
# model_dir (string, optional): directory in which to save the object | |||||
# map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |||||
# progress (bool, optional): whether or not to display a progress bar to stderr. | |||||
# Default: True | |||||
# check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention | |||||
# ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||||
# digits of the SHA256 hash of the contents of the file. The hash is used to | |||||
# ensure unique names and to verify the contents of the file. | |||||
# Default: False | |||||
# file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. | |||||
# | |||||
# Example: | |||||
# >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |||||
# | |||||
# """ | |||||
# # Issue warning to move data if old env is set | |||||
# if os.getenv('TORCH_MODEL_ZOO'): | |||||
# warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') | |||||
# | |||||
# if model_dir is None: | |||||
# hub_dir = get_dir() | |||||
# model_dir = os.path.join(hub_dir, 'checkpoints') | |||||
# | |||||
# try: | |||||
# os.makedirs(model_dir) | |||||
# except OSError as e: | |||||
# if e.errno == errno.EEXIST: | |||||
# # Directory already exists, ignore. | |||||
# pass | |||||
# else: | |||||
# # Unexpected OSError, re-raise. | |||||
# raise | |||||
# | |||||
# parts = urlparse(url) | |||||
# filename = os.path.basename(parts.path) | |||||
# if file_name is not None: | |||||
# filename = file_name | |||||
# cached_file = os.path.join(model_dir, filename) | |||||
# if not os.path.exists(cached_file): | |||||
# sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||||
# hash_prefix = None | |||||
# if check_hash: | |||||
# r = HASH_REGEX.search(filename) # r is Optional[Match[str]] | |||||
# hash_prefix = r.group(1) if r else None | |||||
# download_url_to_file(url, cached_file, hash_prefix, progress=progress) | |||||
# | |||||
# if _is_legacy_zip_format(cached_file): | |||||
# return _legacy_zip_load(cached_file, model_dir, map_location) | |||||
# return torch.load(cached_file, map_location=map_location) |
@@ -0,0 +1,139 @@ | |||||
import warnings | |||||
import os | |||||
import os.path | |||||
import numpy as np | |||||
import codecs | |||||
import string | |||||
import gzip | |||||
import lzma | |||||
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union | |||||
from dataset_utils import download_url, download_and_extract_archive, extract_archive, verify_str_arg | |||||
# tqdm >= 4.31.1 | |||||
from tods import generate_dataset | |||||
from sklearn import preprocessing | |||||
import pandas as pd | |||||
class TODS_dataset: | |||||
resources = [] | |||||
training_file = None | |||||
testing_file = None | |||||
ground_truth_index = None | |||||
_repr_indent = None | |||||
@property | |||||
def raw_folder(self) -> str: | |||||
return os.path.join(self.root, self.__class__.__name__, 'raw') | |||||
@property | |||||
def processed_folder(self) -> str: | |||||
return os.path.join(self.root, self.__class__.__name__, 'processed') | |||||
def __init__(self, root, train, transform=None, download=True): | |||||
self.root = root | |||||
self.train = train | |||||
self.transform = self.transform_init(transform) | |||||
if download: | |||||
self.download() | |||||
pass | |||||
self.process() | |||||
def _check_exists(self) -> bool: | |||||
return (os.path.exists(os.path.join(self.processed_folder, | |||||
self.training_file)) and | |||||
os.path.exists(os.path.join(self.processed_folder, | |||||
self.testing_file))) | |||||
def download(self) -> None: | |||||
if self._check_exists(): | |||||
return | |||||
os.makedirs(self.raw_folder, exist_ok=True) | |||||
# download files | |||||
for url, md5 in self.resources: | |||||
filename = url.rpartition('/')[2] | |||||
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) | |||||
def process(self) -> None: | |||||
pass | |||||
def process_dataframe(self) -> None: | |||||
if self.transform is None: | |||||
pass | |||||
else: | |||||
self.transform.fit(self.training_set_dataframe) | |||||
self.training_set_array = self.transform.transform(self.training_set_dataframe.values) | |||||
self.testing_set_array = self.transform.transform(self.testing_set_dataframe.values) | |||||
self.training_set_dataframe = pd.DataFrame(self.training_set_array) | |||||
self.testing_set_dataframe = pd.DataFrame(self.testing_set_array) | |||||
def transform_init(self, transform_str): | |||||
if transform_str is None: | |||||
return None | |||||
elif transform_str == 'standardscale': | |||||
return preprocessing.StandardScaler() | |||||
elif transform_str == 'normalize': | |||||
return preprocessing.Normalizer() | |||||
elif transform_str == 'minmaxscale': | |||||
return preprocessing.MinMaxScaler() | |||||
elif transform_str == 'maxabsscale': | |||||
return preprocessing.MaxAbsScaler() | |||||
elif transform_str == 'binarize': | |||||
return preprocessing.Binarizer() | |||||
else: | |||||
raise ValueError("Input parameter transform must take value of 'standardscale', 'normalize', " + | |||||
"'minmaxscale', 'maxabsscale' or 'binarize'." | |||||
) | |||||
def to_axolotl_dataset(self): | |||||
if self.train: | |||||
return generate_dataset(self.training_set_dataframe, self.ground_truth_index) | |||||
else: | |||||
return generate_dataset(self.testing_set_dataframe, self.ground_truth_index) | |||||
def __repr__(self) -> str: | |||||
head = "Dataset " + self.__class__.__name__ | |||||
body = ["Number of datapoints: {}".format(self.__len__())] | |||||
if self.root is not None: | |||||
body.append("Root location: {}".format(self.root)) | |||||
body += self.extra_repr().splitlines() | |||||
if hasattr(self, "transforms") and self.transforms is not None: | |||||
body += [repr(self.transforms)] | |||||
lines = [head] + [" " * self._repr_indent + line for line in body] | |||||
print(self.training_set_dataframe) | |||||
return '\n'.join(lines) | |||||
def __len__(self) -> int: | |||||
return len(self.training_set_dataframe) | |||||
def extra_repr(self) -> str: | |||||
return "" | |||||
# kpi(root='./datasets', train=True) | |||||
# class yahoo5: | |||||
# | |||||
# def __init__(self): | |||||
# pass |
@@ -0,0 +1,116 @@ | |||||
import os | |||||
import pandas as pd | |||||
from tods_dataset_base import TODS_dataset | |||||
from shutil import copyfile | |||||
class kpi_dataset(TODS_dataset): | |||||
resources = [ | |||||
# ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||||
# ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||||
# ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||||
("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||||
# needs a server to store the dataset. | |||||
# ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
] | |||||
training_file = 'learningData.csv' | |||||
testing_file = 'testingData.csv' | |||||
ground_truth_index = 3 | |||||
_repr_indent = 4 | |||||
# def __init__(self, root, train, transform=None, target_transform=None, download=True): | |||||
# super().__init__(root, train, transform=None, target_transform=None, download=True) | |||||
def process(self) -> None: | |||||
print('Processing...') | |||||
os.makedirs(self.processed_folder, exist_ok=True) | |||||
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||||
training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||||
self.training_set_dataframe = pd.read_csv(training_set_fname) | |||||
testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||||
self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||||
self.process_dataframe() | |||||
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||||
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||||
copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||||
print('Done!') | |||||
class yahoo_dataset(TODS_dataset): | |||||
resources = [ | |||||
# ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||||
# ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||||
# ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||||
("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||||
# needs a server to store the dataset. | |||||
# ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
] | |||||
training_file = 'learningData.csv' | |||||
testing_file = 'testingData.csv' | |||||
ground_truth_index = 7 | |||||
_repr_indent = 4 | |||||
def process(self) -> None: | |||||
print('Processing...') | |||||
os.makedirs(self.processed_folder, exist_ok=True) | |||||
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||||
training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||||
self.training_set_dataframe = pd.read_csv(training_set_fname) | |||||
testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||||
self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||||
self.process_dataframe() | |||||
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||||
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||||
copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||||
print('Done!') | |||||
class NAB_dataset(TODS_dataset): | |||||
resources = [ | |||||
("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.csv", None), | |||||
# it needs md5 to check if local learningData.csv is the same with online. | |||||
("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.json", None), | |||||
# needs a server to store the dataset. | |||||
] | |||||
training_file = 'learningData.csv' | |||||
testing_file = 'testingData.csv' | |||||
ground_truth_index = 2 | |||||
_repr_indent = 4 | |||||
def process(self) -> None: | |||||
print('Processing...') | |||||
os.makedirs(self.processed_folder, exist_ok=True) | |||||
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||||
training_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') | |||||
self.training_set_dataframe = pd.read_csv(training_set_fname) | |||||
testing_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') # temperarily same with training set | |||||
self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||||
self.process_dataframe() | |||||
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||||
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||||
copyfile(os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.json'), | |||||
os.path.join(self.processed_folder, 'datasetDoc.json')) | |||||
print('Done!') | |||||
# kpi_dataset(root='./datasets', train=True, transform='binarize') | |||||
# yahoo_dataset(root='./datasets', train=True, transform='binarize') | |||||
# NAB_dataset(root='./datasets', train=True, transform='binarize') |
@@ -1 +0,0 @@ | |||||
Subproject commit af54e6970476a081bf0cd65990c9f56a1200d8a2 |
@@ -1 +0,0 @@ | |||||
Subproject commit 046b20d2f6d4543dcbe18f0a1d4bcbb1f61cf518 |
@@ -1 +0,0 @@ | |||||
Subproject commit 70aeefed6b7307941581357c4b7858bb3f88e1da |
@@ -0,0 +1,116 @@ | |||||
import os | |||||
import typing | |||||
import numpy | |||||
import pandas | |||||
from d3m import container, exceptions, utils as d3m_utils | |||||
from d3m.metadata import base as metadata_base, hyperparams | |||||
from d3m.base import primitives | |||||
__all__ = ('FixedSplitDatasetSplitPrimitive',) | |||||
class Hyperparams(hyperparams.Hyperparams): | |||||
primary_index_values = hyperparams.Set( | |||||
elements=hyperparams.Hyperparameter[str](''), | |||||
default=(), | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description='A set of primary index values of the main resource belonging to the test (score) split. Cannot be set together with "row_indices".', | |||||
) | |||||
row_indices = hyperparams.Set( | |||||
elements=hyperparams.Hyperparameter[int](-1), | |||||
default=(), | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description='A set of row indices of the main resource belonging to the test (score) split. Cannot be set together with "primary_index_values".', | |||||
) | |||||
delete_recursive = hyperparams.Hyperparameter[bool]( | |||||
default=False, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||||
) | |||||
class FixedSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
""" | |||||
A primitive which splits a tabular Dataset in a way that uses for the test | |||||
(score) split a fixed list of primary index values or row indices of the main | |||||
resource to be used. All other rows are added used for the train split. | |||||
""" | |||||
metadata = metadata_base.PrimitiveMetadata( | |||||
{ | |||||
'id': '1654f000-2178-4520-be4c-a95bc26b8d3a', | |||||
'version': '0.1.0', | |||||
'name': "Fixed split tabular dataset splits", | |||||
'python_path': 'd3m.primitives.tods.evaluation.fixed_split_dataset_split', | |||||
'source': { | |||||
'name': "DATALab@TexasA&M University", | |||||
'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||||
'uris': [ | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/fixed_split.py', | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
], | |||||
}, | |||||
'algorithm_types': [ | |||||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
], | |||||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
}, | |||||
) | |||||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
# This should be handled by "Set" hyper-parameter, but we check it here again just to be sure. | |||||
if d3m_utils.has_duplicates(self.hyperparams['primary_index_values']): | |||||
raise exceptions.InvalidArgumentValueError("\"primary_index_values\" hyper-parameter has duplicate values.") | |||||
if d3m_utils.has_duplicates(self.hyperparams['row_indices']): | |||||
raise exceptions.InvalidArgumentValueError("\"row_indices\" hyper-parameter has duplicate values.") | |||||
if self.hyperparams['primary_index_values'] and self.hyperparams['row_indices']: | |||||
raise exceptions.InvalidArgumentValueError("Both \"primary_index_values\" and \"row_indices\" cannot be provided.") | |||||
if self.hyperparams['primary_index_values']: | |||||
primary_index_values = numpy.array(self.hyperparams['primary_index_values']) | |||||
index_columns = dataset.metadata.get_index_columns(at=(main_resource_id,)) | |||||
if not index_columns: | |||||
raise exceptions.InvalidArgumentValueError("Cannot find index columns in the main resource of the dataset, but \"primary_index_values\" is provided.") | |||||
main_resource = dataset[main_resource_id] | |||||
# We reset the index so that the index corresponds to row indices. | |||||
main_resource = main_resource.reset_index(drop=True) | |||||
# We use just the "d3mIndex" column and ignore multi-key indices. | |||||
# This works for now because it seems that every current multi-key | |||||
# dataset in fact has an unique value in "d3mIndex" alone. | |||||
# See: https://gitlab.datadrivendiscovery.org/MIT-LL/d3m_data_supply/issues/117 | |||||
index_column = index_columns[0] | |||||
score_data = numpy.array(main_resource.loc[main_resource.iloc[:, index_column].isin(primary_index_values)].index) | |||||
score_data_set = set(score_data) | |||||
assert len(score_data) == len(score_data_set), (len(score_data), len(score_data_set)) | |||||
if len(score_data) != len(primary_index_values): | |||||
raise exceptions.InvalidArgumentValueError("\"primary_index_values\" contains values which do not exist.") | |||||
else: | |||||
score_data = numpy.array(self.hyperparams['row_indices']) | |||||
score_data_set = set(score_data) | |||||
all_data_set = set(numpy.arange(len(attributes))) | |||||
if not score_data_set <= all_data_set: | |||||
raise exceptions.InvalidArgumentValueError("\"row_indices\" contains indices which do not exist, e.g., {indices}.".format( | |||||
indices=sorted(score_data_set - all_data_set)[:5], | |||||
)) | |||||
train_data = [] | |||||
for i in numpy.arange(len(attributes)): | |||||
if i not in score_data_set: | |||||
train_data.append(i) | |||||
assert len(train_data) + len(score_data) == len(attributes), (len(train_data), len(score_data), len(attributes)) | |||||
return [(numpy.array(train_data), score_data)] |
@@ -0,0 +1,87 @@ | |||||
import os | |||||
import typing | |||||
import numpy | |||||
import pandas | |||||
from sklearn import model_selection | |||||
from d3m import container, exceptions, utils as d3m_utils | |||||
from d3m.metadata import base as metadata_base, hyperparams | |||||
from d3m.base import primitives | |||||
__all__ = ('KFoldDatasetSplitPrimitive',) | |||||
class Hyperparams(hyperparams.Hyperparams): | |||||
number_of_folds = hyperparams.Bounded[int]( | |||||
lower=2, | |||||
upper=None, | |||||
default=5, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Number of folds for k-folds cross-validation.", | |||||
) | |||||
stratified = hyperparams.UniformBool( | |||||
default=False, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||||
) | |||||
shuffle = hyperparams.UniformBool( | |||||
default=False, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Whether to shuffle the data before splitting into batches.", | |||||
) | |||||
delete_recursive = hyperparams.Hyperparameter[bool]( | |||||
default=False, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||||
) | |||||
class KFoldDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
""" | |||||
A primitive which splits a tabular Dataset for k-fold cross-validation. | |||||
""" | |||||
__author__ = 'Mingjie Sun <sunmj15@gmail.com>' | |||||
metadata = metadata_base.PrimitiveMetadata( | |||||
{ | |||||
'id': 'bfedaf3a-6dd0-4a83-ad83-3a50fe882bf8', | |||||
'version': '0.1.0', | |||||
'name': "K-fold cross-validation tabular dataset splits", | |||||
'python_path': 'd3m.primitives.tods.evaluation.kfold_dataset_split', | |||||
'source': { | |||||
'name': 'DATALab@Texas A&M University', | |||||
'contact': 'mailto:sunmj15@gmail.com', | |||||
'uris': [ | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split.py', | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
], | |||||
}, | |||||
'algorithm_types': [ | |||||
metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||||
metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
], | |||||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
}, | |||||
) | |||||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
if self.hyperparams['stratified']: | |||||
if not len(targets.columns): | |||||
raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||||
k_fold = model_selection.StratifiedKFold( | |||||
n_splits=self.hyperparams['number_of_folds'], | |||||
shuffle=self.hyperparams['shuffle'], | |||||
random_state=self._random_state, | |||||
) | |||||
else: | |||||
k_fold = model_selection.KFold( | |||||
n_splits=self.hyperparams['number_of_folds'], | |||||
shuffle=self.hyperparams['shuffle'], | |||||
random_state=self._random_state, | |||||
) | |||||
return list(k_fold.split(attributes, targets)) |
@@ -0,0 +1,187 @@ | |||||
import os | |||||
import uuid | |||||
import typing | |||||
from collections import OrderedDict | |||||
import numpy | |||||
import pandas | |||||
from sklearn import model_selection | |||||
from d3m import container, exceptions, utils as d3m_utils | |||||
from d3m.metadata import base as metadata_base, hyperparams | |||||
from d3m.base import primitives | |||||
import utils | |||||
__all__ = ('KFoldTimeSeriesSplitPrimitive',) | |||||
class Hyperparams(hyperparams.Hyperparams): | |||||
number_of_folds = hyperparams.Bounded[int]( | |||||
lower=2, | |||||
upper=None, | |||||
default=5, | |||||
semantic_types=[ | |||||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
], | |||||
description="Number of folds for k-folds cross-validation.", | |||||
) | |||||
number_of_window_folds = hyperparams.Union[typing.Union[int, None]]( | |||||
configuration=OrderedDict( | |||||
fixed=hyperparams.Bounded[int]( | |||||
lower=1, | |||||
upper=None, | |||||
default=1, | |||||
description="Number of folds in train set (window). These folds come directly " | |||||
"before test set (streaming window).", | |||||
), | |||||
all_records=hyperparams.Constant( | |||||
default=None, | |||||
description="Number of folds in train set (window) = maximum number possible.", | |||||
), | |||||
), | |||||
default='all_records', | |||||
semantic_types=[ | |||||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
], | |||||
description="Maximum size for a single training set.", | |||||
) | |||||
time_column_index = hyperparams.Union[typing.Union[int, None]]( | |||||
configuration=OrderedDict( | |||||
fixed=hyperparams.Bounded[int]( | |||||
lower=1, | |||||
upper=None, | |||||
default=1, | |||||
description="Specific column that contains the time index", | |||||
), | |||||
one_column=hyperparams.Constant( | |||||
default=None, | |||||
description="Only one column contains a time index. " | |||||
"It is detected automatically using semantic types.", | |||||
), | |||||
), | |||||
default='one_column', | |||||
semantic_types=[ | |||||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
], | |||||
description="Column index to use as datetime index. " | |||||
"If None, it is required that only one column with time column role semantic type is " | |||||
"present and otherwise an exception is raised. " | |||||
"If column index specified is not a datetime column an exception is" | |||||
"also raised.", | |||||
) | |||||
fuzzy_time_parsing = hyperparams.UniformBool( | |||||
default=True, | |||||
semantic_types=[ | |||||
'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
], | |||||
description="Use fuzzy time parsing.", | |||||
) | |||||
class KFoldTimeSeriesSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
""" | |||||
A primitive which splits a tabular time-series Dataset for k-fold cross-validation. | |||||
Primitive sorts the time column so care should be taken to assure sorting of a | |||||
column is reasonable. E.g., if column is not numeric but of string structural type, | |||||
strings should be formatted so that sorting by them also sorts by time. | |||||
""" | |||||
__author__ = 'Distil' | |||||
__version__ = '0.3.0' | |||||
__contact__ = 'mailto:jeffrey.gleason@yonder.co' | |||||
metadata = metadata_base.PrimitiveMetadata( | |||||
{ | |||||
'id': '002f9ad1-46e3-40f4-89ed-eeffbb3a102b', | |||||
'version': __version__, | |||||
'name': "K-fold cross-validation timeseries dataset splits", | |||||
'python_path': 'd3m.primitives.tods.evaluation.kfold_time_series_split', | |||||
'source': { | |||||
'name': 'DATALab@Texas A&M University', | |||||
'contact': __contact__, | |||||
'uris': [ | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split_timeseries.py', | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
], | |||||
}, | |||||
'algorithm_types': [ | |||||
metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||||
metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
], | |||||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
}, | |||||
) | |||||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
time_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Time'], at=(main_resource_id,)) | |||||
attribute_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Attribute'], at=(main_resource_id,)) | |||||
# We want only time columns which are also attributes. | |||||
time_column_indices = [time_column_index for time_column_index in time_column_indices if time_column_index in attribute_column_indices] | |||||
if self.hyperparams['time_column_index'] is None: | |||||
if len(time_column_indices) != 1: | |||||
raise exceptions.InvalidArgumentValueError( | |||||
"If \"time_column_index\" hyper-parameter is \"None\", it is required that exactly one column with time column role semantic type is present.", | |||||
) | |||||
else: | |||||
# We know it exists because "time_column_indices" is a subset of "attribute_column_indices". | |||||
time_column_index = attribute_column_indices.index( | |||||
time_column_indices[0], | |||||
) | |||||
else: | |||||
if self.hyperparams['time_column_index'] not in time_column_indices: | |||||
raise exceptions.InvalidArgumentValueError( | |||||
"Time column index specified does not have a time column role semantic type.", | |||||
) | |||||
else: | |||||
time_column_index = attribute_column_indices.index( | |||||
self.hyperparams['time_column_index'], | |||||
) | |||||
# We first reset index. | |||||
attributes = attributes.reset_index(drop=True) | |||||
# Then convert datetime column to consistent datetime representation | |||||
attributes.insert( | |||||
loc=0, | |||||
column=uuid.uuid4(), # use uuid to ensure we are inserting a new column name | |||||
value=self._parse_time_data( | |||||
attributes, time_column_index, self.hyperparams['fuzzy_time_parsing'], | |||||
), | |||||
) | |||||
# Then sort dataframe by new datetime column. Index contains original row order. | |||||
attributes = attributes.sort_values(by=attributes.columns[0]) | |||||
# Remove datetime representation used for sorting (primitives might choose to parse this str col differently). | |||||
attributes = attributes.drop(attributes.columns[0], axis=1) | |||||
max_train_size: typing.Optional[int] = None | |||||
if self.hyperparams['number_of_window_folds'] is not None: | |||||
max_train_size = int(attributes.shape[0] * self.hyperparams['number_of_window_folds'] / self.hyperparams['number_of_folds']) | |||||
k_fold = model_selection.TimeSeriesSplit( | |||||
n_splits=self.hyperparams['number_of_folds'], | |||||
max_train_size=max_train_size | |||||
) | |||||
# We sorted "attributes" so we have to map indices on sorted "attributes" back to original | |||||
# indices. We do that by using DataFrame's index which contains original row order. | |||||
return [ | |||||
( | |||||
numpy.array([attributes.index[val] for val in train]), | |||||
numpy.array([attributes.index[val] for val in test]), | |||||
) | |||||
for train, test in k_fold.split(attributes) | |||||
] | |||||
@classmethod | |||||
def _parse_time_data(cls, inputs: container.DataFrame, column_index: metadata_base.SimpleSelectorSegment, fuzzy: bool) -> typing.List[float]: | |||||
return [ | |||||
utils.parse_datetime_to_float(value, fuzzy=fuzzy) | |||||
for value in inputs.iloc[:, column_index] | |||||
] |
@@ -0,0 +1,52 @@ | |||||
import os | |||||
import typing | |||||
import numpy | |||||
import pandas | |||||
from d3m import container, utils as d3m_utils | |||||
from d3m.metadata import base as metadata_base, hyperparams | |||||
from d3m.base import primitives | |||||
__all__ = ('NoSplitDatasetSplitPrimitive',) | |||||
class Hyperparams(hyperparams.Hyperparams): | |||||
pass | |||||
class NoSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
""" | |||||
A primitive which splits a tabular Dataset in a way that for all splits it | |||||
produces the same (full) Dataset. Useful for unsupervised learning tasks. . | |||||
""" | |||||
metadata = metadata_base.PrimitiveMetadata( | |||||
{ | |||||
'id': '48c683ad-da9e-48cf-b3a0-7394dba5e5d2', | |||||
'version': '0.1.0', | |||||
'name': "No-split tabular dataset splits", | |||||
'python_path': 'd3m.primitives.tods.evaluation.no_split_dataset_split', | |||||
'source': { | |||||
'name': 'DATALab@Texas A&M University', | |||||
'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||||
'uris': [ | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/no_split.py', | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
], | |||||
}, | |||||
'algorithm_types': [ | |||||
metadata_base.PrimitiveAlgorithmType.IDENTITY_FUNCTION, | |||||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
], | |||||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
}, | |||||
) | |||||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
# We still go through the whole splitting process to assure full compatibility | |||||
# (and error conditions) of a regular split, but we use all data for both splits. | |||||
all_data = numpy.arange(len(attributes)) | |||||
return [(all_data, all_data)] |
@@ -0,0 +1,160 @@ | |||||
import copy | |||||
import os | |||||
import typing | |||||
from d3m import container, exceptions, utils as d3m_utils | |||||
from d3m.metadata import base as metadata_base, hyperparams | |||||
from d3m.primitive_interfaces import base, transformer | |||||
Inputs = container.List | |||||
Outputs = container.List | |||||
class Hyperparams(hyperparams.Hyperparams): | |||||
match_logic = hyperparams.Enumeration( | |||||
values=['all', 'any'], | |||||
default='any', | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Should a column have all of semantic types in \"semantic_types\" to be redacted, or any of them?", | |||||
) | |||||
semantic_types = hyperparams.Set( | |||||
elements=hyperparams.Hyperparameter[str](''), | |||||
default=(), | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Redact columns with these semantic types. Only columns having semantic types listed here will be operated on, based on \"match_logic\".", | |||||
) | |||||
add_semantic_types = hyperparams.Set( | |||||
elements=hyperparams.Hyperparameter[str](''), | |||||
default=(), | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Semantic types to add to redacted columns. All listed semantic types will be added to all columns which were redacted.", | |||||
) | |||||
# TODO: Make clear the assumption that both container type (List) and Datasets should have metadata. | |||||
# Primitive is modifying metadata of Datasets, while there is officially no reason for them | |||||
# to really have metadata: metadata is stored available on the input container type, not | |||||
# values inside it. | |||||
class RedactColumnsPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]): | |||||
""" | |||||
A primitive which takes as an input a list of ``Dataset`` objects and redacts values of all columns matching | |||||
a given semantic type or types. | |||||
Redaction is done by setting all values in a redacted column to an empty string. | |||||
It operates only on DataFrame resources inside datasets. | |||||
""" | |||||
metadata = metadata_base.PrimitiveMetadata( | |||||
{ | |||||
'id': '744c4090-e2f6-489e-8efc-8b1e051bfad6', | |||||
'version': '0.2.0', | |||||
'name': "Redact columns for evaluation", | |||||
'python_path': 'd3m.primitives.tods.evaluation.redact_columns', | |||||
'source': { | |||||
'name': 'DATALab@Texas A&M University', | |||||
'contact': 'mailto:sunmj15@gmail.com', | |||||
'uris': [ | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/redact_columns.py', | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
], | |||||
}, | |||||
'installation': [{ | |||||
'type': metadata_base.PrimitiveInstallationType.PIP, | |||||
'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||||
git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||||
), | |||||
}], | |||||
'algorithm_types': [ | |||||
metadata_base.PrimitiveAlgorithmType.DATA_CONVERSION, | |||||
], | |||||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
}, | |||||
) | |||||
def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]: | |||||
output_datasets = container.List(generate_metadata=True) | |||||
for dataset in inputs: | |||||
resources = {} | |||||
metadata = dataset.metadata | |||||
for resource_id, resource in dataset.items(): | |||||
if not isinstance(resource, container.DataFrame): | |||||
resources[resource_id] = resource | |||||
continue | |||||
columns_to_redact = self._get_columns_to_redact(metadata, (resource_id,)) | |||||
if not columns_to_redact: | |||||
resources[resource_id] = resource | |||||
continue | |||||
resource = copy.copy(resource) | |||||
for column_index in columns_to_redact: | |||||
column_metadata = dataset.metadata.query((resource_id, metadata_base.ALL_ELEMENTS, column_index)) | |||||
if 'structural_type' in column_metadata and issubclass(column_metadata['structural_type'], str): | |||||
resource.iloc[:, column_index] = '' | |||||
else: | |||||
raise TypeError("Primitive can operate only on columns with structural type \"str\", not \"{type}\".".format( | |||||
type=column_metadata.get('structural_type', None), | |||||
)) | |||||
metadata = self._update_metadata(metadata, resource_id, column_index, ()) | |||||
resources[resource_id] = resource | |||||
dataset = container.Dataset(resources, metadata) | |||||
output_datasets.append(dataset) | |||||
output_datasets.metadata = metadata_base.DataMetadata({ | |||||
'schema': metadata_base.CONTAINER_SCHEMA_VERSION, | |||||
'structural_type': container.List, | |||||
'dimension': { | |||||
'length': len(output_datasets), | |||||
}, | |||||
}) | |||||
# We update metadata based on metadata of each dataset. | |||||
# TODO: In the future this might be done automatically by generate_metadata. | |||||
# See: https://gitlab.com/datadrivendiscovery/d3m/issues/119 | |||||
for index, dataset in enumerate(output_datasets): | |||||
output_datasets.metadata = dataset.metadata.copy_to(output_datasets.metadata, (), (index,)) | |||||
return base.CallResult(output_datasets) | |||||
def _get_columns_to_redact(self, inputs_metadata: metadata_base.DataMetadata, at: metadata_base.Selector) -> typing.Sequence[int]: | |||||
columns = [] | |||||
for element in inputs_metadata.get_elements(list(at) + [metadata_base.ALL_ELEMENTS]): | |||||
semantic_types = inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS, element]).get('semantic_types', ()) | |||||
# TODO: Should we handle inheritance between semantic types here? | |||||
if self.hyperparams['match_logic'] == 'all': | |||||
matched = all(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||||
elif self.hyperparams['match_logic'] == 'any': | |||||
matched = any(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||||
else: | |||||
raise exceptions.UnexpectedValueError("Unknown value of hyper-parameter \"match_logic\": {value}".format(value=self.hyperparams['match_logic'])) | |||||
if matched: | |||||
if element is metadata_base.ALL_ELEMENTS: | |||||
return list(range(inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS]).get('dimension', {}).get('length', 0))) | |||||
else: | |||||
columns.append(typing.cast(int, element)) | |||||
return columns | |||||
def _update_metadata( | |||||
self, inputs_metadata: metadata_base.DataMetadata, resource_id: metadata_base.SelectorSegment, | |||||
column_index: int, at: metadata_base.Selector, | |||||
) -> metadata_base.DataMetadata: | |||||
outputs_metadata = inputs_metadata | |||||
for semantic_type in self.hyperparams['add_semantic_types']: | |||||
outputs_metadata = outputs_metadata.add_semantic_type(tuple(at) + (resource_id, metadata_base.ALL_ELEMENTS, column_index), semantic_type) | |||||
return outputs_metadata |
@@ -0,0 +1,88 @@ | |||||
import os | |||||
import typing | |||||
import numpy | |||||
import pandas | |||||
from sklearn import model_selection | |||||
from d3m import container, exceptions, utils as d3m_utils | |||||
from d3m.metadata import base as metadata_base, hyperparams | |||||
from d3m.base import primitives | |||||
__all__ = ('TrainScoreDatasetSplitPrimitive',) | |||||
class Hyperparams(hyperparams.Hyperparams): | |||||
train_score_ratio = hyperparams.Uniform( | |||||
lower=0, | |||||
upper=1, | |||||
default=0.75, | |||||
upper_inclusive=True, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="The ratio between the train and score data and represents the proportion of the Dataset to include in the train split. The rest is included in the score split.", | |||||
) | |||||
stratified = hyperparams.UniformBool( | |||||
default=False, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||||
) | |||||
shuffle = hyperparams.UniformBool( | |||||
default=False, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Whether to shuffle the data before splitting into batches.", | |||||
) | |||||
delete_recursive = hyperparams.Hyperparameter[bool]( | |||||
default=False, | |||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||||
) | |||||
class TrainScoreDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
""" | |||||
A primitive which splits a tabular Dataset into random train and score subsets. | |||||
""" | |||||
metadata = metadata_base.PrimitiveMetadata( | |||||
{ | |||||
'id': '3fcc6dc4-6681-4c86-948e-066d14e7d803', | |||||
'version': '0.1.0', | |||||
'name': "Train-score tabular dataset splits", | |||||
'python_path': 'd3m.primitives.tods.evaluation.train_score_dataset_split', | |||||
'source': { | |||||
'name': 'DATALab@Texas A&M University', | |||||
'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||||
'uris': [ | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/train_score_split.py', | |||||
'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
], | |||||
}, | |||||
'installation': [{ | |||||
'type': metadata_base.PrimitiveInstallationType.PIP, | |||||
'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||||
git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||||
), | |||||
}], | |||||
'algorithm_types': [ | |||||
metadata_base.PrimitiveAlgorithmType.HOLDOUT, | |||||
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
], | |||||
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
}, | |||||
) | |||||
def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
if self.hyperparams['stratified'] and not len(targets.columns): | |||||
raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||||
train_data, score_data = model_selection.train_test_split( | |||||
numpy.arange(len(attributes)), | |||||
test_size=None, | |||||
train_size=self.hyperparams['train_score_ratio'], | |||||
random_state=self._random_state, | |||||
shuffle=self.hyperparams['shuffle'], | |||||
stratify=targets if self.hyperparams['stratified'] else None, | |||||
) | |||||
return [(train_data, score_data)] |
@@ -0,0 +1,192 @@ | |||||
import datetime | |||||
import logging | |||||
import typing | |||||
import dateutil.parser | |||||
import numpy | |||||
from d3m import container, deprecate | |||||
from d3m.base import utils as base_utils | |||||
from d3m.metadata import base as metadata_base | |||||
logger = logging.getLogger(__name__) | |||||
DEFAULT_DATETIME = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) | |||||
@deprecate.function(message="it should not be used anymore") | |||||
def copy_elements_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||||
to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return source_metadata._copy_elements_metadata(target_metadata, list(from_selector), list(to_selector), [], ignore_all_elements) | |||||
@deprecate.function(message="use Metadata.copy_to method instead") | |||||
def copy_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||||
to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return source_metadata.copy_to(target_metadata, from_selector, to_selector, ignore_all_elements=ignore_all_elements) | |||||
@deprecate.function(message="use DataFrame.select_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def select_columns(inputs: container.DataFrame, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||||
source: typing.Any = None) -> container.DataFrame: | |||||
return inputs.select_columns(columns) | |||||
@deprecate.function(message="use DataMetadata.select_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def select_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||||
source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return inputs_metadata.select_columns(columns) | |||||
@deprecate.function(message="use DataMetadata.list_columns_with_semantic_types method instead") | |||||
def list_columns_with_semantic_types(metadata: metadata_base.DataMetadata, semantic_types: typing.Sequence[str], *, | |||||
at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||||
return metadata.list_columns_with_semantic_types(semantic_types, at=at) | |||||
@deprecate.function(message="use DataMetadata.list_columns_with_structural_types method instead") | |||||
def list_columns_with_structural_types(metadata: metadata_base.DataMetadata, structural_types: typing.Union[typing.Callable, typing.Sequence[typing.Union[str, type]]], *, | |||||
at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||||
return metadata.list_columns_with_structural_types(structural_types, at=at) | |||||
@deprecate.function(message="use DataFrame.remove_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def remove_columns(inputs: container.DataFrame, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> container.DataFrame: | |||||
return inputs.remove_columns(column_indices) | |||||
@deprecate.function(message="use DataMetadata.remove_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def remove_columns_metadata(inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return inputs_metadata.remove_columns(column_indices) | |||||
@deprecate.function(message="use DataFrame.append_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def append_columns(left: container.DataFrame, right: container.DataFrame, *, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||||
return left.append_columns(right, use_right_metadata=use_right_metadata) | |||||
@deprecate.function(message="use DataMetadata.append_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def append_columns_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return left_metadata.append_columns(right_metadata, use_right_metadata=use_right_metadata) | |||||
@deprecate.function(message="use DataFrame.insert_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def insert_columns(inputs: container.DataFrame, columns: container.DataFrame, at_column_index: int, *, source: typing.Any = None) -> container.DataFrame: | |||||
return inputs.insert_columns(columns, at_column_index) | |||||
@deprecate.function(message="use DataMetadata.insert_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def insert_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, at_column_index: int, *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return inputs_metadata.insert_columns(columns_metadata, at_column_index) | |||||
@deprecate.function(message="use DataFrame.replace_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def replace_columns(inputs: container.DataFrame, columns: container.DataFrame, column_indices: typing.Sequence[int], *, copy: bool = True, source: typing.Any = None) -> container.DataFrame: | |||||
return inputs.replace_columns(columns, column_indices, copy=copy) | |||||
@deprecate.function(message="use DataMetadata.replace_columns method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def replace_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return inputs_metadata.replace_columns(columns_metadata, column_indices) | |||||
@deprecate.function(message="use DataMetadata.get_index_columns method instead") | |||||
def get_index_columns(metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||||
return metadata.get_index_columns(at=at) | |||||
@deprecate.function(message="use DataFrame.horizontal_concat method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def horizontal_concat(left: container.DataFrame, right: container.DataFrame, *, use_index: bool = True, | |||||
remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||||
return left.horizontal_concat(right, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||||
@deprecate.function(message="use DataMetadata.horizontal_concat method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def horizontal_concat_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, *, use_index: bool = True, | |||||
remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return left_metadata.horizontal_concat(right_metadata, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||||
@deprecate.function(message="use d3m.base.utils.get_columns_to_use function instead") | |||||
def get_columns_to_use(metadata: metadata_base.DataMetadata, use_columns: typing.Sequence[int], exclude_columns: typing.Sequence[int], | |||||
can_use_column: typing.Callable) -> typing.Tuple[typing.List[int], typing.List[int]]: | |||||
return base_utils.get_columns_to_use(metadata, use_columns, exclude_columns, can_use_column) | |||||
@deprecate.function(message="use d3m.base.utils.combine_columns function instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def combine_columns(return_result: str, add_index_columns: bool, inputs: container.DataFrame, column_indices: typing.Sequence[int], | |||||
columns_list: typing.Sequence[container.DataFrame], *, source: typing.Any = None) -> container.DataFrame: | |||||
return base_utils.combine_columns(inputs, column_indices, columns_list, return_result=return_result, add_index_columns=add_index_columns) | |||||
@deprecate.function(message="use d3m.base.utils.combine_columns_metadata function instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def combine_columns_metadata(return_result: str, add_index_columns: bool, inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], | |||||
columns_metadata_list: typing.Sequence[metadata_base.DataMetadata], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return base_utils.combine_columns_metadata(inputs_metadata, column_indices, columns_metadata_list, return_result=return_result, add_index_columns=add_index_columns) | |||||
@deprecate.function(message="use DataMetadata.set_table_metadata method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def set_table_metadata(inputs_metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = (), source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
return inputs_metadata.set_table_metadata(at=at) | |||||
@deprecate.function(message="use DataMetadata.get_column_index_from_column_name method instead") | |||||
def get_column_index_from_column_name(inputs_metadata: metadata_base.DataMetadata, column_name: str, *, at: metadata_base.Selector = ()) -> int: | |||||
return inputs_metadata.get_column_index_from_column_name(column_name, at=at) | |||||
@deprecate.function(message="use Dataset.get_relations_graph method instead") | |||||
def build_relation_graph(dataset: container.Dataset) -> typing.Dict[str, typing.List[typing.Tuple[str, bool, int, int, typing.Dict]]]: | |||||
return dataset.get_relations_graph() | |||||
@deprecate.function(message="use d3m.base.utils.get_tabular_resource function instead") | |||||
def get_tabular_resource(dataset: container.Dataset, resource_id: typing.Optional[str], *, | |||||
pick_entry_point: bool = True, pick_one: bool = True, has_hyperparameter: bool = True) -> typing.Tuple[str, container.DataFrame]: | |||||
return base_utils.get_tabular_resource(dataset, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one, has_hyperparameter=has_hyperparameter) | |||||
@deprecate.function(message="use d3m.base.utils.get_tabular_resource_metadata function instead") | |||||
def get_tabular_resource_metadata(dataset_metadata: metadata_base.DataMetadata, resource_id: typing.Optional[metadata_base.SelectorSegment], *, | |||||
pick_entry_point: bool = True, pick_one: bool = True) -> metadata_base.SelectorSegment: | |||||
return base_utils.get_tabular_resource_metadata(dataset_metadata, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one) | |||||
@deprecate.function(message="use Dataset.select_rows method instead") | |||||
@deprecate.arguments('source', message="argument ignored") | |||||
def cut_dataset(dataset: container.Dataset, row_indices_to_keep: typing.Mapping[str, typing.Sequence[int]], *, | |||||
source: typing.Any = None) -> container.Dataset: | |||||
return dataset.select_rows(row_indices_to_keep) | |||||
def parse_datetime(value: str, *, fuzzy: bool = True) -> typing.Optional[datetime.datetime]: | |||||
try: | |||||
return dateutil.parser.parse(value, default=DEFAULT_DATETIME, fuzzy=fuzzy) | |||||
except (ValueError, OverflowError, TypeError): | |||||
return None | |||||
def parse_datetime_to_float(value: str, *, fuzzy: bool = True) -> float: | |||||
try: | |||||
parsed = parse_datetime(value, fuzzy=fuzzy) | |||||
if parsed is None: | |||||
return numpy.nan | |||||
else: | |||||
return parsed.timestamp() | |||||
except (ValueError, OverflowError, TypeError): | |||||
return numpy.nan |