From 88dc4c4e3304c9f2ee2d68cfe26fde0744bae06b Mon Sep 17 00:00:00 2001 From: lhenry15 Date: Wed, 3 Feb 2021 21:47:15 -0600 Subject: [PATCH] resolve construct prediction bug Former-commit-id: b0d35d9aea14dc7b9382642dc4aebbb7b33227af [formerly 28b7582d253ac7370edb4ee77c41a78ec5efe21b] [formerly 632f8af13ad3d37eb5e721730ccef4c795052ff0 [formerly 9eb35832567152e9d8b616951cba2f9d472e3fd0]] [formerly ab6f5a8a9c6b9255cdcd65805974015a7daf0693 [formerly 2720e70c78084e35396f51bb28c2090679f1a53b] [formerly c1f615a3c58a4916b810c6a566c16c12bc507908 [formerly 21840b378759723d0ff5a11d59d3e2daade1e679]]] [formerly 2184b60ad634dbc847bfaeb3413ae99b824bad0a [formerly 2dae6c4ed4a1995848d151ca54fddd69969145a1] [formerly 6a577865b8a1f92a963b3d168e06a63dcb0161b3 [formerly ee28aaed84e9f06385a453ed06a7314fc603ca91]] [formerly 72078d88c8363efb9a91ddd5963d234b8ab0fd10 [formerly ab1fb9d5125467f2bad37c01b395460e15830910] [formerly 476600a47c7aa6c88895eda67185ee6a70d1fd5d [formerly f5a6205bb56bdfcde3584db8107b25af51ce1feb]]]] [formerly e5c2c3deefcac7dea9094877181d7910eb5286a5 [formerly 8585dc41a9c94b9a7084484b60e3cc07b9cef7fe] [formerly 9cc7fe2088bcc893022abf4b6ba137fdbdda4211 [formerly 1d2104316f6dc308482b2d8c1c0210a240ef4f4a]] [formerly b51e614a11931bf92fcbc497ce6509339a4a9771 [formerly f9891a191d390c09ecd274931f6fbab9f09fe06c] [formerly 048aa2f114fe950c50e3d9833dbe4967c65f25b1 [formerly d4de64574be6221a8287ce4c4f72e7b39f361258]]] [formerly ff66f55a512c8e42de6d48182b79b9c8e42bfae1 [formerly 0b691b9a7f54cef7d12d14b5e5859959503838c8] [formerly e64e8dd25314730134bf6de25d3d3c0a9d5c8d63 [formerly cc45939e01d88b193138b72384ba519e4df731e7]] [formerly cdad10712ac93799628bcf24365402a347dbb16d [formerly 2789e20e79f5b9726ad7c1a002a31c5d2253f63b] [formerly 25924f293c4e9d69e44c77a08a05c46d3f66bb27 [formerly 32997accab262d1a8960e38792035422d14efd13]]]]] [formerly f43b4310405ef100174e295d57504653eb26f04f [formerly 95815d02ca055259aa579dc3822f1e77ddca8a1c] [formerly fe9bd45d441e3608622bdfb90d3748bb2fd17bbd [formerly 6daf0aa73e40a4027dd5ae04cf16522974b71b36]] [formerly 61ab30c9a3ab7d8fb4d7366dd09937813e34a7fb [formerly a13c6e23b463cb54a80938f981b811f33326f68e] [formerly 86fa5919eec79d126e750a176f67f1554eac41fa [formerly 1e49e1a30348ac38fe7e31a59a50006aabfac43c]]] [formerly 69c5bc967a3bd09af946c28c8680a510614897fd [formerly ab82915cbd9157ad0dc204c0247107bda650742a] [formerly f8057c3b14e022eb44a0abb79f50196e093eaa35 [formerly 5232f345780a58c7544bd6a4d5c31b930cd0a36e]] [formerly 671c54e9520203b5ed14053770c4a1ed4b621959 [formerly 6454a28f2685a4176ead96d7ee68e15afaaadfa8] [formerly 3db6ff66b90f6abf68a8c78ae929d548c1f6b2ad [formerly aa8c7fe127cf82a244718da550b121052855ad27]]]] [formerly 86b2ec6b84d16ae7d5c51eace4f24ea472ea11a2 [formerly f35c344efebe54ad170cd6de9bc7f4126e53f510] [formerly d5616f66cd90f797ad80a85b38fe9234f86540fe [formerly 98c9dca7daf1efc38766b3fde4854f06538d7efa]] [formerly a7dcc62bc54b16698b3b7e06aa9a6735afd63b08 [formerly 4ef4fa0c98b2f210cbe7de6f2a00ba5c63c1b2ba] [formerly 55f670b9ae053698ddba84658940ddea9ac795e6 [formerly 1cd4421e2e02bfb8597f338583ad32c0550c02af]]] [formerly d7a5bab8320a31eb1dedac6a7806e5c59683053f [formerly c77c5b48df83e1aed383fb28f0812fc87ecbf973] [formerly 01ffd33e2ff1a6fa5cffd6dd82289111682cd9a1 [formerly aea728ceb67c39709f9a3921547a95e09c1dcff7]] [formerly 16afb18e352c8da18fff35384430bbc427cc4802 [formerly 4768b156f37cfba1a657944fe8d4314af3532685] [formerly 3c1298c626a945d362a74f922c5660b11e8ad491 [formerly 1e61cf09743c7ec285301e748dbbec3caf42d89c]]]]]] Former-commit-id: 28b09a56cb8b16b07fc356e5f3133ad2bb9d5859 [formerly 5241dbb36c08ebc2ecfeac6450b1cb3b52f624ea] [formerly d43909d97927eee8f9f09fcc8b4e5c92cddb3a6a [formerly 43b0cca7f5e0372770f2b2a530bc921d8ac9ec8d]] [formerly 8d8f384c8e46748e59790e0421d07bd8a942043d [formerly bcf58203c66d0bc6c5392a855d1b5d617dcef9df] [formerly ca56bff2d0a2b498941b59b240f1cf4ac420462c [formerly 7a8750ffc99115c6cb37e63ce9fd0c3958d235a8]]] [formerly 2ce9fa87aeb4fe174de68dc4d79501b057b477b1 [formerly e7b1b542f5b9fbf69df7912c4393d264713c24f4] [formerly 62a7edf94a8193828f465bd432ffc15edd480ea1 [formerly 26ca5f220dcc3745ee271e5712d72836f2255222]] [formerly 73d10253b93b35f692e13089ddb148d408cf14fe [formerly af463cecb0333e14aabc502126f7c6d22973e214] [formerly 961314f474a483302d2a5ce0a12b51afc386cb6d [formerly 6e7141e5e2e9f72d2102892b4affce0de0944cb0]]]] [formerly c3f8938ba5ba29f2ad8de1cfc486a7f323ecea69 [formerly 7f6674829240eab485a34d67eb41f61ca744b26f] [formerly 3e5ef2c136584b7f0e769f5bdf43ab994fb12c21 [formerly 266ddb4ccc8948b124ef81cee46efd1804bddd2e]] [formerly 4a8a5437b39ea3daf303fa79dbae22c294ff0b40 [formerly cbcb0f87775b068d2fccd5c8632cf05ab0a019a7] [formerly c212b8217fd92fa1b7899ab5211e0262f3a9c927 [formerly 5f3d3d01c87574acb44b077dd7439ac53b98a0f1]]] [formerly 09764ba6cd565eae5cb426f705ec72ee255e4b49 [formerly 4db991be5f3f4d27092a19876eb75c5e51f60bb0] [formerly f79c6ec15df35536aeba92957c37306643e3a517 [formerly 7f8eb54d4752bbefbc208251563af8e77a1792c0]] [formerly 771b00b188ceaa0b9569636bae5f28643999ce47 [formerly f1dcba565fc0943c24d468b496b48df7f815420f] [formerly 83c42510ad4af56394a30b78f2b4eeb693b2df6b [formerly 3c1298c626a945d362a74f922c5660b11e8ad491]]]]] Former-commit-id: f5cdcca4f3e91cc75d3af0de6bd17cce1dfaf021 [formerly dc57f947a2385967e8847ab784a433184f9620a3] [formerly 55cd6eb9a35d44b00d0d55a30c9686dffafdae57 [formerly e92d6c09230f08e9760fc17f45f8389e919b82a3]] [formerly ba80ed43d23d7ff4c349e534fb9036b84064bbf3 [formerly 0a2f65401a73b676569fb809d0e4e9dd28803952] [formerly 73c0c2ebb36b0a14c038f327507a133d082fef6f [formerly bf80cf285eec74221e6add1ff41e5f5c3b3ab808]]] [formerly a34f21d93329c3bee2d5e97fcbe820a0eab1d5df [formerly ae311cb3e77af5bcd03976e2a108ac1eba772fc4] [formerly ae0e3ed079c601e0096c648911ae8246eef220fe [formerly c9030d0303acdbd031b19797c30e98573cf32837]] [formerly f8f4f6a8ec75e269e0db55255d3a01dcafc79c23 [formerly 343dd65df69b51f0f404f200c3af4cbc52320b25] [formerly 1c97e6a7baa6ff68d466723b754b2887cc834be9 [formerly c69d281aca97ee3bde4328d04d92163eab376623]]]] Former-commit-id: b93de572408cfd327f0ca40b8d62530945e5c5ba [formerly 6fd32d0759069e91a1183acb86096ccbab5f88e9] [formerly ff29e71cb0b43473b0be6345b1747e6ec2df856b [formerly ee483b56f98eb5e187b7b262278d121c34ed86ab]] [formerly 2f944ec28b50590b4b4eb1e7aa07518f78234810 [formerly 15c46e806cd9f35403a758aa6f46f3d746b0d0c5] [formerly dca0c18b8b693227cdef4539165e8e2d6a47c340 [formerly f185a8658cf2893c7fa8b0393d52489593d32d37]]] Former-commit-id: 76ed7f184f13ac60e2d458d60e9a8befe3cace6f [formerly c8adbe1dea7c40540fc74be2867a780b10a5788f] [formerly 392fc3e54b90f8d1201d2986ffbde77529b26560 [formerly 5bd396738e6152ff5b4b28bbda4cb495bcda1459]] Former-commit-id: 11579edbe104d2a7361847a5dd71fc06dc4f369c [formerly 6970795314200993d1d027617ceddd7a7c6421a3] Former-commit-id: 8fca90698b3ffff3c183e4754cabc6b39fe9b16f --- datasets/dataset_utils.py | 297 +++++++++++++++++++ datasets/hub.py | 559 ++++++++++++++++++++++++++++++++++++ datasets/tods_dataset_base.py | 139 +++++++++ datasets/tods_datasets.py | 116 ++++++++ src/axolotl | 1 - src/common-primitives | 1 - src/d3m | 1 - tods/common/FixedSplit.py | 116 ++++++++ tods/common/KFoldSplit.py | 87 ++++++ tods/common/KFoldSplitTimeseries.py | 187 ++++++++++++ tods/common/NoSplit.py | 52 ++++ tods/common/RedactColumns.py | 160 +++++++++++ tods/common/TrainScoreSplit.py | 88 ++++++ tods/common/__init__.py | 0 tods/common/utils.py | 192 +++++++++++++ 15 files changed, 1993 insertions(+), 3 deletions(-) create mode 100644 datasets/dataset_utils.py create mode 100644 datasets/hub.py create mode 100755 datasets/tods_dataset_base.py create mode 100755 datasets/tods_datasets.py delete mode 160000 src/axolotl delete mode 160000 src/common-primitives delete mode 160000 src/d3m create mode 100644 tods/common/FixedSplit.py create mode 100644 tods/common/KFoldSplit.py create mode 100644 tods/common/KFoldSplitTimeseries.py create mode 100644 tods/common/NoSplit.py create mode 100644 tods/common/RedactColumns.py create mode 100644 tods/common/TrainScoreSplit.py create mode 100644 tods/common/__init__.py create mode 100644 tods/common/utils.py diff --git a/datasets/dataset_utils.py b/datasets/dataset_utils.py new file mode 100644 index 0000000..b5300d7 --- /dev/null +++ b/datasets/dataset_utils.py @@ -0,0 +1,297 @@ +import os +import os.path +import hashlib +import gzip +import errno +import tarfile +from typing import Any, Callable, List, Iterable, Optional, TypeVar +import zipfile + +# import torch +# from torch.utils.model_zoo import tqdm + +from hub import tqdm + +def gen_bar_updater() -> Callable[[int, int, int], None]: + pbar = tqdm(total=None) + + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: + md5 = hashlib.md5() + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + return md5.hexdigest() + + +def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the basename of the URL + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + import urllib + + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + # check if file is already present locally + if check_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + else: # download the file + try: + print('Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater() + ) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + ' Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater() + ) + else: + raise e + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def list_dir(root: str, prefix: bool = False) -> List[str]: + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + return directories + + +def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] + if prefix is True: + files = [os.path.join(root, d) for d in files] + return files + + +def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined] + return "Google Drive - Quota exceeded" in response.text + + +def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): + """Download a Google Drive file from and place it in root. + + Args: + file_id (str): id of file to be downloaded + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the id of the file. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url + import requests + url = "https://docs.google.com/uc?export=download" + + root = os.path.expanduser(root) + if not filename: + filename = file_id + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if os.path.isfile(fpath) and check_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + else: + session = requests.Session() + + response = session.get(url, params={'id': file_id}, stream=True) + token = _get_confirm_token(response) + + if token: + params = {'id': file_id, 'confirm': token} + response = session.get(url, params=params, stream=True) + + if _quota_exceeded(response): + msg = ( + f"The daily quota of the file {filename} is exceeded and it " + f"can't be downloaded. This is a limitation of Google Drive " + f"and can only be overcome by trying again later." + ) + raise RuntimeError(msg) + + _save_response_content(response, fpath) + + +def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + + return None + + +def _save_response_content( + response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined] +) -> None: + with open(destination, "wb") as f: + pbar = tqdm(total=None) + progress = 0 + for chunk in response.iter_content(chunk_size): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + progress += len(chunk) + pbar.update(progress - pbar.n) + pbar.close() + + +def _is_tarxz(filename: str) -> bool: + return filename.endswith(".tar.xz") + + +def _is_tar(filename: str) -> bool: + return filename.endswith(".tar") + + +def _is_targz(filename: str) -> bool: + return filename.endswith(".tar.gz") + + +def _is_tgz(filename: str) -> bool: + return filename.endswith(".tgz") + + +def _is_gzip(filename: str) -> bool: + return filename.endswith(".gz") and not filename.endswith(".tar.gz") + + +def _is_zip(filename: str) -> bool: + return filename.endswith(".zip") + + +def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: + if to_path is None: + to_path = os.path.dirname(from_path) + + if _is_tar(from_path): + with tarfile.open(from_path, 'r') as tar: + tar.extractall(path=to_path) + elif _is_targz(from_path) or _is_tgz(from_path): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif _is_tarxz(from_path): + with tarfile.open(from_path, 'r:xz') as tar: + tar.extractall(path=to_path) + elif _is_gzip(from_path): + to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif _is_zip(from_path): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError("Extraction of {} not supported".format(from_path)) + + if remove_finished: + os.remove(from_path) + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + # print(archive) + # print(extract_root) + # extract_archive(archive, extract_root, remove_finished) + + +def iterable_to_str(iterable: Iterable) -> str: + return "'" + "', '".join([str(item) for item in iterable]) + "'" + + +T = TypeVar("T", str, bytes) + + +def verify_str_arg( + value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, +) -> T: + if not isinstance(value, torch._six.string_classes): + if arg is None: + msg = "Expected type str, but got type {type}." + else: + msg = "Expected type str for argument {arg}, but got type {type}." + msg = msg.format(type=type(value), arg=arg) + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + if custom_msg is not None: + msg = custom_msg + else: + msg = ("Unknown value '{value}' for argument {arg}. " + "Valid values are {{{valid_values}}}.") + msg = msg.format(value=value, arg=arg, + valid_values=iterable_to_str(valid_values)) + raise ValueError(msg) + + return value diff --git a/datasets/hub.py b/datasets/hub.py new file mode 100644 index 0000000..edc4906 --- /dev/null +++ b/datasets/hub.py @@ -0,0 +1,559 @@ +import errno +import hashlib +import os +import re +import shutil +import sys +import tempfile +# import torch +import warnings +import zipfile + +from urllib.request import urlopen, Request +from urllib.parse import urlparse # noqa: F401 + +try: + from tqdm.auto import tqdm # automatically select proper tqdm submodule if available +except ImportError: + try: + from tqdm import tqdm + except ImportError: + # fake tqdm if it's not installed + class tqdm(object): # type: ignore + + def __init__(self, total=None, disable=False, + unit=None, unit_scale=None, unit_divisor=None): + self.total = total + self.disable = disable + self.n = 0 + # ignore unit, unit_scale, unit_divisor; they're just for real tqdm + + def update(self, n): + if self.disable: + return + + self.n += n + if self.total is None: + sys.stderr.write("\r{0:.1f} bytes".format(self.n)) + else: + sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) + sys.stderr.flush() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disable: + return + + sys.stderr.write('\n') + +# # matches bfd8deac from resnet18-bfd8deac.pth +# HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') +# +# MASTER_BRANCH = 'master' +# ENV_TORCH_HOME = 'TORCH_HOME' +# ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +# DEFAULT_CACHE_DIR = '~/.cache' +# VAR_DEPENDENCY = 'dependencies' +# MODULE_HUBCONF = 'hubconf.py' +# READ_DATA_CHUNK = 8192 +# _hub_dir = None +# +# +# # Copied from tools/shared/module_loader to be included in torch package +# def import_module(name, path): +# import importlib.util +# from importlib.abc import Loader +# spec = importlib.util.spec_from_file_location(name, path) +# module = importlib.util.module_from_spec(spec) +# assert isinstance(spec.loader, Loader) +# spec.loader.exec_module(module) +# return module +# +# +# def _remove_if_exists(path): +# if os.path.exists(path): +# if os.path.isfile(path): +# os.remove(path) +# else: +# shutil.rmtree(path) +# +# +# def _git_archive_link(repo_owner, repo_name, branch): +# return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch) +# +# +# def _load_attr_from_module(module, func_name): +# # Check if callable is defined in the module +# if func_name not in dir(module): +# return None +# return getattr(module, func_name) +# +# +# def _get_torch_home(): +# torch_home = os.path.expanduser( +# os.getenv(ENV_TORCH_HOME, +# os.path.join(os.getenv(ENV_XDG_CACHE_HOME, +# DEFAULT_CACHE_DIR), 'torch'))) +# return torch_home +# +# +# def _parse_repo_info(github): +# branch = MASTER_BRANCH +# if ':' in github: +# repo_info, branch = github.split(':') +# else: +# repo_info = github +# repo_owner, repo_name = repo_info.split('/') +# return repo_owner, repo_name, branch +# +# +# def _get_cache_or_reload(github, force_reload, verbose=True): +# # Setup hub_dir to save downloaded files +# hub_dir = get_dir() +# if not os.path.exists(hub_dir): +# os.makedirs(hub_dir) +# # Parse github repo information +# repo_owner, repo_name, branch = _parse_repo_info(github) +# # Github allows branch name with slash '/', +# # this causes confusion with path on both Linux and Windows. +# # Backslash is not allowed in Github branch name so no need to +# # to worry about it. +# normalized_br = branch.replace('/', '_') +# # Github renames folder repo-v1.x.x to repo-1.x.x +# # We don't know the repo name before downloading the zip file +# # and inspect name from it. +# # To check if cached repo exists, we need to normalize folder names. +# repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br])) +# +# use_cache = (not force_reload) and os.path.exists(repo_dir) +# +# if use_cache: +# if verbose: +# sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) +# else: +# cached_file = os.path.join(hub_dir, normalized_br + '.zip') +# _remove_if_exists(cached_file) +# +# url = _git_archive_link(repo_owner, repo_name, branch) +# sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file)) +# download_url_to_file(url, cached_file, progress=False) +# +# with zipfile.ZipFile(cached_file) as cached_zipfile: +# extraced_repo_name = cached_zipfile.infolist()[0].filename +# extracted_repo = os.path.join(hub_dir, extraced_repo_name) +# _remove_if_exists(extracted_repo) +# # Unzip the code and rename the base folder +# cached_zipfile.extractall(hub_dir) +# +# _remove_if_exists(cached_file) +# _remove_if_exists(repo_dir) +# shutil.move(extracted_repo, repo_dir) # rename the repo +# +# return repo_dir +# +# +# def _check_module_exists(name): +# if sys.version_info >= (3, 4): +# import importlib.util +# return importlib.util.find_spec(name) is not None +# elif sys.version_info >= (3, 3): +# # Special case for python3.3 +# import importlib.find_loader +# return importlib.find_loader(name) is not None +# else: +# # NB: Python2.7 imp.find_module() doesn't respect PEP 302, +# # it cannot find a package installed as .egg(zip) file. +# # Here we use workaround from: +# # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1 +# # Also imp doesn't handle hierarchical module names (names contains dots). +# try: +# # 1. Try imp.find_module(), which searches sys.path, but does +# # not respect PEP 302 import hooks. +# import imp +# result = imp.find_module(name) +# if result: +# return True +# except ImportError: +# pass +# path = sys.path +# for item in path: +# # 2. Scan path for import hooks. sys.path_importer_cache maps +# # path items to optional "importer" objects, that implement +# # find_module() etc. Note that path must be a subset of +# # sys.path for this to work. +# importer = sys.path_importer_cache.get(item) +# if importer: +# try: +# result = importer.find_module(name, [item]) +# if result: +# return True +# except ImportError: +# pass +# return False +# +# def _check_dependencies(m): +# dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) +# +# if dependencies is not None: +# missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] +# if len(missing_deps): +# raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) +# +# +# def _load_entry_from_hubconf(m, model): +# if not isinstance(model, str): +# raise ValueError('Invalid input: model should be a string of function name') +# +# # Note that if a missing dependency is imported at top level of hubconf, it will +# # throw before this function. It's a chicken and egg situation where we have to +# # load hubconf to know what're the dependencies, but to import hubconf it requires +# # a missing package. This is fine, Python will throw proper error message for users. +# _check_dependencies(m) +# +# func = _load_attr_from_module(m, model) +# +# if func is None or not callable(func): +# raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) +# +# return func +# +# +# def get_dir(): +# r""" +# Get the Torch Hub cache directory used for storing downloaded models & weights. +# +# If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where +# environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. +# ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux +# filesystem layout, with a default value ``~/.cache`` if the environment +# variable is not set. +# """ +# # Issue warning to move data if old env is set +# if os.getenv('TORCH_HUB'): +# warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') +# +# if _hub_dir is not None: +# return _hub_dir +# return os.path.join(_get_torch_home(), 'hub') +# +# +# def set_dir(d): +# r""" +# Optionally set the Torch Hub directory used to save downloaded models & weights. +# +# Args: +# d (string): path to a local folder to save downloaded models & weights. +# """ +# global _hub_dir +# _hub_dir = d +# +# +# def list(github, force_reload=False): +# r""" +# List all entrypoints available in `github` hubconf. +# +# Args: +# github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional +# tag/branch. The default branch is `master` if not specified. +# Example: 'pytorch/vision[:hub]' +# force_reload (bool, optional): whether to discard the existing cache and force a fresh download. +# Default is `False`. +# Returns: +# entrypoints: a list of available entrypoint names +# +# Example: +# >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) +# """ +# repo_dir = _get_cache_or_reload(github, force_reload, True) +# +# sys.path.insert(0, repo_dir) +# +# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) +# +# sys.path.remove(repo_dir) +# +# # We take functions starts with '_' as internal helper functions +# entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] +# +# return entrypoints +# +# +# def help(github, model, force_reload=False): +# r""" +# Show the docstring of entrypoint `model`. +# +# Args: +# github (string): a string with format with an optional +# tag/branch. The default branch is `master` if not specified. +# Example: 'pytorch/vision[:hub]' +# model (string): a string of entrypoint name defined in repo's hubconf.py +# force_reload (bool, optional): whether to discard the existing cache and force a fresh download. +# Default is `False`. +# Example: +# >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) +# """ +# repo_dir = _get_cache_or_reload(github, force_reload, True) +# +# sys.path.insert(0, repo_dir) +# +# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) +# +# sys.path.remove(repo_dir) +# +# entry = _load_entry_from_hubconf(hub_module, model) +# +# return entry.__doc__ +# +# +# # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, +# # but Python2 complains syntax error for it. We have to skip force_reload in function +# # signature here but detect it in kwargs instead. +# # TODO: fix it after Python2 EOL +# def load(repo_or_dir, model, *args, **kwargs): +# r""" +# Load a model from a github repo or a local directory. +# +# Note: Loading a model is the typical use case, but this can also be used to +# for loading other objects such as tokenizers, loss functions, etc. +# +# If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be +# of the form ``repo_owner/repo_name[:tag_name]`` with an optional +# tag/branch. +# +# If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a +# path to a local directory. +# +# Args: +# repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``), +# if ``source = 'github'``; or a path to a local directory, if +# ``source = 'local'``. +# model (string): the name of a callable (entrypoint) defined in the +# repo/dir's ``hubconf.py``. +# *args (optional): the corresponding args for callable :attr:`model`. +# source (string, optional): ``'github'`` | ``'local'``. Specifies how +# ``repo_or_dir`` is to be interpreted. Default is ``'github'``. +# force_reload (bool, optional): whether to force a fresh download of +# the github repo unconditionally. Does not have any effect if +# ``source = 'local'``. Default is ``False``. +# verbose (bool, optional): If ``False``, mute messages about hitting +# local caches. Note that the message about first download cannot be +# muted. Does not have any effect if ``source = 'local'``. +# Default is ``True``. +# **kwargs (optional): the corresponding kwargs for callable +# :attr:`model`. +# +# Returns: +# The output of the :attr:`model` callable when called with the given +# ``*args`` and ``**kwargs``. +# +# Example: +# >>> # from a github repo +# >>> repo = 'pytorch/vision' +# >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) +# >>> # from a local directory +# >>> path = '/some/local/path/pytorch/vision' +# >>> model = torch.hub.load(path, 'resnet50', pretrained=True) +# """ +# source = kwargs.pop('source', 'github').lower() +# force_reload = kwargs.pop('force_reload', False) +# verbose = kwargs.pop('verbose', True) +# +# if source not in ('github', 'local'): +# raise ValueError( +# f'Unknown source: "{source}". Allowed values: "github" | "local".') +# +# if source == 'github': +# repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose) +# +# model = _load_local(repo_or_dir, model, *args, **kwargs) +# return model +# +# +# def _load_local(hubconf_dir, model, *args, **kwargs): +# r""" +# Load a model from a local directory with a ``hubconf.py``. +# +# Args: +# hubconf_dir (string): path to a local directory that contains a +# ``hubconf.py``. +# model (string): name of an entrypoint defined in the directory's +# `hubconf.py`. +# *args (optional): the corresponding args for callable ``model``. +# **kwargs (optional): the corresponding kwargs for callable ``model``. +# +# Returns: +# a single model with corresponding pretrained weights. +# +# Example: +# >>> path = '/some/local/path/pytorch/vision' +# >>> model = _load_local(path, 'resnet50', pretrained=True) +# """ +# sys.path.insert(0, hubconf_dir) +# +# hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) +# hub_module = import_module(MODULE_HUBCONF, hubconf_path) +# +# entry = _load_entry_from_hubconf(hub_module, model) +# model = entry(*args, **kwargs) +# +# sys.path.remove(hubconf_dir) +# +# return model +# +# +# def download_url_to_file(url, dst, hash_prefix=None, progress=True): +# r"""Download object at the given URL to a local path. +# +# Args: +# url (string): URL of the object to download +# dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` +# hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. +# Default: None +# progress (bool, optional): whether or not to display a progress bar to stderr +# Default: True +# +# Example: +# >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') +# +# """ +# file_size = None +# # We use a different API for python2 since urllib(2) doesn't recognize the CA +# # certificates in older Python +# req = Request(url, headers={"User-Agent": "torch.hub"}) +# u = urlopen(req) +# meta = u.info() +# if hasattr(meta, 'getheaders'): +# content_length = meta.getheaders("Content-Length") +# else: +# content_length = meta.get_all("Content-Length") +# if content_length is not None and len(content_length) > 0: +# file_size = int(content_length[0]) +# +# # We deliberately save it in a temp file and move it after +# # download is complete. This prevents a local working checkpoint +# # being overridden by a broken download. +# dst = os.path.expanduser(dst) +# dst_dir = os.path.dirname(dst) +# f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) +# +# try: +# if hash_prefix is not None: +# sha256 = hashlib.sha256() +# with tqdm(total=file_size, disable=not progress, +# unit='B', unit_scale=True, unit_divisor=1024) as pbar: +# while True: +# buffer = u.read(8192) +# if len(buffer) == 0: +# break +# f.write(buffer) +# if hash_prefix is not None: +# sha256.update(buffer) +# pbar.update(len(buffer)) +# +# f.close() +# if hash_prefix is not None: +# digest = sha256.hexdigest() +# if digest[:len(hash_prefix)] != hash_prefix: +# raise RuntimeError('invalid hash value (expected "{}", got "{}")' +# .format(hash_prefix, digest)) +# shutil.move(f.name, dst) +# finally: +# f.close() +# if os.path.exists(f.name): +# os.remove(f.name) +# +# def _download_url_to_file(url, dst, hash_prefix=None, progress=True): +# warnings.warn('torch.hub._download_url_to_file has been renamed to\ +# torch.hub.download_url_to_file to be a public API,\ +# _download_url_to_file will be removed in after 1.3 release') +# download_url_to_file(url, dst, hash_prefix, progress) +# +# # Hub used to support automatically extracts from zipfile manually compressed by users. +# # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. +# # We should remove this support since zipfile is now default zipfile format for torch.save(). +# def _is_legacy_zip_format(filename): +# if zipfile.is_zipfile(filename): +# infolist = zipfile.ZipFile(filename).infolist() +# return len(infolist) == 1 and not infolist[0].is_dir() +# return False +# +# def _legacy_zip_load(filename, model_dir, map_location): +# warnings.warn('Falling back to the old format < 1.6. This support will be ' +# 'deprecated in favor of default zipfile format introduced in 1.6. ' +# 'Please redo torch.save() to save it in the new zipfile format.') +# # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. +# # We deliberately don't handle tarfile here since our legacy serialization format was in tar. +# # E.g. resnet18-5c106cde.pth which is widely used. +# with zipfile.ZipFile(filename) as f: +# members = f.infolist() +# if len(members) != 1: +# raise RuntimeError('Only one file(not dir) is allowed in the zipfile') +# f.extractall(model_dir) +# extraced_name = members[0].filename +# extracted_file = os.path.join(model_dir, extraced_name) +# return torch.load(extracted_file, map_location=map_location) +# +# def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): +# r"""Loads the Torch serialized object at the given URL. +# +# If downloaded file is a zip file, it will be automatically +# decompressed. +# +# If the object is already present in `model_dir`, it's deserialized and +# returned. +# The default value of `model_dir` is ``/checkpoints`` where +# `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. +# +# Args: +# url (string): URL of the object to download +# model_dir (string, optional): directory in which to save the object +# map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) +# progress (bool, optional): whether or not to display a progress bar to stderr. +# Default: True +# check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention +# ``filename-.ext`` where ```` is the first eight or more +# digits of the SHA256 hash of the contents of the file. The hash is used to +# ensure unique names and to verify the contents of the file. +# Default: False +# file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. +# +# Example: +# >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') +# +# """ +# # Issue warning to move data if old env is set +# if os.getenv('TORCH_MODEL_ZOO'): +# warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') +# +# if model_dir is None: +# hub_dir = get_dir() +# model_dir = os.path.join(hub_dir, 'checkpoints') +# +# try: +# os.makedirs(model_dir) +# except OSError as e: +# if e.errno == errno.EEXIST: +# # Directory already exists, ignore. +# pass +# else: +# # Unexpected OSError, re-raise. +# raise +# +# parts = urlparse(url) +# filename = os.path.basename(parts.path) +# if file_name is not None: +# filename = file_name +# cached_file = os.path.join(model_dir, filename) +# if not os.path.exists(cached_file): +# sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) +# hash_prefix = None +# if check_hash: +# r = HASH_REGEX.search(filename) # r is Optional[Match[str]] +# hash_prefix = r.group(1) if r else None +# download_url_to_file(url, cached_file, hash_prefix, progress=progress) +# +# if _is_legacy_zip_format(cached_file): +# return _legacy_zip_load(cached_file, model_dir, map_location) +# return torch.load(cached_file, map_location=map_location) diff --git a/datasets/tods_dataset_base.py b/datasets/tods_dataset_base.py new file mode 100755 index 0000000..82b1982 --- /dev/null +++ b/datasets/tods_dataset_base.py @@ -0,0 +1,139 @@ +import warnings +import os +import os.path +import numpy as np +import codecs +import string +import gzip +import lzma +from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union +from dataset_utils import download_url, download_and_extract_archive, extract_archive, verify_str_arg + +# tqdm >= 4.31.1 + +from tods import generate_dataset +from sklearn import preprocessing +import pandas as pd + +class TODS_dataset: + resources = [] + training_file = None + testing_file = None + ground_truth_index = None + _repr_indent = None + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, 'raw') + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, 'processed') + + def __init__(self, root, train, transform=None, download=True): + + self.root = root + self.train = train + self.transform = self.transform_init(transform) + + if download: + self.download() + pass + + self.process() + + + def _check_exists(self) -> bool: + return (os.path.exists(os.path.join(self.processed_folder, + self.training_file)) and + os.path.exists(os.path.join(self.processed_folder, + self.testing_file))) + + + def download(self) -> None: + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + # download files + for url, md5 in self.resources: + filename = url.rpartition('/')[2] + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + + + def process(self) -> None: + + pass + + + def process_dataframe(self) -> None: + + if self.transform is None: + pass + + else: + self.transform.fit(self.training_set_dataframe) + self.training_set_array = self.transform.transform(self.training_set_dataframe.values) + self.testing_set_array = self.transform.transform(self.testing_set_dataframe.values) + self.training_set_dataframe = pd.DataFrame(self.training_set_array) + self.testing_set_dataframe = pd.DataFrame(self.testing_set_array) + + + def transform_init(self, transform_str): + + if transform_str is None: + return None + elif transform_str == 'standardscale': + return preprocessing.StandardScaler() + elif transform_str == 'normalize': + return preprocessing.Normalizer() + elif transform_str == 'minmaxscale': + return preprocessing.MinMaxScaler() + elif transform_str == 'maxabsscale': + return preprocessing.MaxAbsScaler() + elif transform_str == 'binarize': + return preprocessing.Binarizer() + else: + raise ValueError("Input parameter transform must take value of 'standardscale', 'normalize', " + + "'minmaxscale', 'maxabsscale' or 'binarize'." + ) + + + def to_axolotl_dataset(self): + if self.train: + return generate_dataset(self.training_set_dataframe, self.ground_truth_index) + else: + return generate_dataset(self.testing_set_dataframe, self.ground_truth_index) + + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = ["Number of datapoints: {}".format(self.__len__())] + if self.root is not None: + body.append("Root location: {}".format(self.root)) + body += self.extra_repr().splitlines() + if hasattr(self, "transforms") and self.transforms is not None: + body += [repr(self.transforms)] + lines = [head] + [" " * self._repr_indent + line for line in body] + + print(self.training_set_dataframe) + + return '\n'.join(lines) + + + def __len__(self) -> int: + return len(self.training_set_dataframe) + + + def extra_repr(self) -> str: + return "" + + +# kpi(root='./datasets', train=True) + +# class yahoo5: +# +# def __init__(self): +# pass \ No newline at end of file diff --git a/datasets/tods_datasets.py b/datasets/tods_datasets.py new file mode 100755 index 0000000..e90efc6 --- /dev/null +++ b/datasets/tods_datasets.py @@ -0,0 +1,116 @@ +import os +import pandas as pd + +from tods_dataset_base import TODS_dataset +from shutil import copyfile + +class kpi_dataset(TODS_dataset): + resources = [ + # ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + # ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), + # ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), + ("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. + ("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/datasetDoc.json", None), + # needs a server to store the dataset. + # ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. + ] + + training_file = 'learningData.csv' + testing_file = 'testingData.csv' + ground_truth_index = 3 + _repr_indent = 4 + + # def __init__(self, root, train, transform=None, target_transform=None, download=True): + # super().__init__(root, train, transform=None, target_transform=None, download=True) + + def process(self) -> None: + + print('Processing...') + + os.makedirs(self.processed_folder, exist_ok=True) + os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) + + training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') + self.training_set_dataframe = pd.read_csv(training_set_fname) + testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set + self.testing_set_dataframe = pd.read_csv(testing_set_fname) + + self.process_dataframe() + self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) + self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) + copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) + + print('Done!') + + +class yahoo_dataset(TODS_dataset): + resources = [ + # ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + # ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), + # ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), + ("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. + ("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/datasetDoc.json", None), + # needs a server to store the dataset. + # ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. + ] + + training_file = 'learningData.csv' + testing_file = 'testingData.csv' + ground_truth_index = 7 + _repr_indent = 4 + + def process(self) -> None: + + print('Processing...') + + os.makedirs(self.processed_folder, exist_ok=True) + os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) + + training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') + self.training_set_dataframe = pd.read_csv(training_set_fname) + testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set + self.testing_set_dataframe = pd.read_csv(testing_set_fname) + + self.process_dataframe() + self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) + self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) + copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) + + print('Done!') + + +class NAB_dataset(TODS_dataset): + resources = [ + ("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.csv", None), + # it needs md5 to check if local learningData.csv is the same with online. + ("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.json", None), + # needs a server to store the dataset. + ] + + training_file = 'learningData.csv' + testing_file = 'testingData.csv' + ground_truth_index = 2 + _repr_indent = 4 + + def process(self) -> None: + print('Processing...') + + os.makedirs(self.processed_folder, exist_ok=True) + os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) + + training_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') + self.training_set_dataframe = pd.read_csv(training_set_fname) + testing_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') # temperarily same with training set + self.testing_set_dataframe = pd.read_csv(testing_set_fname) + + self.process_dataframe() + self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) + self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) + copyfile(os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.json'), + os.path.join(self.processed_folder, 'datasetDoc.json')) + + print('Done!') + +# kpi_dataset(root='./datasets', train=True, transform='binarize') +# yahoo_dataset(root='./datasets', train=True, transform='binarize') +# NAB_dataset(root='./datasets', train=True, transform='binarize') diff --git a/src/axolotl b/src/axolotl deleted file mode 160000 index af54e69..0000000 --- a/src/axolotl +++ /dev/null @@ -1 +0,0 @@ -Subproject commit af54e6970476a081bf0cd65990c9f56a1200d8a2 diff --git a/src/common-primitives b/src/common-primitives deleted file mode 160000 index 046b20d..0000000 --- a/src/common-primitives +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 046b20d2f6d4543dcbe18f0a1d4bcbb1f61cf518 diff --git a/src/d3m b/src/d3m deleted file mode 160000 index 70aeefe..0000000 --- a/src/d3m +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 70aeefed6b7307941581357c4b7858bb3f88e1da diff --git a/tods/common/FixedSplit.py b/tods/common/FixedSplit.py new file mode 100644 index 0000000..3fc17c9 --- /dev/null +++ b/tods/common/FixedSplit.py @@ -0,0 +1,116 @@ +import os +import typing + +import numpy +import pandas + +from d3m import container, exceptions, utils as d3m_utils +from d3m.metadata import base as metadata_base, hyperparams +from d3m.base import primitives + +__all__ = ('FixedSplitDatasetSplitPrimitive',) + + +class Hyperparams(hyperparams.Hyperparams): + primary_index_values = hyperparams.Set( + elements=hyperparams.Hyperparameter[str](''), + default=(), + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description='A set of primary index values of the main resource belonging to the test (score) split. Cannot be set together with "row_indices".', + ) + row_indices = hyperparams.Set( + elements=hyperparams.Hyperparameter[int](-1), + default=(), + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description='A set of row indices of the main resource belonging to the test (score) split. Cannot be set together with "primary_index_values".', + ) + delete_recursive = hyperparams.Hyperparameter[bool]( + default=False, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", + ) + + +class FixedSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): + """ + A primitive which splits a tabular Dataset in a way that uses for the test + (score) split a fixed list of primary index values or row indices of the main + resource to be used. All other rows are added used for the train split. + """ + + metadata = metadata_base.PrimitiveMetadata( + { + 'id': '1654f000-2178-4520-be4c-a95bc26b8d3a', + 'version': '0.1.0', + 'name': "Fixed split tabular dataset splits", + 'python_path': 'd3m.primitives.tods.evaluation.fixed_split_dataset_split', + 'source': { + 'name': "DATALab@TexasA&M University", + 'contact': 'mailto:mitar.commonprimitives@tnode.com', + 'uris': [ + 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/fixed_split.py', + 'https://gitlab.com/datadrivendiscovery/common-primitives.git', + ], + }, + 'algorithm_types': [ + metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, + ], + 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, + }, + ) + + def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: + # This should be handled by "Set" hyper-parameter, but we check it here again just to be sure. + if d3m_utils.has_duplicates(self.hyperparams['primary_index_values']): + raise exceptions.InvalidArgumentValueError("\"primary_index_values\" hyper-parameter has duplicate values.") + if d3m_utils.has_duplicates(self.hyperparams['row_indices']): + raise exceptions.InvalidArgumentValueError("\"row_indices\" hyper-parameter has duplicate values.") + + if self.hyperparams['primary_index_values'] and self.hyperparams['row_indices']: + raise exceptions.InvalidArgumentValueError("Both \"primary_index_values\" and \"row_indices\" cannot be provided.") + + if self.hyperparams['primary_index_values']: + primary_index_values = numpy.array(self.hyperparams['primary_index_values']) + + index_columns = dataset.metadata.get_index_columns(at=(main_resource_id,)) + + if not index_columns: + raise exceptions.InvalidArgumentValueError("Cannot find index columns in the main resource of the dataset, but \"primary_index_values\" is provided.") + + main_resource = dataset[main_resource_id] + # We reset the index so that the index corresponds to row indices. + main_resource = main_resource.reset_index(drop=True) + + # We use just the "d3mIndex" column and ignore multi-key indices. + # This works for now because it seems that every current multi-key + # dataset in fact has an unique value in "d3mIndex" alone. + # See: https://gitlab.datadrivendiscovery.org/MIT-LL/d3m_data_supply/issues/117 + index_column = index_columns[0] + + score_data = numpy.array(main_resource.loc[main_resource.iloc[:, index_column].isin(primary_index_values)].index) + score_data_set = set(score_data) + + assert len(score_data) == len(score_data_set), (len(score_data), len(score_data_set)) + + if len(score_data) != len(primary_index_values): + raise exceptions.InvalidArgumentValueError("\"primary_index_values\" contains values which do not exist.") + + else: + score_data = numpy.array(self.hyperparams['row_indices']) + score_data_set = set(score_data) + + all_data_set = set(numpy.arange(len(attributes))) + + if not score_data_set <= all_data_set: + raise exceptions.InvalidArgumentValueError("\"row_indices\" contains indices which do not exist, e.g., {indices}.".format( + indices=sorted(score_data_set - all_data_set)[:5], + )) + + train_data = [] + for i in numpy.arange(len(attributes)): + if i not in score_data_set: + train_data.append(i) + + assert len(train_data) + len(score_data) == len(attributes), (len(train_data), len(score_data), len(attributes)) + + return [(numpy.array(train_data), score_data)] diff --git a/tods/common/KFoldSplit.py b/tods/common/KFoldSplit.py new file mode 100644 index 0000000..283024d --- /dev/null +++ b/tods/common/KFoldSplit.py @@ -0,0 +1,87 @@ +import os +import typing + +import numpy +import pandas +from sklearn import model_selection + +from d3m import container, exceptions, utils as d3m_utils +from d3m.metadata import base as metadata_base, hyperparams +from d3m.base import primitives + + +__all__ = ('KFoldDatasetSplitPrimitive',) + + +class Hyperparams(hyperparams.Hyperparams): + number_of_folds = hyperparams.Bounded[int]( + lower=2, + upper=None, + default=5, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Number of folds for k-folds cross-validation.", + ) + stratified = hyperparams.UniformBool( + default=False, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", + ) + shuffle = hyperparams.UniformBool( + default=False, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Whether to shuffle the data before splitting into batches.", + ) + delete_recursive = hyperparams.Hyperparameter[bool]( + default=False, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", + ) + + +class KFoldDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): + """ + A primitive which splits a tabular Dataset for k-fold cross-validation. + """ + + __author__ = 'Mingjie Sun ' + metadata = metadata_base.PrimitiveMetadata( + { + 'id': 'bfedaf3a-6dd0-4a83-ad83-3a50fe882bf8', + 'version': '0.1.0', + 'name': "K-fold cross-validation tabular dataset splits", + 'python_path': 'd3m.primitives.tods.evaluation.kfold_dataset_split', + 'source': { + 'name': 'DATALab@Texas A&M University', + 'contact': 'mailto:sunmj15@gmail.com', + 'uris': [ + 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split.py', + 'https://gitlab.com/datadrivendiscovery/common-primitives.git', + ], + }, + 'algorithm_types': [ + metadata_base.PrimitiveAlgorithmType.K_FOLD, + metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, + metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, + ], + 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, + }, + ) + + def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: + if self.hyperparams['stratified']: + if not len(targets.columns): + raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") + + k_fold = model_selection.StratifiedKFold( + n_splits=self.hyperparams['number_of_folds'], + shuffle=self.hyperparams['shuffle'], + random_state=self._random_state, + ) + else: + k_fold = model_selection.KFold( + n_splits=self.hyperparams['number_of_folds'], + shuffle=self.hyperparams['shuffle'], + random_state=self._random_state, + ) + + return list(k_fold.split(attributes, targets)) diff --git a/tods/common/KFoldSplitTimeseries.py b/tods/common/KFoldSplitTimeseries.py new file mode 100644 index 0000000..56b609e --- /dev/null +++ b/tods/common/KFoldSplitTimeseries.py @@ -0,0 +1,187 @@ +import os +import uuid +import typing +from collections import OrderedDict + +import numpy +import pandas +from sklearn import model_selection + +from d3m import container, exceptions, utils as d3m_utils +from d3m.metadata import base as metadata_base, hyperparams +from d3m.base import primitives + +import utils + +__all__ = ('KFoldTimeSeriesSplitPrimitive',) + + +class Hyperparams(hyperparams.Hyperparams): + number_of_folds = hyperparams.Bounded[int]( + lower=2, + upper=None, + default=5, + semantic_types=[ + 'https://metadata.datadrivendiscovery.org/types/ControlParameter' + ], + description="Number of folds for k-folds cross-validation.", + ) + number_of_window_folds = hyperparams.Union[typing.Union[int, None]]( + configuration=OrderedDict( + fixed=hyperparams.Bounded[int]( + lower=1, + upper=None, + default=1, + description="Number of folds in train set (window). These folds come directly " + "before test set (streaming window).", + ), + all_records=hyperparams.Constant( + default=None, + description="Number of folds in train set (window) = maximum number possible.", + ), + ), + default='all_records', + semantic_types=[ + 'https://metadata.datadrivendiscovery.org/types/ControlParameter' + ], + description="Maximum size for a single training set.", + ) + time_column_index = hyperparams.Union[typing.Union[int, None]]( + configuration=OrderedDict( + fixed=hyperparams.Bounded[int]( + lower=1, + upper=None, + default=1, + description="Specific column that contains the time index", + ), + one_column=hyperparams.Constant( + default=None, + description="Only one column contains a time index. " + "It is detected automatically using semantic types.", + ), + ), + default='one_column', + semantic_types=[ + 'https://metadata.datadrivendiscovery.org/types/ControlParameter' + ], + description="Column index to use as datetime index. " + "If None, it is required that only one column with time column role semantic type is " + "present and otherwise an exception is raised. " + "If column index specified is not a datetime column an exception is" + "also raised.", + ) + fuzzy_time_parsing = hyperparams.UniformBool( + default=True, + semantic_types=[ + 'https://metadata.datadrivendiscovery.org/types/ControlParameter' + ], + description="Use fuzzy time parsing.", + ) + + +class KFoldTimeSeriesSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): + """ + A primitive which splits a tabular time-series Dataset for k-fold cross-validation. + + Primitive sorts the time column so care should be taken to assure sorting of a + column is reasonable. E.g., if column is not numeric but of string structural type, + strings should be formatted so that sorting by them also sorts by time. + """ + + __author__ = 'Distil' + __version__ = '0.3.0' + __contact__ = 'mailto:jeffrey.gleason@yonder.co' + + metadata = metadata_base.PrimitiveMetadata( + { + 'id': '002f9ad1-46e3-40f4-89ed-eeffbb3a102b', + 'version': __version__, + 'name': "K-fold cross-validation timeseries dataset splits", + 'python_path': 'd3m.primitives.tods.evaluation.kfold_time_series_split', + 'source': { + 'name': 'DATALab@Texas A&M University', + 'contact': __contact__, + 'uris': [ + 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split_timeseries.py', + 'https://gitlab.com/datadrivendiscovery/common-primitives.git', + ], + }, + 'algorithm_types': [ + metadata_base.PrimitiveAlgorithmType.K_FOLD, + metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, + metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, + ], + 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, + }, + ) + + def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: + time_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Time'], at=(main_resource_id,)) + attribute_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Attribute'], at=(main_resource_id,)) + + # We want only time columns which are also attributes. + time_column_indices = [time_column_index for time_column_index in time_column_indices if time_column_index in attribute_column_indices] + + if self.hyperparams['time_column_index'] is None: + if len(time_column_indices) != 1: + raise exceptions.InvalidArgumentValueError( + "If \"time_column_index\" hyper-parameter is \"None\", it is required that exactly one column with time column role semantic type is present.", + ) + else: + # We know it exists because "time_column_indices" is a subset of "attribute_column_indices". + time_column_index = attribute_column_indices.index( + time_column_indices[0], + ) + else: + if self.hyperparams['time_column_index'] not in time_column_indices: + raise exceptions.InvalidArgumentValueError( + "Time column index specified does not have a time column role semantic type.", + ) + else: + time_column_index = attribute_column_indices.index( + self.hyperparams['time_column_index'], + ) + + # We first reset index. + attributes = attributes.reset_index(drop=True) + + # Then convert datetime column to consistent datetime representation + attributes.insert( + loc=0, + column=uuid.uuid4(), # use uuid to ensure we are inserting a new column name + value=self._parse_time_data( + attributes, time_column_index, self.hyperparams['fuzzy_time_parsing'], + ), + ) + + # Then sort dataframe by new datetime column. Index contains original row order. + attributes = attributes.sort_values(by=attributes.columns[0]) + + # Remove datetime representation used for sorting (primitives might choose to parse this str col differently). + attributes = attributes.drop(attributes.columns[0], axis=1) + + max_train_size: typing.Optional[int] = None + if self.hyperparams['number_of_window_folds'] is not None: + max_train_size = int(attributes.shape[0] * self.hyperparams['number_of_window_folds'] / self.hyperparams['number_of_folds']) + + k_fold = model_selection.TimeSeriesSplit( + n_splits=self.hyperparams['number_of_folds'], + max_train_size=max_train_size + ) + + # We sorted "attributes" so we have to map indices on sorted "attributes" back to original + # indices. We do that by using DataFrame's index which contains original row order. + return [ + ( + numpy.array([attributes.index[val] for val in train]), + numpy.array([attributes.index[val] for val in test]), + ) + for train, test in k_fold.split(attributes) + ] + + @classmethod + def _parse_time_data(cls, inputs: container.DataFrame, column_index: metadata_base.SimpleSelectorSegment, fuzzy: bool) -> typing.List[float]: + return [ + utils.parse_datetime_to_float(value, fuzzy=fuzzy) + for value in inputs.iloc[:, column_index] + ] diff --git a/tods/common/NoSplit.py b/tods/common/NoSplit.py new file mode 100644 index 0000000..637312d --- /dev/null +++ b/tods/common/NoSplit.py @@ -0,0 +1,52 @@ +import os +import typing + +import numpy +import pandas + +from d3m import container, utils as d3m_utils +from d3m.metadata import base as metadata_base, hyperparams +from d3m.base import primitives + + +__all__ = ('NoSplitDatasetSplitPrimitive',) + + +class Hyperparams(hyperparams.Hyperparams): + pass + + +class NoSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): + """ + A primitive which splits a tabular Dataset in a way that for all splits it + produces the same (full) Dataset. Useful for unsupervised learning tasks. . + """ + + metadata = metadata_base.PrimitiveMetadata( + { + 'id': '48c683ad-da9e-48cf-b3a0-7394dba5e5d2', + 'version': '0.1.0', + 'name': "No-split tabular dataset splits", + 'python_path': 'd3m.primitives.tods.evaluation.no_split_dataset_split', + 'source': { + 'name': 'DATALab@Texas A&M University', + 'contact': 'mailto:mitar.commonprimitives@tnode.com', + 'uris': [ + 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/no_split.py', + 'https://gitlab.com/datadrivendiscovery/common-primitives.git', + ], + }, + 'algorithm_types': [ + metadata_base.PrimitiveAlgorithmType.IDENTITY_FUNCTION, + metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, + ], + 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, + }, + ) + + def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: + # We still go through the whole splitting process to assure full compatibility + # (and error conditions) of a regular split, but we use all data for both splits. + all_data = numpy.arange(len(attributes)) + + return [(all_data, all_data)] diff --git a/tods/common/RedactColumns.py b/tods/common/RedactColumns.py new file mode 100644 index 0000000..1965e84 --- /dev/null +++ b/tods/common/RedactColumns.py @@ -0,0 +1,160 @@ +import copy +import os +import typing + +from d3m import container, exceptions, utils as d3m_utils +from d3m.metadata import base as metadata_base, hyperparams +from d3m.primitive_interfaces import base, transformer + + +Inputs = container.List +Outputs = container.List + + +class Hyperparams(hyperparams.Hyperparams): + match_logic = hyperparams.Enumeration( + values=['all', 'any'], + default='any', + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Should a column have all of semantic types in \"semantic_types\" to be redacted, or any of them?", + ) + semantic_types = hyperparams.Set( + elements=hyperparams.Hyperparameter[str](''), + default=(), + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Redact columns with these semantic types. Only columns having semantic types listed here will be operated on, based on \"match_logic\".", + ) + add_semantic_types = hyperparams.Set( + elements=hyperparams.Hyperparameter[str](''), + default=(), + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Semantic types to add to redacted columns. All listed semantic types will be added to all columns which were redacted.", + ) + + +# TODO: Make clear the assumption that both container type (List) and Datasets should have metadata. +# Primitive is modifying metadata of Datasets, while there is officially no reason for them +# to really have metadata: metadata is stored available on the input container type, not +# values inside it. +class RedactColumnsPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]): + """ + A primitive which takes as an input a list of ``Dataset`` objects and redacts values of all columns matching + a given semantic type or types. + + Redaction is done by setting all values in a redacted column to an empty string. + + It operates only on DataFrame resources inside datasets. + """ + + metadata = metadata_base.PrimitiveMetadata( + { + 'id': '744c4090-e2f6-489e-8efc-8b1e051bfad6', + 'version': '0.2.0', + 'name': "Redact columns for evaluation", + 'python_path': 'd3m.primitives.tods.evaluation.redact_columns', + 'source': { + 'name': 'DATALab@Texas A&M University', + 'contact': 'mailto:sunmj15@gmail.com', + 'uris': [ + 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/redact_columns.py', + 'https://gitlab.com/datadrivendiscovery/common-primitives.git', + ], + }, + 'installation': [{ + 'type': metadata_base.PrimitiveInstallationType.PIP, + 'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( + git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), + ), + }], + 'algorithm_types': [ + metadata_base.PrimitiveAlgorithmType.DATA_CONVERSION, + ], + 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, + }, + ) + + def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]: + output_datasets = container.List(generate_metadata=True) + + for dataset in inputs: + resources = {} + metadata = dataset.metadata + + for resource_id, resource in dataset.items(): + if not isinstance(resource, container.DataFrame): + resources[resource_id] = resource + continue + + columns_to_redact = self._get_columns_to_redact(metadata, (resource_id,)) + + if not columns_to_redact: + resources[resource_id] = resource + continue + + resource = copy.copy(resource) + + for column_index in columns_to_redact: + column_metadata = dataset.metadata.query((resource_id, metadata_base.ALL_ELEMENTS, column_index)) + if 'structural_type' in column_metadata and issubclass(column_metadata['structural_type'], str): + resource.iloc[:, column_index] = '' + else: + raise TypeError("Primitive can operate only on columns with structural type \"str\", not \"{type}\".".format( + type=column_metadata.get('structural_type', None), + )) + + metadata = self._update_metadata(metadata, resource_id, column_index, ()) + + resources[resource_id] = resource + + dataset = container.Dataset(resources, metadata) + + output_datasets.append(dataset) + + output_datasets.metadata = metadata_base.DataMetadata({ + 'schema': metadata_base.CONTAINER_SCHEMA_VERSION, + 'structural_type': container.List, + 'dimension': { + 'length': len(output_datasets), + }, + }) + + # We update metadata based on metadata of each dataset. + # TODO: In the future this might be done automatically by generate_metadata. + # See: https://gitlab.com/datadrivendiscovery/d3m/issues/119 + for index, dataset in enumerate(output_datasets): + output_datasets.metadata = dataset.metadata.copy_to(output_datasets.metadata, (), (index,)) + + return base.CallResult(output_datasets) + + def _get_columns_to_redact(self, inputs_metadata: metadata_base.DataMetadata, at: metadata_base.Selector) -> typing.Sequence[int]: + columns = [] + + for element in inputs_metadata.get_elements(list(at) + [metadata_base.ALL_ELEMENTS]): + semantic_types = inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS, element]).get('semantic_types', ()) + + # TODO: Should we handle inheritance between semantic types here? + if self.hyperparams['match_logic'] == 'all': + matched = all(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) + elif self.hyperparams['match_logic'] == 'any': + matched = any(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) + else: + raise exceptions.UnexpectedValueError("Unknown value of hyper-parameter \"match_logic\": {value}".format(value=self.hyperparams['match_logic'])) + + if matched: + if element is metadata_base.ALL_ELEMENTS: + return list(range(inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS]).get('dimension', {}).get('length', 0))) + else: + columns.append(typing.cast(int, element)) + + return columns + + def _update_metadata( + self, inputs_metadata: metadata_base.DataMetadata, resource_id: metadata_base.SelectorSegment, + column_index: int, at: metadata_base.Selector, + ) -> metadata_base.DataMetadata: + outputs_metadata = inputs_metadata + + for semantic_type in self.hyperparams['add_semantic_types']: + outputs_metadata = outputs_metadata.add_semantic_type(tuple(at) + (resource_id, metadata_base.ALL_ELEMENTS, column_index), semantic_type) + + return outputs_metadata diff --git a/tods/common/TrainScoreSplit.py b/tods/common/TrainScoreSplit.py new file mode 100644 index 0000000..00a71da --- /dev/null +++ b/tods/common/TrainScoreSplit.py @@ -0,0 +1,88 @@ +import os +import typing + +import numpy +import pandas +from sklearn import model_selection + +from d3m import container, exceptions, utils as d3m_utils +from d3m.metadata import base as metadata_base, hyperparams +from d3m.base import primitives + + +__all__ = ('TrainScoreDatasetSplitPrimitive',) + + +class Hyperparams(hyperparams.Hyperparams): + train_score_ratio = hyperparams.Uniform( + lower=0, + upper=1, + default=0.75, + upper_inclusive=True, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="The ratio between the train and score data and represents the proportion of the Dataset to include in the train split. The rest is included in the score split.", + ) + stratified = hyperparams.UniformBool( + default=False, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", + ) + shuffle = hyperparams.UniformBool( + default=False, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Whether to shuffle the data before splitting into batches.", + ) + delete_recursive = hyperparams.Hyperparameter[bool]( + default=False, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", + ) + + +class TrainScoreDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): + """ + A primitive which splits a tabular Dataset into random train and score subsets. + """ + + metadata = metadata_base.PrimitiveMetadata( + { + 'id': '3fcc6dc4-6681-4c86-948e-066d14e7d803', + 'version': '0.1.0', + 'name': "Train-score tabular dataset splits", + 'python_path': 'd3m.primitives.tods.evaluation.train_score_dataset_split', + 'source': { + 'name': 'DATALab@Texas A&M University', + 'contact': 'mailto:mitar.commonprimitives@tnode.com', + 'uris': [ + 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/train_score_split.py', + 'https://gitlab.com/datadrivendiscovery/common-primitives.git', + ], + }, + 'installation': [{ + 'type': metadata_base.PrimitiveInstallationType.PIP, + 'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( + git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), + ), + }], + 'algorithm_types': [ + metadata_base.PrimitiveAlgorithmType.HOLDOUT, + metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, + ], + 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, + }, + ) + + def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: + if self.hyperparams['stratified'] and not len(targets.columns): + raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") + + train_data, score_data = model_selection.train_test_split( + numpy.arange(len(attributes)), + test_size=None, + train_size=self.hyperparams['train_score_ratio'], + random_state=self._random_state, + shuffle=self.hyperparams['shuffle'], + stratify=targets if self.hyperparams['stratified'] else None, + ) + + return [(train_data, score_data)] diff --git a/tods/common/__init__.py b/tods/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tods/common/utils.py b/tods/common/utils.py new file mode 100644 index 0000000..bc6879c --- /dev/null +++ b/tods/common/utils.py @@ -0,0 +1,192 @@ +import datetime +import logging +import typing + +import dateutil.parser +import numpy + +from d3m import container, deprecate +from d3m.base import utils as base_utils +from d3m.metadata import base as metadata_base + +logger = logging.getLogger(__name__) + +DEFAULT_DATETIME = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) + + +@deprecate.function(message="it should not be used anymore") +def copy_elements_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, + to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: + return source_metadata._copy_elements_metadata(target_metadata, list(from_selector), list(to_selector), [], ignore_all_elements) + + +@deprecate.function(message="use Metadata.copy_to method instead") +def copy_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, + to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: + return source_metadata.copy_to(target_metadata, from_selector, to_selector, ignore_all_elements=ignore_all_elements) + + +@deprecate.function(message="use DataFrame.select_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def select_columns(inputs: container.DataFrame, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, + source: typing.Any = None) -> container.DataFrame: + return inputs.select_columns(columns) + + +@deprecate.function(message="use DataMetadata.select_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def select_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, + source: typing.Any = None) -> metadata_base.DataMetadata: + return inputs_metadata.select_columns(columns) + + +@deprecate.function(message="use DataMetadata.list_columns_with_semantic_types method instead") +def list_columns_with_semantic_types(metadata: metadata_base.DataMetadata, semantic_types: typing.Sequence[str], *, + at: metadata_base.Selector = ()) -> typing.Sequence[int]: + return metadata.list_columns_with_semantic_types(semantic_types, at=at) + + +@deprecate.function(message="use DataMetadata.list_columns_with_structural_types method instead") +def list_columns_with_structural_types(metadata: metadata_base.DataMetadata, structural_types: typing.Union[typing.Callable, typing.Sequence[typing.Union[str, type]]], *, + at: metadata_base.Selector = ()) -> typing.Sequence[int]: + return metadata.list_columns_with_structural_types(structural_types, at=at) + + +@deprecate.function(message="use DataFrame.remove_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def remove_columns(inputs: container.DataFrame, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> container.DataFrame: + return inputs.remove_columns(column_indices) + + +@deprecate.function(message="use DataMetadata.remove_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def remove_columns_metadata(inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: + return inputs_metadata.remove_columns(column_indices) + + +@deprecate.function(message="use DataFrame.append_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def append_columns(left: container.DataFrame, right: container.DataFrame, *, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: + return left.append_columns(right, use_right_metadata=use_right_metadata) + + +@deprecate.function(message="use DataMetadata.append_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def append_columns_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: + return left_metadata.append_columns(right_metadata, use_right_metadata=use_right_metadata) + + +@deprecate.function(message="use DataFrame.insert_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def insert_columns(inputs: container.DataFrame, columns: container.DataFrame, at_column_index: int, *, source: typing.Any = None) -> container.DataFrame: + return inputs.insert_columns(columns, at_column_index) + + +@deprecate.function(message="use DataMetadata.insert_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def insert_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, at_column_index: int, *, source: typing.Any = None) -> metadata_base.DataMetadata: + return inputs_metadata.insert_columns(columns_metadata, at_column_index) + + +@deprecate.function(message="use DataFrame.replace_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def replace_columns(inputs: container.DataFrame, columns: container.DataFrame, column_indices: typing.Sequence[int], *, copy: bool = True, source: typing.Any = None) -> container.DataFrame: + return inputs.replace_columns(columns, column_indices, copy=copy) + + +@deprecate.function(message="use DataMetadata.replace_columns method instead") +@deprecate.arguments('source', message="argument ignored") +def replace_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: + return inputs_metadata.replace_columns(columns_metadata, column_indices) + + +@deprecate.function(message="use DataMetadata.get_index_columns method instead") +def get_index_columns(metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = ()) -> typing.Sequence[int]: + return metadata.get_index_columns(at=at) + + +@deprecate.function(message="use DataFrame.horizontal_concat method instead") +@deprecate.arguments('source', message="argument ignored") +def horizontal_concat(left: container.DataFrame, right: container.DataFrame, *, use_index: bool = True, + remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: + return left.horizontal_concat(right, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) + + +@deprecate.function(message="use DataMetadata.horizontal_concat method instead") +@deprecate.arguments('source', message="argument ignored") +def horizontal_concat_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, *, use_index: bool = True, + remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: + return left_metadata.horizontal_concat(right_metadata, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) + + +@deprecate.function(message="use d3m.base.utils.get_columns_to_use function instead") +def get_columns_to_use(metadata: metadata_base.DataMetadata, use_columns: typing.Sequence[int], exclude_columns: typing.Sequence[int], + can_use_column: typing.Callable) -> typing.Tuple[typing.List[int], typing.List[int]]: + return base_utils.get_columns_to_use(metadata, use_columns, exclude_columns, can_use_column) + + +@deprecate.function(message="use d3m.base.utils.combine_columns function instead") +@deprecate.arguments('source', message="argument ignored") +def combine_columns(return_result: str, add_index_columns: bool, inputs: container.DataFrame, column_indices: typing.Sequence[int], + columns_list: typing.Sequence[container.DataFrame], *, source: typing.Any = None) -> container.DataFrame: + return base_utils.combine_columns(inputs, column_indices, columns_list, return_result=return_result, add_index_columns=add_index_columns) + + +@deprecate.function(message="use d3m.base.utils.combine_columns_metadata function instead") +@deprecate.arguments('source', message="argument ignored") +def combine_columns_metadata(return_result: str, add_index_columns: bool, inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], + columns_metadata_list: typing.Sequence[metadata_base.DataMetadata], *, source: typing.Any = None) -> metadata_base.DataMetadata: + return base_utils.combine_columns_metadata(inputs_metadata, column_indices, columns_metadata_list, return_result=return_result, add_index_columns=add_index_columns) + + +@deprecate.function(message="use DataMetadata.set_table_metadata method instead") +@deprecate.arguments('source', message="argument ignored") +def set_table_metadata(inputs_metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = (), source: typing.Any = None) -> metadata_base.DataMetadata: + return inputs_metadata.set_table_metadata(at=at) + + +@deprecate.function(message="use DataMetadata.get_column_index_from_column_name method instead") +def get_column_index_from_column_name(inputs_metadata: metadata_base.DataMetadata, column_name: str, *, at: metadata_base.Selector = ()) -> int: + return inputs_metadata.get_column_index_from_column_name(column_name, at=at) + + +@deprecate.function(message="use Dataset.get_relations_graph method instead") +def build_relation_graph(dataset: container.Dataset) -> typing.Dict[str, typing.List[typing.Tuple[str, bool, int, int, typing.Dict]]]: + return dataset.get_relations_graph() + + +@deprecate.function(message="use d3m.base.utils.get_tabular_resource function instead") +def get_tabular_resource(dataset: container.Dataset, resource_id: typing.Optional[str], *, + pick_entry_point: bool = True, pick_one: bool = True, has_hyperparameter: bool = True) -> typing.Tuple[str, container.DataFrame]: + return base_utils.get_tabular_resource(dataset, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one, has_hyperparameter=has_hyperparameter) + + +@deprecate.function(message="use d3m.base.utils.get_tabular_resource_metadata function instead") +def get_tabular_resource_metadata(dataset_metadata: metadata_base.DataMetadata, resource_id: typing.Optional[metadata_base.SelectorSegment], *, + pick_entry_point: bool = True, pick_one: bool = True) -> metadata_base.SelectorSegment: + return base_utils.get_tabular_resource_metadata(dataset_metadata, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one) + + +@deprecate.function(message="use Dataset.select_rows method instead") +@deprecate.arguments('source', message="argument ignored") +def cut_dataset(dataset: container.Dataset, row_indices_to_keep: typing.Mapping[str, typing.Sequence[int]], *, + source: typing.Any = None) -> container.Dataset: + return dataset.select_rows(row_indices_to_keep) + + +def parse_datetime(value: str, *, fuzzy: bool = True) -> typing.Optional[datetime.datetime]: + try: + return dateutil.parser.parse(value, default=DEFAULT_DATETIME, fuzzy=fuzzy) + except (ValueError, OverflowError, TypeError): + return None + + +def parse_datetime_to_float(value: str, *, fuzzy: bool = True) -> float: + try: + parsed = parse_datetime(value, fuzzy=fuzzy) + if parsed is None: + return numpy.nan + else: + return parsed.timestamp() + except (ValueError, OverflowError, TypeError): + return numpy.nan