Browse Source

resolve construct prediction bug

Former-commit-id: b0d35d9aea [formerly 28b7582d25] [formerly 632f8af13a [formerly 9eb3583256]] [formerly ab6f5a8a9c [formerly 2720e70c78] [formerly c1f615a3c5 [formerly 21840b3787]]] [formerly 2184b60ad6 [formerly 2dae6c4ed4] [formerly 6a577865b8 [formerly ee28aaed84]] [formerly 72078d88c8 [formerly ab1fb9d512] [formerly 476600a47c [formerly f5a6205bb5]]]] [formerly e5c2c3deef [formerly 8585dc41a9] [formerly 9cc7fe2088 [formerly 1d2104316f]] [formerly b51e614a11 [formerly f9891a191d] [formerly 048aa2f114 [formerly d4de64574b]]] [formerly ff66f55a51 [formerly 0b691b9a7f] [formerly e64e8dd253 [formerly cc45939e01]] [formerly cdad10712a [formerly 2789e20e79] [formerly 25924f293c [formerly 32997accab]]]]] [formerly f43b431040 [formerly 95815d02ca] [formerly fe9bd45d44 [formerly 6daf0aa73e]] [formerly 61ab30c9a3 [formerly a13c6e23b4] [formerly 86fa5919ee [formerly 1e49e1a303]]] [formerly 69c5bc967a [formerly ab82915cbd] [formerly f8057c3b14 [formerly 5232f34578]] [formerly 671c54e952 [formerly 6454a28f26] [formerly 3db6ff66b9 [formerly aa8c7fe127]]]] [formerly 86b2ec6b84 [formerly f35c344efe] [formerly d5616f66cd [formerly 98c9dca7da]] [formerly a7dcc62bc5 [formerly 4ef4fa0c98] [formerly 55f670b9ae [formerly 1cd4421e2e]]] [formerly d7a5bab832 [formerly c77c5b48df] [formerly 01ffd33e2f [formerly aea728ceb6]] [formerly 16afb18e35 [formerly 4768b156f3] [formerly 3c1298c626 [formerly 1e61cf0974]]]]]]
Former-commit-id: 28b09a56cb [formerly 5241dbb36c] [formerly d43909d979 [formerly 43b0cca7f5]] [formerly 8d8f384c8e [formerly bcf58203c6] [formerly ca56bff2d0 [formerly 7a8750ffc9]]] [formerly 2ce9fa87ae [formerly e7b1b542f5] [formerly 62a7edf94a [formerly 26ca5f220d]] [formerly 73d10253b9 [formerly af463cecb0] [formerly 961314f474 [formerly 6e7141e5e2]]]] [formerly c3f8938ba5 [formerly 7f66748292] [formerly 3e5ef2c136 [formerly 266ddb4ccc]] [formerly 4a8a5437b3 [formerly cbcb0f8777] [formerly c212b8217f [formerly 5f3d3d01c8]]] [formerly 09764ba6cd [formerly 4db991be5f] [formerly f79c6ec15d [formerly 7f8eb54d47]] [formerly 771b00b188 [formerly f1dcba565f] [formerly 83c42510ad [formerly 3c1298c626]]]]]
Former-commit-id: f5cdcca4f3 [formerly dc57f947a2] [formerly 55cd6eb9a3 [formerly e92d6c0923]] [formerly ba80ed43d2 [formerly 0a2f65401a] [formerly 73c0c2ebb3 [formerly bf80cf285e]]] [formerly a34f21d933 [formerly ae311cb3e7] [formerly ae0e3ed079 [formerly c9030d0303]] [formerly f8f4f6a8ec [formerly 343dd65df6] [formerly 1c97e6a7ba [formerly c69d281aca]]]]
Former-commit-id: b93de57240 [formerly 6fd32d0759] [formerly ff29e71cb0 [formerly ee483b56f9]] [formerly 2f944ec28b [formerly 15c46e806c] [formerly dca0c18b8b [formerly f185a8658c]]]
Former-commit-id: 76ed7f184f [formerly c8adbe1dea] [formerly 392fc3e54b [formerly 5bd396738e]]
Former-commit-id: 11579edbe1 [formerly 6970795314]
Former-commit-id: 8fca90698b
master
lhenry15 4 years ago
parent
commit
88dc4c4e33
15 changed files with 1993 additions and 3 deletions
  1. +297
    -0
      datasets/dataset_utils.py
  2. +559
    -0
      datasets/hub.py
  3. +139
    -0
      datasets/tods_dataset_base.py
  4. +116
    -0
      datasets/tods_datasets.py
  5. +0
    -1
      src/axolotl
  6. +0
    -1
      src/common-primitives
  7. +0
    -1
      src/d3m
  8. +116
    -0
      tods/common/FixedSplit.py
  9. +87
    -0
      tods/common/KFoldSplit.py
  10. +187
    -0
      tods/common/KFoldSplitTimeseries.py
  11. +52
    -0
      tods/common/NoSplit.py
  12. +160
    -0
      tods/common/RedactColumns.py
  13. +88
    -0
      tods/common/TrainScoreSplit.py
  14. +0
    -0
      tods/common/__init__.py
  15. +192
    -0
      tods/common/utils.py

+ 297
- 0
datasets/dataset_utils.py View File

@@ -0,0 +1,297 @@
import os
import os.path
import hashlib
import gzip
import errno
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
import zipfile

# import torch
# from torch.utils.model_zoo import tqdm

from hub import tqdm

def gen_bar_updater() -> Callable[[int, int, int], None]:
pbar = tqdm(total=None)

def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)

return bar_update


def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()


def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)


def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)


def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
"""Download a file from a url and place it in root.

Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
import urllib

root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)

os.makedirs(root, exist_ok=True)

# check if file is already present locally
if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else: # download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root

Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories


def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root

Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files


def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
return "Google Drive - Quota exceeded" in response.text


def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root.

Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"

root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)

os.makedirs(root, exist_ok=True)

if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
session = requests.Session()

response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)

if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)

if _quota_exceeded(response):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)

_save_response_content(response, fpath)


def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value

return None


def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()


def _is_tarxz(filename: str) -> bool:
return filename.endswith(".tar.xz")


def _is_tar(filename: str) -> bool:
return filename.endswith(".tar")


def _is_targz(filename: str) -> bool:
return filename.endswith(".tar.gz")


def _is_tgz(filename: str) -> bool:
return filename.endswith(".tgz")


def _is_gzip(filename: str) -> bool:
return filename.endswith(".gz") and not filename.endswith(".tar.gz")


def _is_zip(filename: str) -> bool:
return filename.endswith(".zip")


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
if to_path is None:
to_path = os.path.dirname(from_path)

if _is_tar(from_path):
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))

if remove_finished:
os.remove(from_path)


def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)

download_url(url, download_root, filename, md5)

archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
# print(archive)
# print(extract_root)
# extract_archive(archive, extract_root, remove_finished)


def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'"


T = TypeVar("T", str, bytes)


def verify_str_arg(
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)

if valid_values is None:
return value

if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = ("Unknown value '{value}' for argument {arg}. "
"Valid values are {{{valid_values}}}.")
msg = msg.format(value=value, arg=arg,
valid_values=iterable_to_str(valid_values))
raise ValueError(msg)

return value

+ 559
- 0
datasets/hub.py View File

@@ -0,0 +1,559 @@
import errno
import hashlib
import os
import re
import shutil
import sys
import tempfile
# import torch
import warnings
import zipfile

from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401

try:
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
except ImportError:
try:
from tqdm import tqdm
except ImportError:
# fake tqdm if it's not installed
class tqdm(object): # type: ignore

def __init__(self, total=None, disable=False,
unit=None, unit_scale=None, unit_divisor=None):
self.total = total
self.disable = disable
self.n = 0
# ignore unit, unit_scale, unit_divisor; they're just for real tqdm

def update(self, n):
if self.disable:
return

self.n += n
if self.total is None:
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
else:
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
sys.stderr.flush()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.disable:
return

sys.stderr.write('\n')

# # matches bfd8deac from resnet18-bfd8deac.pth
# HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
#
# MASTER_BRANCH = 'master'
# ENV_TORCH_HOME = 'TORCH_HOME'
# ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
# DEFAULT_CACHE_DIR = '~/.cache'
# VAR_DEPENDENCY = 'dependencies'
# MODULE_HUBCONF = 'hubconf.py'
# READ_DATA_CHUNK = 8192
# _hub_dir = None
#
#
# # Copied from tools/shared/module_loader to be included in torch package
# def import_module(name, path):
# import importlib.util
# from importlib.abc import Loader
# spec = importlib.util.spec_from_file_location(name, path)
# module = importlib.util.module_from_spec(spec)
# assert isinstance(spec.loader, Loader)
# spec.loader.exec_module(module)
# return module
#
#
# def _remove_if_exists(path):
# if os.path.exists(path):
# if os.path.isfile(path):
# os.remove(path)
# else:
# shutil.rmtree(path)
#
#
# def _git_archive_link(repo_owner, repo_name, branch):
# return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
#
#
# def _load_attr_from_module(module, func_name):
# # Check if callable is defined in the module
# if func_name not in dir(module):
# return None
# return getattr(module, func_name)
#
#
# def _get_torch_home():
# torch_home = os.path.expanduser(
# os.getenv(ENV_TORCH_HOME,
# os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
# DEFAULT_CACHE_DIR), 'torch')))
# return torch_home
#
#
# def _parse_repo_info(github):
# branch = MASTER_BRANCH
# if ':' in github:
# repo_info, branch = github.split(':')
# else:
# repo_info = github
# repo_owner, repo_name = repo_info.split('/')
# return repo_owner, repo_name, branch
#
#
# def _get_cache_or_reload(github, force_reload, verbose=True):
# # Setup hub_dir to save downloaded files
# hub_dir = get_dir()
# if not os.path.exists(hub_dir):
# os.makedirs(hub_dir)
# # Parse github repo information
# repo_owner, repo_name, branch = _parse_repo_info(github)
# # Github allows branch name with slash '/',
# # this causes confusion with path on both Linux and Windows.
# # Backslash is not allowed in Github branch name so no need to
# # to worry about it.
# normalized_br = branch.replace('/', '_')
# # Github renames folder repo-v1.x.x to repo-1.x.x
# # We don't know the repo name before downloading the zip file
# # and inspect name from it.
# # To check if cached repo exists, we need to normalize folder names.
# repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br]))
#
# use_cache = (not force_reload) and os.path.exists(repo_dir)
#
# if use_cache:
# if verbose:
# sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
# else:
# cached_file = os.path.join(hub_dir, normalized_br + '.zip')
# _remove_if_exists(cached_file)
#
# url = _git_archive_link(repo_owner, repo_name, branch)
# sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file))
# download_url_to_file(url, cached_file, progress=False)
#
# with zipfile.ZipFile(cached_file) as cached_zipfile:
# extraced_repo_name = cached_zipfile.infolist()[0].filename
# extracted_repo = os.path.join(hub_dir, extraced_repo_name)
# _remove_if_exists(extracted_repo)
# # Unzip the code and rename the base folder
# cached_zipfile.extractall(hub_dir)
#
# _remove_if_exists(cached_file)
# _remove_if_exists(repo_dir)
# shutil.move(extracted_repo, repo_dir) # rename the repo
#
# return repo_dir
#
#
# def _check_module_exists(name):
# if sys.version_info >= (3, 4):
# import importlib.util
# return importlib.util.find_spec(name) is not None
# elif sys.version_info >= (3, 3):
# # Special case for python3.3
# import importlib.find_loader
# return importlib.find_loader(name) is not None
# else:
# # NB: Python2.7 imp.find_module() doesn't respect PEP 302,
# # it cannot find a package installed as .egg(zip) file.
# # Here we use workaround from:
# # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1
# # Also imp doesn't handle hierarchical module names (names contains dots).
# try:
# # 1. Try imp.find_module(), which searches sys.path, but does
# # not respect PEP 302 import hooks.
# import imp
# result = imp.find_module(name)
# if result:
# return True
# except ImportError:
# pass
# path = sys.path
# for item in path:
# # 2. Scan path for import hooks. sys.path_importer_cache maps
# # path items to optional "importer" objects, that implement
# # find_module() etc. Note that path must be a subset of
# # sys.path for this to work.
# importer = sys.path_importer_cache.get(item)
# if importer:
# try:
# result = importer.find_module(name, [item])
# if result:
# return True
# except ImportError:
# pass
# return False
#
# def _check_dependencies(m):
# dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
#
# if dependencies is not None:
# missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
# if len(missing_deps):
# raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
#
#
# def _load_entry_from_hubconf(m, model):
# if not isinstance(model, str):
# raise ValueError('Invalid input: model should be a string of function name')
#
# # Note that if a missing dependency is imported at top level of hubconf, it will
# # throw before this function. It's a chicken and egg situation where we have to
# # load hubconf to know what're the dependencies, but to import hubconf it requires
# # a missing package. This is fine, Python will throw proper error message for users.
# _check_dependencies(m)
#
# func = _load_attr_from_module(m, model)
#
# if func is None or not callable(func):
# raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
#
# return func
#
#
# def get_dir():
# r"""
# Get the Torch Hub cache directory used for storing downloaded models & weights.
#
# If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
# environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
# ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
# filesystem layout, with a default value ``~/.cache`` if the environment
# variable is not set.
# """
# # Issue warning to move data if old env is set
# if os.getenv('TORCH_HUB'):
# warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
#
# if _hub_dir is not None:
# return _hub_dir
# return os.path.join(_get_torch_home(), 'hub')
#
#
# def set_dir(d):
# r"""
# Optionally set the Torch Hub directory used to save downloaded models & weights.
#
# Args:
# d (string): path to a local folder to save downloaded models & weights.
# """
# global _hub_dir
# _hub_dir = d
#
#
# def list(github, force_reload=False):
# r"""
# List all entrypoints available in `github` hubconf.
#
# Args:
# github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
# tag/branch. The default branch is `master` if not specified.
# Example: 'pytorch/vision[:hub]'
# force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
# Default is `False`.
# Returns:
# entrypoints: a list of available entrypoint names
#
# Example:
# >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
# """
# repo_dir = _get_cache_or_reload(github, force_reload, True)
#
# sys.path.insert(0, repo_dir)
#
# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
#
# sys.path.remove(repo_dir)
#
# # We take functions starts with '_' as internal helper functions
# entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
#
# return entrypoints
#
#
# def help(github, model, force_reload=False):
# r"""
# Show the docstring of entrypoint `model`.
#
# Args:
# github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
# tag/branch. The default branch is `master` if not specified.
# Example: 'pytorch/vision[:hub]'
# model (string): a string of entrypoint name defined in repo's hubconf.py
# force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
# Default is `False`.
# Example:
# >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
# """
# repo_dir = _get_cache_or_reload(github, force_reload, True)
#
# sys.path.insert(0, repo_dir)
#
# hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
#
# sys.path.remove(repo_dir)
#
# entry = _load_entry_from_hubconf(hub_module, model)
#
# return entry.__doc__
#
#
# # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`,
# # but Python2 complains syntax error for it. We have to skip force_reload in function
# # signature here but detect it in kwargs instead.
# # TODO: fix it after Python2 EOL
# def load(repo_or_dir, model, *args, **kwargs):
# r"""
# Load a model from a github repo or a local directory.
#
# Note: Loading a model is the typical use case, but this can also be used to
# for loading other objects such as tokenizers, loss functions, etc.
#
# If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be
# of the form ``repo_owner/repo_name[:tag_name]`` with an optional
# tag/branch.
#
# If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a
# path to a local directory.
#
# Args:
# repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``),
# if ``source = 'github'``; or a path to a local directory, if
# ``source = 'local'``.
# model (string): the name of a callable (entrypoint) defined in the
# repo/dir's ``hubconf.py``.
# *args (optional): the corresponding args for callable :attr:`model`.
# source (string, optional): ``'github'`` | ``'local'``. Specifies how
# ``repo_or_dir`` is to be interpreted. Default is ``'github'``.
# force_reload (bool, optional): whether to force a fresh download of
# the github repo unconditionally. Does not have any effect if
# ``source = 'local'``. Default is ``False``.
# verbose (bool, optional): If ``False``, mute messages about hitting
# local caches. Note that the message about first download cannot be
# muted. Does not have any effect if ``source = 'local'``.
# Default is ``True``.
# **kwargs (optional): the corresponding kwargs for callable
# :attr:`model`.
#
# Returns:
# The output of the :attr:`model` callable when called with the given
# ``*args`` and ``**kwargs``.
#
# Example:
# >>> # from a github repo
# >>> repo = 'pytorch/vision'
# >>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
# >>> # from a local directory
# >>> path = '/some/local/path/pytorch/vision'
# >>> model = torch.hub.load(path, 'resnet50', pretrained=True)
# """
# source = kwargs.pop('source', 'github').lower()
# force_reload = kwargs.pop('force_reload', False)
# verbose = kwargs.pop('verbose', True)
#
# if source not in ('github', 'local'):
# raise ValueError(
# f'Unknown source: "{source}". Allowed values: "github" | "local".')
#
# if source == 'github':
# repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose)
#
# model = _load_local(repo_or_dir, model, *args, **kwargs)
# return model
#
#
# def _load_local(hubconf_dir, model, *args, **kwargs):
# r"""
# Load a model from a local directory with a ``hubconf.py``.
#
# Args:
# hubconf_dir (string): path to a local directory that contains a
# ``hubconf.py``.
# model (string): name of an entrypoint defined in the directory's
# `hubconf.py`.
# *args (optional): the corresponding args for callable ``model``.
# **kwargs (optional): the corresponding kwargs for callable ``model``.
#
# Returns:
# a single model with corresponding pretrained weights.
#
# Example:
# >>> path = '/some/local/path/pytorch/vision'
# >>> model = _load_local(path, 'resnet50', pretrained=True)
# """
# sys.path.insert(0, hubconf_dir)
#
# hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
# hub_module = import_module(MODULE_HUBCONF, hubconf_path)
#
# entry = _load_entry_from_hubconf(hub_module, model)
# model = entry(*args, **kwargs)
#
# sys.path.remove(hubconf_dir)
#
# return model
#
#
# def download_url_to_file(url, dst, hash_prefix=None, progress=True):
# r"""Download object at the given URL to a local path.
#
# Args:
# url (string): URL of the object to download
# dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
# hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
# Default: None
# progress (bool, optional): whether or not to display a progress bar to stderr
# Default: True
#
# Example:
# >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
#
# """
# file_size = None
# # We use a different API for python2 since urllib(2) doesn't recognize the CA
# # certificates in older Python
# req = Request(url, headers={"User-Agent": "torch.hub"})
# u = urlopen(req)
# meta = u.info()
# if hasattr(meta, 'getheaders'):
# content_length = meta.getheaders("Content-Length")
# else:
# content_length = meta.get_all("Content-Length")
# if content_length is not None and len(content_length) > 0:
# file_size = int(content_length[0])
#
# # We deliberately save it in a temp file and move it after
# # download is complete. This prevents a local working checkpoint
# # being overridden by a broken download.
# dst = os.path.expanduser(dst)
# dst_dir = os.path.dirname(dst)
# f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
#
# try:
# if hash_prefix is not None:
# sha256 = hashlib.sha256()
# with tqdm(total=file_size, disable=not progress,
# unit='B', unit_scale=True, unit_divisor=1024) as pbar:
# while True:
# buffer = u.read(8192)
# if len(buffer) == 0:
# break
# f.write(buffer)
# if hash_prefix is not None:
# sha256.update(buffer)
# pbar.update(len(buffer))
#
# f.close()
# if hash_prefix is not None:
# digest = sha256.hexdigest()
# if digest[:len(hash_prefix)] != hash_prefix:
# raise RuntimeError('invalid hash value (expected "{}", got "{}")'
# .format(hash_prefix, digest))
# shutil.move(f.name, dst)
# finally:
# f.close()
# if os.path.exists(f.name):
# os.remove(f.name)
#
# def _download_url_to_file(url, dst, hash_prefix=None, progress=True):
# warnings.warn('torch.hub._download_url_to_file has been renamed to\
# torch.hub.download_url_to_file to be a public API,\
# _download_url_to_file will be removed in after 1.3 release')
# download_url_to_file(url, dst, hash_prefix, progress)
#
# # Hub used to support automatically extracts from zipfile manually compressed by users.
# # The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
# # We should remove this support since zipfile is now default zipfile format for torch.save().
# def _is_legacy_zip_format(filename):
# if zipfile.is_zipfile(filename):
# infolist = zipfile.ZipFile(filename).infolist()
# return len(infolist) == 1 and not infolist[0].is_dir()
# return False
#
# def _legacy_zip_load(filename, model_dir, map_location):
# warnings.warn('Falling back to the old format < 1.6. This support will be '
# 'deprecated in favor of default zipfile format introduced in 1.6. '
# 'Please redo torch.save() to save it in the new zipfile format.')
# # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# # We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# # E.g. resnet18-5c106cde.pth which is widely used.
# with zipfile.ZipFile(filename) as f:
# members = f.infolist()
# if len(members) != 1:
# raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
# f.extractall(model_dir)
# extraced_name = members[0].filename
# extracted_file = os.path.join(model_dir, extraced_name)
# return torch.load(extracted_file, map_location=map_location)
#
# def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
# r"""Loads the Torch serialized object at the given URL.
#
# If downloaded file is a zip file, it will be automatically
# decompressed.
#
# If the object is already present in `model_dir`, it's deserialized and
# returned.
# The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
# `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
#
# Args:
# url (string): URL of the object to download
# model_dir (string, optional): directory in which to save the object
# map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
# progress (bool, optional): whether or not to display a progress bar to stderr.
# Default: True
# check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
# ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
# digits of the SHA256 hash of the contents of the file. The hash is used to
# ensure unique names and to verify the contents of the file.
# Default: False
# file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set.
#
# Example:
# >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
#
# """
# # Issue warning to move data if old env is set
# if os.getenv('TORCH_MODEL_ZOO'):
# warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
#
# if model_dir is None:
# hub_dir = get_dir()
# model_dir = os.path.join(hub_dir, 'checkpoints')
#
# try:
# os.makedirs(model_dir)
# except OSError as e:
# if e.errno == errno.EEXIST:
# # Directory already exists, ignore.
# pass
# else:
# # Unexpected OSError, re-raise.
# raise
#
# parts = urlparse(url)
# filename = os.path.basename(parts.path)
# if file_name is not None:
# filename = file_name
# cached_file = os.path.join(model_dir, filename)
# if not os.path.exists(cached_file):
# sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
# hash_prefix = None
# if check_hash:
# r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
# hash_prefix = r.group(1) if r else None
# download_url_to_file(url, cached_file, hash_prefix, progress=progress)
#
# if _is_legacy_zip_format(cached_file):
# return _legacy_zip_load(cached_file, model_dir, map_location)
# return torch.load(cached_file, map_location=map_location)

+ 139
- 0
datasets/tods_dataset_base.py View File

@@ -0,0 +1,139 @@
import warnings
import os
import os.path
import numpy as np
import codecs
import string
import gzip
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from dataset_utils import download_url, download_and_extract_archive, extract_archive, verify_str_arg

# tqdm >= 4.31.1

from tods import generate_dataset
from sklearn import preprocessing
import pandas as pd

class TODS_dataset:
resources = []
training_file = None
testing_file = None
ground_truth_index = None
_repr_indent = None

@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw')

@property
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'processed')

def __init__(self, root, train, transform=None, download=True):

self.root = root
self.train = train
self.transform = self.transform_init(transform)

if download:
self.download()
pass

self.process()


def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.testing_file)))


def download(self) -> None:

if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)

# download files
for url, md5 in self.resources:
filename = url.rpartition('/')[2]
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)


def process(self) -> None:

pass


def process_dataframe(self) -> None:

if self.transform is None:
pass

else:
self.transform.fit(self.training_set_dataframe)
self.training_set_array = self.transform.transform(self.training_set_dataframe.values)
self.testing_set_array = self.transform.transform(self.testing_set_dataframe.values)
self.training_set_dataframe = pd.DataFrame(self.training_set_array)
self.testing_set_dataframe = pd.DataFrame(self.testing_set_array)


def transform_init(self, transform_str):

if transform_str is None:
return None
elif transform_str == 'standardscale':
return preprocessing.StandardScaler()
elif transform_str == 'normalize':
return preprocessing.Normalizer()
elif transform_str == 'minmaxscale':
return preprocessing.MinMaxScaler()
elif transform_str == 'maxabsscale':
return preprocessing.MaxAbsScaler()
elif transform_str == 'binarize':
return preprocessing.Binarizer()
else:
raise ValueError("Input parameter transform must take value of 'standardscale', 'normalize', " +
"'minmaxscale', 'maxabsscale' or 'binarize'."
)


def to_axolotl_dataset(self):
if self.train:
return generate_dataset(self.training_set_dataframe, self.ground_truth_index)
else:
return generate_dataset(self.testing_set_dataframe, self.ground_truth_index)


def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]

print(self.training_set_dataframe)

return '\n'.join(lines)


def __len__(self) -> int:
return len(self.training_set_dataframe)


def extra_repr(self) -> str:
return ""


# kpi(root='./datasets', train=True)

# class yahoo5:
#
# def __init__(self):
# pass

+ 116
- 0
datasets/tods_datasets.py View File

@@ -0,0 +1,116 @@
import os
import pandas as pd

from tods_dataset_base import TODS_dataset
from shutil import copyfile

class kpi_dataset(TODS_dataset):
resources = [
# ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
# ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None),
# ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None),
("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online.
("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/datasetDoc.json", None),
# needs a server to store the dataset.
# ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online.
]

training_file = 'learningData.csv'
testing_file = 'testingData.csv'
ground_truth_index = 3
_repr_indent = 4

# def __init__(self, root, train, transform=None, target_transform=None, download=True):
# super().__init__(root, train, transform=None, target_transform=None, download=True)

def process(self) -> None:

print('Processing...')

os.makedirs(self.processed_folder, exist_ok=True)
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True)

training_set_fname = os.path.join(self.raw_folder, 'learningData.csv')
self.training_set_dataframe = pd.read_csv(training_set_fname)
testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set
self.testing_set_dataframe = pd.read_csv(testing_set_fname)

self.process_dataframe()
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file))
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file))
copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json'))

print('Done!')


class yahoo_dataset(TODS_dataset):
resources = [
# ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
# ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None),
# ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None),
("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online.
("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/datasetDoc.json", None),
# needs a server to store the dataset.
# ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online.
]

training_file = 'learningData.csv'
testing_file = 'testingData.csv'
ground_truth_index = 7
_repr_indent = 4

def process(self) -> None:

print('Processing...')

os.makedirs(self.processed_folder, exist_ok=True)
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True)

training_set_fname = os.path.join(self.raw_folder, 'learningData.csv')
self.training_set_dataframe = pd.read_csv(training_set_fname)
testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set
self.testing_set_dataframe = pd.read_csv(testing_set_fname)

self.process_dataframe()
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file))
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file))
copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json'))

print('Done!')


class NAB_dataset(TODS_dataset):
resources = [
("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.csv", None),
# it needs md5 to check if local learningData.csv is the same with online.
("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.json", None),
# needs a server to store the dataset.
]

training_file = 'learningData.csv'
testing_file = 'testingData.csv'
ground_truth_index = 2
_repr_indent = 4

def process(self) -> None:
print('Processing...')

os.makedirs(self.processed_folder, exist_ok=True)
os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True)

training_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv')
self.training_set_dataframe = pd.read_csv(training_set_fname)
testing_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') # temperarily same with training set
self.testing_set_dataframe = pd.read_csv(testing_set_fname)

self.process_dataframe()
self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file))
self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file))
copyfile(os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.json'),
os.path.join(self.processed_folder, 'datasetDoc.json'))

print('Done!')

# kpi_dataset(root='./datasets', train=True, transform='binarize')
# yahoo_dataset(root='./datasets', train=True, transform='binarize')
# NAB_dataset(root='./datasets', train=True, transform='binarize')

+ 0
- 1
src/axolotl

@@ -1 +0,0 @@
Subproject commit af54e6970476a081bf0cd65990c9f56a1200d8a2

+ 0
- 1
src/common-primitives

@@ -1 +0,0 @@
Subproject commit 046b20d2f6d4543dcbe18f0a1d4bcbb1f61cf518

+ 0
- 1
src/d3m

@@ -1 +0,0 @@
Subproject commit 70aeefed6b7307941581357c4b7858bb3f88e1da

+ 116
- 0
tods/common/FixedSplit.py View File

@@ -0,0 +1,116 @@
import os
import typing

import numpy
import pandas

from d3m import container, exceptions, utils as d3m_utils
from d3m.metadata import base as metadata_base, hyperparams
from d3m.base import primitives

__all__ = ('FixedSplitDatasetSplitPrimitive',)


class Hyperparams(hyperparams.Hyperparams):
primary_index_values = hyperparams.Set(
elements=hyperparams.Hyperparameter[str](''),
default=(),
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description='A set of primary index values of the main resource belonging to the test (score) split. Cannot be set together with "row_indices".',
)
row_indices = hyperparams.Set(
elements=hyperparams.Hyperparameter[int](-1),
default=(),
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description='A set of row indices of the main resource belonging to the test (score) split. Cannot be set together with "primary_index_values".',
)
delete_recursive = hyperparams.Hyperparameter[bool](
default=False,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.",
)


class FixedSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]):
"""
A primitive which splits a tabular Dataset in a way that uses for the test
(score) split a fixed list of primary index values or row indices of the main
resource to be used. All other rows are added used for the train split.
"""

metadata = metadata_base.PrimitiveMetadata(
{
'id': '1654f000-2178-4520-be4c-a95bc26b8d3a',
'version': '0.1.0',
'name': "Fixed split tabular dataset splits",
'python_path': 'd3m.primitives.tods.evaluation.fixed_split_dataset_split',
'source': {
'name': "DATALab@TexasA&M University",
'contact': 'mailto:mitar.commonprimitives@tnode.com',
'uris': [
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/fixed_split.py',
'https://gitlab.com/datadrivendiscovery/common-primitives.git',
],
},
'algorithm_types': [
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING,
],
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION,
},
)

def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]:
# This should be handled by "Set" hyper-parameter, but we check it here again just to be sure.
if d3m_utils.has_duplicates(self.hyperparams['primary_index_values']):
raise exceptions.InvalidArgumentValueError("\"primary_index_values\" hyper-parameter has duplicate values.")
if d3m_utils.has_duplicates(self.hyperparams['row_indices']):
raise exceptions.InvalidArgumentValueError("\"row_indices\" hyper-parameter has duplicate values.")

if self.hyperparams['primary_index_values'] and self.hyperparams['row_indices']:
raise exceptions.InvalidArgumentValueError("Both \"primary_index_values\" and \"row_indices\" cannot be provided.")

if self.hyperparams['primary_index_values']:
primary_index_values = numpy.array(self.hyperparams['primary_index_values'])

index_columns = dataset.metadata.get_index_columns(at=(main_resource_id,))

if not index_columns:
raise exceptions.InvalidArgumentValueError("Cannot find index columns in the main resource of the dataset, but \"primary_index_values\" is provided.")

main_resource = dataset[main_resource_id]
# We reset the index so that the index corresponds to row indices.
main_resource = main_resource.reset_index(drop=True)

# We use just the "d3mIndex" column and ignore multi-key indices.
# This works for now because it seems that every current multi-key
# dataset in fact has an unique value in "d3mIndex" alone.
# See: https://gitlab.datadrivendiscovery.org/MIT-LL/d3m_data_supply/issues/117
index_column = index_columns[0]

score_data = numpy.array(main_resource.loc[main_resource.iloc[:, index_column].isin(primary_index_values)].index)
score_data_set = set(score_data)

assert len(score_data) == len(score_data_set), (len(score_data), len(score_data_set))

if len(score_data) != len(primary_index_values):
raise exceptions.InvalidArgumentValueError("\"primary_index_values\" contains values which do not exist.")

else:
score_data = numpy.array(self.hyperparams['row_indices'])
score_data_set = set(score_data)

all_data_set = set(numpy.arange(len(attributes)))

if not score_data_set <= all_data_set:
raise exceptions.InvalidArgumentValueError("\"row_indices\" contains indices which do not exist, e.g., {indices}.".format(
indices=sorted(score_data_set - all_data_set)[:5],
))

train_data = []
for i in numpy.arange(len(attributes)):
if i not in score_data_set:
train_data.append(i)

assert len(train_data) + len(score_data) == len(attributes), (len(train_data), len(score_data), len(attributes))

return [(numpy.array(train_data), score_data)]

+ 87
- 0
tods/common/KFoldSplit.py View File

@@ -0,0 +1,87 @@
import os
import typing

import numpy
import pandas
from sklearn import model_selection

from d3m import container, exceptions, utils as d3m_utils
from d3m.metadata import base as metadata_base, hyperparams
from d3m.base import primitives


__all__ = ('KFoldDatasetSplitPrimitive',)


class Hyperparams(hyperparams.Hyperparams):
number_of_folds = hyperparams.Bounded[int](
lower=2,
upper=None,
default=5,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Number of folds for k-folds cross-validation.",
)
stratified = hyperparams.UniformBool(
default=False,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.",
)
shuffle = hyperparams.UniformBool(
default=False,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Whether to shuffle the data before splitting into batches.",
)
delete_recursive = hyperparams.Hyperparameter[bool](
default=False,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.",
)


class KFoldDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]):
"""
A primitive which splits a tabular Dataset for k-fold cross-validation.
"""

__author__ = 'Mingjie Sun <sunmj15@gmail.com>'
metadata = metadata_base.PrimitiveMetadata(
{
'id': 'bfedaf3a-6dd0-4a83-ad83-3a50fe882bf8',
'version': '0.1.0',
'name': "K-fold cross-validation tabular dataset splits",
'python_path': 'd3m.primitives.tods.evaluation.kfold_dataset_split',
'source': {
'name': 'DATALab@Texas A&M University',
'contact': 'mailto:sunmj15@gmail.com',
'uris': [
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split.py',
'https://gitlab.com/datadrivendiscovery/common-primitives.git',
],
},
'algorithm_types': [
metadata_base.PrimitiveAlgorithmType.K_FOLD,
metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION,
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING,
],
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION,
},
)

def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]:
if self.hyperparams['stratified']:
if not len(targets.columns):
raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.")

k_fold = model_selection.StratifiedKFold(
n_splits=self.hyperparams['number_of_folds'],
shuffle=self.hyperparams['shuffle'],
random_state=self._random_state,
)
else:
k_fold = model_selection.KFold(
n_splits=self.hyperparams['number_of_folds'],
shuffle=self.hyperparams['shuffle'],
random_state=self._random_state,
)

return list(k_fold.split(attributes, targets))

+ 187
- 0
tods/common/KFoldSplitTimeseries.py View File

@@ -0,0 +1,187 @@
import os
import uuid
import typing
from collections import OrderedDict

import numpy
import pandas
from sklearn import model_selection

from d3m import container, exceptions, utils as d3m_utils
from d3m.metadata import base as metadata_base, hyperparams
from d3m.base import primitives

import utils

__all__ = ('KFoldTimeSeriesSplitPrimitive',)


class Hyperparams(hyperparams.Hyperparams):
number_of_folds = hyperparams.Bounded[int](
lower=2,
upper=None,
default=5,
semantic_types=[
'https://metadata.datadrivendiscovery.org/types/ControlParameter'
],
description="Number of folds for k-folds cross-validation.",
)
number_of_window_folds = hyperparams.Union[typing.Union[int, None]](
configuration=OrderedDict(
fixed=hyperparams.Bounded[int](
lower=1,
upper=None,
default=1,
description="Number of folds in train set (window). These folds come directly "
"before test set (streaming window).",
),
all_records=hyperparams.Constant(
default=None,
description="Number of folds in train set (window) = maximum number possible.",
),
),
default='all_records',
semantic_types=[
'https://metadata.datadrivendiscovery.org/types/ControlParameter'
],
description="Maximum size for a single training set.",
)
time_column_index = hyperparams.Union[typing.Union[int, None]](
configuration=OrderedDict(
fixed=hyperparams.Bounded[int](
lower=1,
upper=None,
default=1,
description="Specific column that contains the time index",
),
one_column=hyperparams.Constant(
default=None,
description="Only one column contains a time index. "
"It is detected automatically using semantic types.",
),
),
default='one_column',
semantic_types=[
'https://metadata.datadrivendiscovery.org/types/ControlParameter'
],
description="Column index to use as datetime index. "
"If None, it is required that only one column with time column role semantic type is "
"present and otherwise an exception is raised. "
"If column index specified is not a datetime column an exception is"
"also raised.",
)
fuzzy_time_parsing = hyperparams.UniformBool(
default=True,
semantic_types=[
'https://metadata.datadrivendiscovery.org/types/ControlParameter'
],
description="Use fuzzy time parsing.",
)


class KFoldTimeSeriesSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]):
"""
A primitive which splits a tabular time-series Dataset for k-fold cross-validation.

Primitive sorts the time column so care should be taken to assure sorting of a
column is reasonable. E.g., if column is not numeric but of string structural type,
strings should be formatted so that sorting by them also sorts by time.
"""

__author__ = 'Distil'
__version__ = '0.3.0'
__contact__ = 'mailto:jeffrey.gleason@yonder.co'

metadata = metadata_base.PrimitiveMetadata(
{
'id': '002f9ad1-46e3-40f4-89ed-eeffbb3a102b',
'version': __version__,
'name': "K-fold cross-validation timeseries dataset splits",
'python_path': 'd3m.primitives.tods.evaluation.kfold_time_series_split',
'source': {
'name': 'DATALab@Texas A&M University',
'contact': __contact__,
'uris': [
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split_timeseries.py',
'https://gitlab.com/datadrivendiscovery/common-primitives.git',
],
},
'algorithm_types': [
metadata_base.PrimitiveAlgorithmType.K_FOLD,
metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION,
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING,
],
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION,
},
)

def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]:
time_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Time'], at=(main_resource_id,))
attribute_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Attribute'], at=(main_resource_id,))

# We want only time columns which are also attributes.
time_column_indices = [time_column_index for time_column_index in time_column_indices if time_column_index in attribute_column_indices]

if self.hyperparams['time_column_index'] is None:
if len(time_column_indices) != 1:
raise exceptions.InvalidArgumentValueError(
"If \"time_column_index\" hyper-parameter is \"None\", it is required that exactly one column with time column role semantic type is present.",
)
else:
# We know it exists because "time_column_indices" is a subset of "attribute_column_indices".
time_column_index = attribute_column_indices.index(
time_column_indices[0],
)
else:
if self.hyperparams['time_column_index'] not in time_column_indices:
raise exceptions.InvalidArgumentValueError(
"Time column index specified does not have a time column role semantic type.",
)
else:
time_column_index = attribute_column_indices.index(
self.hyperparams['time_column_index'],
)

# We first reset index.
attributes = attributes.reset_index(drop=True)

# Then convert datetime column to consistent datetime representation
attributes.insert(
loc=0,
column=uuid.uuid4(), # use uuid to ensure we are inserting a new column name
value=self._parse_time_data(
attributes, time_column_index, self.hyperparams['fuzzy_time_parsing'],
),
)

# Then sort dataframe by new datetime column. Index contains original row order.
attributes = attributes.sort_values(by=attributes.columns[0])

# Remove datetime representation used for sorting (primitives might choose to parse this str col differently).
attributes = attributes.drop(attributes.columns[0], axis=1)

max_train_size: typing.Optional[int] = None
if self.hyperparams['number_of_window_folds'] is not None:
max_train_size = int(attributes.shape[0] * self.hyperparams['number_of_window_folds'] / self.hyperparams['number_of_folds'])

k_fold = model_selection.TimeSeriesSplit(
n_splits=self.hyperparams['number_of_folds'],
max_train_size=max_train_size
)

# We sorted "attributes" so we have to map indices on sorted "attributes" back to original
# indices. We do that by using DataFrame's index which contains original row order.
return [
(
numpy.array([attributes.index[val] for val in train]),
numpy.array([attributes.index[val] for val in test]),
)
for train, test in k_fold.split(attributes)
]

@classmethod
def _parse_time_data(cls, inputs: container.DataFrame, column_index: metadata_base.SimpleSelectorSegment, fuzzy: bool) -> typing.List[float]:
return [
utils.parse_datetime_to_float(value, fuzzy=fuzzy)
for value in inputs.iloc[:, column_index]
]

+ 52
- 0
tods/common/NoSplit.py View File

@@ -0,0 +1,52 @@
import os
import typing

import numpy
import pandas

from d3m import container, utils as d3m_utils
from d3m.metadata import base as metadata_base, hyperparams
from d3m.base import primitives


__all__ = ('NoSplitDatasetSplitPrimitive',)


class Hyperparams(hyperparams.Hyperparams):
pass


class NoSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]):
"""
A primitive which splits a tabular Dataset in a way that for all splits it
produces the same (full) Dataset. Useful for unsupervised learning tasks. .
"""

metadata = metadata_base.PrimitiveMetadata(
{
'id': '48c683ad-da9e-48cf-b3a0-7394dba5e5d2',
'version': '0.1.0',
'name': "No-split tabular dataset splits",
'python_path': 'd3m.primitives.tods.evaluation.no_split_dataset_split',
'source': {
'name': 'DATALab@Texas A&M University',
'contact': 'mailto:mitar.commonprimitives@tnode.com',
'uris': [
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/no_split.py',
'https://gitlab.com/datadrivendiscovery/common-primitives.git',
],
},
'algorithm_types': [
metadata_base.PrimitiveAlgorithmType.IDENTITY_FUNCTION,
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING,
],
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION,
},
)

def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]:
# We still go through the whole splitting process to assure full compatibility
# (and error conditions) of a regular split, but we use all data for both splits.
all_data = numpy.arange(len(attributes))

return [(all_data, all_data)]

+ 160
- 0
tods/common/RedactColumns.py View File

@@ -0,0 +1,160 @@
import copy
import os
import typing

from d3m import container, exceptions, utils as d3m_utils
from d3m.metadata import base as metadata_base, hyperparams
from d3m.primitive_interfaces import base, transformer


Inputs = container.List
Outputs = container.List


class Hyperparams(hyperparams.Hyperparams):
match_logic = hyperparams.Enumeration(
values=['all', 'any'],
default='any',
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Should a column have all of semantic types in \"semantic_types\" to be redacted, or any of them?",
)
semantic_types = hyperparams.Set(
elements=hyperparams.Hyperparameter[str](''),
default=(),
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Redact columns with these semantic types. Only columns having semantic types listed here will be operated on, based on \"match_logic\".",
)
add_semantic_types = hyperparams.Set(
elements=hyperparams.Hyperparameter[str](''),
default=(),
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Semantic types to add to redacted columns. All listed semantic types will be added to all columns which were redacted.",
)


# TODO: Make clear the assumption that both container type (List) and Datasets should have metadata.
# Primitive is modifying metadata of Datasets, while there is officially no reason for them
# to really have metadata: metadata is stored available on the input container type, not
# values inside it.
class RedactColumnsPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
"""
A primitive which takes as an input a list of ``Dataset`` objects and redacts values of all columns matching
a given semantic type or types.

Redaction is done by setting all values in a redacted column to an empty string.

It operates only on DataFrame resources inside datasets.
"""

metadata = metadata_base.PrimitiveMetadata(
{
'id': '744c4090-e2f6-489e-8efc-8b1e051bfad6',
'version': '0.2.0',
'name': "Redact columns for evaluation",
'python_path': 'd3m.primitives.tods.evaluation.redact_columns',
'source': {
'name': 'DATALab@Texas A&M University',
'contact': 'mailto:sunmj15@gmail.com',
'uris': [
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/redact_columns.py',
'https://gitlab.com/datadrivendiscovery/common-primitives.git',
],
},
'installation': [{
'type': metadata_base.PrimitiveInstallationType.PIP,
'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format(
git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)),
),
}],
'algorithm_types': [
metadata_base.PrimitiveAlgorithmType.DATA_CONVERSION,
],
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION,
},
)

def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
output_datasets = container.List(generate_metadata=True)

for dataset in inputs:
resources = {}
metadata = dataset.metadata

for resource_id, resource in dataset.items():
if not isinstance(resource, container.DataFrame):
resources[resource_id] = resource
continue

columns_to_redact = self._get_columns_to_redact(metadata, (resource_id,))

if not columns_to_redact:
resources[resource_id] = resource
continue

resource = copy.copy(resource)

for column_index in columns_to_redact:
column_metadata = dataset.metadata.query((resource_id, metadata_base.ALL_ELEMENTS, column_index))
if 'structural_type' in column_metadata and issubclass(column_metadata['structural_type'], str):
resource.iloc[:, column_index] = ''
else:
raise TypeError("Primitive can operate only on columns with structural type \"str\", not \"{type}\".".format(
type=column_metadata.get('structural_type', None),
))

metadata = self._update_metadata(metadata, resource_id, column_index, ())

resources[resource_id] = resource

dataset = container.Dataset(resources, metadata)

output_datasets.append(dataset)

output_datasets.metadata = metadata_base.DataMetadata({
'schema': metadata_base.CONTAINER_SCHEMA_VERSION,
'structural_type': container.List,
'dimension': {
'length': len(output_datasets),
},
})

# We update metadata based on metadata of each dataset.
# TODO: In the future this might be done automatically by generate_metadata.
# See: https://gitlab.com/datadrivendiscovery/d3m/issues/119
for index, dataset in enumerate(output_datasets):
output_datasets.metadata = dataset.metadata.copy_to(output_datasets.metadata, (), (index,))

return base.CallResult(output_datasets)

def _get_columns_to_redact(self, inputs_metadata: metadata_base.DataMetadata, at: metadata_base.Selector) -> typing.Sequence[int]:
columns = []

for element in inputs_metadata.get_elements(list(at) + [metadata_base.ALL_ELEMENTS]):
semantic_types = inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS, element]).get('semantic_types', ())

# TODO: Should we handle inheritance between semantic types here?
if self.hyperparams['match_logic'] == 'all':
matched = all(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types'])
elif self.hyperparams['match_logic'] == 'any':
matched = any(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types'])
else:
raise exceptions.UnexpectedValueError("Unknown value of hyper-parameter \"match_logic\": {value}".format(value=self.hyperparams['match_logic']))

if matched:
if element is metadata_base.ALL_ELEMENTS:
return list(range(inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS]).get('dimension', {}).get('length', 0)))
else:
columns.append(typing.cast(int, element))

return columns

def _update_metadata(
self, inputs_metadata: metadata_base.DataMetadata, resource_id: metadata_base.SelectorSegment,
column_index: int, at: metadata_base.Selector,
) -> metadata_base.DataMetadata:
outputs_metadata = inputs_metadata

for semantic_type in self.hyperparams['add_semantic_types']:
outputs_metadata = outputs_metadata.add_semantic_type(tuple(at) + (resource_id, metadata_base.ALL_ELEMENTS, column_index), semantic_type)

return outputs_metadata

+ 88
- 0
tods/common/TrainScoreSplit.py View File

@@ -0,0 +1,88 @@
import os
import typing

import numpy
import pandas
from sklearn import model_selection

from d3m import container, exceptions, utils as d3m_utils
from d3m.metadata import base as metadata_base, hyperparams
from d3m.base import primitives


__all__ = ('TrainScoreDatasetSplitPrimitive',)


class Hyperparams(hyperparams.Hyperparams):
train_score_ratio = hyperparams.Uniform(
lower=0,
upper=1,
default=0.75,
upper_inclusive=True,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="The ratio between the train and score data and represents the proportion of the Dataset to include in the train split. The rest is included in the score split.",
)
stratified = hyperparams.UniformBool(
default=False,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.",
)
shuffle = hyperparams.UniformBool(
default=False,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Whether to shuffle the data before splitting into batches.",
)
delete_recursive = hyperparams.Hyperparameter[bool](
default=False,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.",
)


class TrainScoreDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]):
"""
A primitive which splits a tabular Dataset into random train and score subsets.
"""

metadata = metadata_base.PrimitiveMetadata(
{
'id': '3fcc6dc4-6681-4c86-948e-066d14e7d803',
'version': '0.1.0',
'name': "Train-score tabular dataset splits",
'python_path': 'd3m.primitives.tods.evaluation.train_score_dataset_split',
'source': {
'name': 'DATALab@Texas A&M University',
'contact': 'mailto:mitar.commonprimitives@tnode.com',
'uris': [
'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/train_score_split.py',
'https://gitlab.com/datadrivendiscovery/common-primitives.git',
],
},
'installation': [{
'type': metadata_base.PrimitiveInstallationType.PIP,
'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format(
git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)),
),
}],
'algorithm_types': [
metadata_base.PrimitiveAlgorithmType.HOLDOUT,
metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING,
],
'primitive_family': metadata_base.PrimitiveFamily.EVALUATION,
},
)

def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]:
if self.hyperparams['stratified'] and not len(targets.columns):
raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.")

train_data, score_data = model_selection.train_test_split(
numpy.arange(len(attributes)),
test_size=None,
train_size=self.hyperparams['train_score_ratio'],
random_state=self._random_state,
shuffle=self.hyperparams['shuffle'],
stratify=targets if self.hyperparams['stratified'] else None,
)

return [(train_data, score_data)]

+ 0
- 0
tods/common/__init__.py View File


+ 192
- 0
tods/common/utils.py View File

@@ -0,0 +1,192 @@
import datetime
import logging
import typing

import dateutil.parser
import numpy

from d3m import container, deprecate
from d3m.base import utils as base_utils
from d3m.metadata import base as metadata_base

logger = logging.getLogger(__name__)

DEFAULT_DATETIME = datetime.datetime.fromtimestamp(0, datetime.timezone.utc)


@deprecate.function(message="it should not be used anymore")
def copy_elements_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector,
to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata:
return source_metadata._copy_elements_metadata(target_metadata, list(from_selector), list(to_selector), [], ignore_all_elements)


@deprecate.function(message="use Metadata.copy_to method instead")
def copy_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector,
to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata:
return source_metadata.copy_to(target_metadata, from_selector, to_selector, ignore_all_elements=ignore_all_elements)


@deprecate.function(message="use DataFrame.select_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def select_columns(inputs: container.DataFrame, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *,
source: typing.Any = None) -> container.DataFrame:
return inputs.select_columns(columns)


@deprecate.function(message="use DataMetadata.select_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def select_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *,
source: typing.Any = None) -> metadata_base.DataMetadata:
return inputs_metadata.select_columns(columns)


@deprecate.function(message="use DataMetadata.list_columns_with_semantic_types method instead")
def list_columns_with_semantic_types(metadata: metadata_base.DataMetadata, semantic_types: typing.Sequence[str], *,
at: metadata_base.Selector = ()) -> typing.Sequence[int]:
return metadata.list_columns_with_semantic_types(semantic_types, at=at)


@deprecate.function(message="use DataMetadata.list_columns_with_structural_types method instead")
def list_columns_with_structural_types(metadata: metadata_base.DataMetadata, structural_types: typing.Union[typing.Callable, typing.Sequence[typing.Union[str, type]]], *,
at: metadata_base.Selector = ()) -> typing.Sequence[int]:
return metadata.list_columns_with_structural_types(structural_types, at=at)


@deprecate.function(message="use DataFrame.remove_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def remove_columns(inputs: container.DataFrame, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> container.DataFrame:
return inputs.remove_columns(column_indices)


@deprecate.function(message="use DataMetadata.remove_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def remove_columns_metadata(inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata:
return inputs_metadata.remove_columns(column_indices)


@deprecate.function(message="use DataFrame.append_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def append_columns(left: container.DataFrame, right: container.DataFrame, *, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame:
return left.append_columns(right, use_right_metadata=use_right_metadata)


@deprecate.function(message="use DataMetadata.append_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def append_columns_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata:
return left_metadata.append_columns(right_metadata, use_right_metadata=use_right_metadata)


@deprecate.function(message="use DataFrame.insert_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def insert_columns(inputs: container.DataFrame, columns: container.DataFrame, at_column_index: int, *, source: typing.Any = None) -> container.DataFrame:
return inputs.insert_columns(columns, at_column_index)


@deprecate.function(message="use DataMetadata.insert_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def insert_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, at_column_index: int, *, source: typing.Any = None) -> metadata_base.DataMetadata:
return inputs_metadata.insert_columns(columns_metadata, at_column_index)


@deprecate.function(message="use DataFrame.replace_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def replace_columns(inputs: container.DataFrame, columns: container.DataFrame, column_indices: typing.Sequence[int], *, copy: bool = True, source: typing.Any = None) -> container.DataFrame:
return inputs.replace_columns(columns, column_indices, copy=copy)


@deprecate.function(message="use DataMetadata.replace_columns method instead")
@deprecate.arguments('source', message="argument ignored")
def replace_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata:
return inputs_metadata.replace_columns(columns_metadata, column_indices)


@deprecate.function(message="use DataMetadata.get_index_columns method instead")
def get_index_columns(metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = ()) -> typing.Sequence[int]:
return metadata.get_index_columns(at=at)


@deprecate.function(message="use DataFrame.horizontal_concat method instead")
@deprecate.arguments('source', message="argument ignored")
def horizontal_concat(left: container.DataFrame, right: container.DataFrame, *, use_index: bool = True,
remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame:
return left.horizontal_concat(right, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata)


@deprecate.function(message="use DataMetadata.horizontal_concat method instead")
@deprecate.arguments('source', message="argument ignored")
def horizontal_concat_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, *, use_index: bool = True,
remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata:
return left_metadata.horizontal_concat(right_metadata, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata)


@deprecate.function(message="use d3m.base.utils.get_columns_to_use function instead")
def get_columns_to_use(metadata: metadata_base.DataMetadata, use_columns: typing.Sequence[int], exclude_columns: typing.Sequence[int],
can_use_column: typing.Callable) -> typing.Tuple[typing.List[int], typing.List[int]]:
return base_utils.get_columns_to_use(metadata, use_columns, exclude_columns, can_use_column)


@deprecate.function(message="use d3m.base.utils.combine_columns function instead")
@deprecate.arguments('source', message="argument ignored")
def combine_columns(return_result: str, add_index_columns: bool, inputs: container.DataFrame, column_indices: typing.Sequence[int],
columns_list: typing.Sequence[container.DataFrame], *, source: typing.Any = None) -> container.DataFrame:
return base_utils.combine_columns(inputs, column_indices, columns_list, return_result=return_result, add_index_columns=add_index_columns)


@deprecate.function(message="use d3m.base.utils.combine_columns_metadata function instead")
@deprecate.arguments('source', message="argument ignored")
def combine_columns_metadata(return_result: str, add_index_columns: bool, inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int],
columns_metadata_list: typing.Sequence[metadata_base.DataMetadata], *, source: typing.Any = None) -> metadata_base.DataMetadata:
return base_utils.combine_columns_metadata(inputs_metadata, column_indices, columns_metadata_list, return_result=return_result, add_index_columns=add_index_columns)


@deprecate.function(message="use DataMetadata.set_table_metadata method instead")
@deprecate.arguments('source', message="argument ignored")
def set_table_metadata(inputs_metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = (), source: typing.Any = None) -> metadata_base.DataMetadata:
return inputs_metadata.set_table_metadata(at=at)


@deprecate.function(message="use DataMetadata.get_column_index_from_column_name method instead")
def get_column_index_from_column_name(inputs_metadata: metadata_base.DataMetadata, column_name: str, *, at: metadata_base.Selector = ()) -> int:
return inputs_metadata.get_column_index_from_column_name(column_name, at=at)


@deprecate.function(message="use Dataset.get_relations_graph method instead")
def build_relation_graph(dataset: container.Dataset) -> typing.Dict[str, typing.List[typing.Tuple[str, bool, int, int, typing.Dict]]]:
return dataset.get_relations_graph()


@deprecate.function(message="use d3m.base.utils.get_tabular_resource function instead")
def get_tabular_resource(dataset: container.Dataset, resource_id: typing.Optional[str], *,
pick_entry_point: bool = True, pick_one: bool = True, has_hyperparameter: bool = True) -> typing.Tuple[str, container.DataFrame]:
return base_utils.get_tabular_resource(dataset, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one, has_hyperparameter=has_hyperparameter)


@deprecate.function(message="use d3m.base.utils.get_tabular_resource_metadata function instead")
def get_tabular_resource_metadata(dataset_metadata: metadata_base.DataMetadata, resource_id: typing.Optional[metadata_base.SelectorSegment], *,
pick_entry_point: bool = True, pick_one: bool = True) -> metadata_base.SelectorSegment:
return base_utils.get_tabular_resource_metadata(dataset_metadata, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one)


@deprecate.function(message="use Dataset.select_rows method instead")
@deprecate.arguments('source', message="argument ignored")
def cut_dataset(dataset: container.Dataset, row_indices_to_keep: typing.Mapping[str, typing.Sequence[int]], *,
source: typing.Any = None) -> container.Dataset:
return dataset.select_rows(row_indices_to_keep)


def parse_datetime(value: str, *, fuzzy: bool = True) -> typing.Optional[datetime.datetime]:
try:
return dateutil.parser.parse(value, default=DEFAULT_DATETIME, fuzzy=fuzzy)
except (ValueError, OverflowError, TypeError):
return None


def parse_datetime_to_float(value: str, *, fuzzy: bool = True) -> float:
try:
parsed = parse_datetime(value, fuzzy=fuzzy)
if parsed is None:
return numpy.nan
else:
return parsed.timestamp()
except (ValueError, OverflowError, TypeError):
return numpy.nan

Loading…
Cancel
Save