import os import os.path import hashlib import gzip import errno import tarfile from typing import Any, Callable, List, Iterable, Optional, TypeVar import zipfile # import torch # from torch.utils.model_zoo import tqdm from hub import tqdm def gen_bar_updater() -> Callable[[int, int, int], None]: pbar = tqdm(total=None) def bar_update(count, block_size, total_size): if pbar.total is None and total_size: pbar.total = total_size progress_bytes = count * block_size pbar.update(progress_bytes - pbar.n) return bar_update def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: md5 = hashlib.md5() with open(fpath, 'rb') as f: for chunk in iter(lambda: f.read(chunk_size), b''): md5.update(chunk) return md5.hexdigest() def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: return md5 == calculate_md5(fpath, **kwargs) def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: if not os.path.isfile(fpath): return False if md5 is None: return True return check_md5(fpath, md5) def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: """Download a file from a url and place it in root. Args: url (str): URL to download file from root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the basename of the URL md5 (str, optional): MD5 checksum of the download. If None, do not check """ import urllib root = os.path.expanduser(root) if not filename: filename = os.path.basename(url) fpath = os.path.join(root, filename) os.makedirs(root, exist_ok=True) # check if file is already present locally if check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) else: # download the file try: print('Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] if url[:5] == 'https': url = url.replace('https:', 'http:') print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) else: raise e # check integrity of downloaded file if not check_integrity(fpath, md5): raise RuntimeError("File not found or corrupted.") def list_dir(root: str, prefix: bool = False) -> List[str]: """List all directories at a given root Args: root (str): Path to directory whose folders need to be listed prefix (bool, optional): If true, prepends the path to each result, otherwise only returns the name of the directories found """ root = os.path.expanduser(root) directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] if prefix is True: directories = [os.path.join(root, d) for d in directories] return directories def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: """List all files ending with a suffix at a given root Args: root (str): Path to directory whose folders need to be listed suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). It uses the Python "str.endswith" method and is passed directly prefix (bool, optional): If true, prepends the path to each result, otherwise only returns the name of the files found """ root = os.path.expanduser(root) files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] if prefix is True: files = [os.path.join(root, d) for d in files] return files def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined] return "Google Drive - Quota exceeded" in response.text def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): """Download a Google Drive file from and place it in root. Args: file_id (str): id of file to be downloaded root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the id of the file. md5 (str, optional): MD5 checksum of the download. If None, do not check """ # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url import requests url = "https://docs.google.com/uc?export=download" root = os.path.expanduser(root) if not filename: filename = file_id fpath = os.path.join(root, filename) os.makedirs(root, exist_ok=True) if os.path.isfile(fpath) and check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) else: session = requests.Session() response = session.get(url, params={'id': file_id}, stream=True) token = _get_confirm_token(response) if token: params = {'id': file_id, 'confirm': token} response = session.get(url, params=params, stream=True) if _quota_exceeded(response): msg = ( f"The daily quota of the file {filename} is exceeded and it " f"can't be downloaded. This is a limitation of Google Drive " f"and can only be overcome by trying again later." ) raise RuntimeError(msg) _save_response_content(response, fpath) def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None def _save_response_content( response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined] ) -> None: with open(destination, "wb") as f: pbar = tqdm(total=None) progress = 0 for chunk in response.iter_content(chunk_size): if chunk: # filter out keep-alive new chunks f.write(chunk) progress += len(chunk) pbar.update(progress - pbar.n) pbar.close() def _is_tarxz(filename: str) -> bool: return filename.endswith(".tar.xz") def _is_tar(filename: str) -> bool: return filename.endswith(".tar") def _is_targz(filename: str) -> bool: return filename.endswith(".tar.gz") def _is_tgz(filename: str) -> bool: return filename.endswith(".tgz") def _is_gzip(filename: str) -> bool: return filename.endswith(".gz") and not filename.endswith(".tar.gz") def _is_zip(filename: str) -> bool: return filename.endswith(".zip") def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: if to_path is None: to_path = os.path.dirname(from_path) if _is_tar(from_path): with tarfile.open(from_path, 'r') as tar: tar.extractall(path=to_path) elif _is_targz(from_path) or _is_tgz(from_path): with tarfile.open(from_path, 'r:gz') as tar: tar.extractall(path=to_path) elif _is_tarxz(from_path): with tarfile.open(from_path, 'r:xz') as tar: tar.extractall(path=to_path) elif _is_gzip(from_path): to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: out_f.write(zip_f.read()) elif _is_zip(from_path): with zipfile.ZipFile(from_path, 'r') as z: z.extractall(to_path) else: raise ValueError("Extraction of {} not supported".format(from_path)) if remove_finished: os.remove(from_path) def download_and_extract_archive( url: str, download_root: str, extract_root: Optional[str] = None, filename: Optional[str] = None, md5: Optional[str] = None, remove_finished: bool = False, ) -> None: download_root = os.path.expanduser(download_root) if extract_root is None: extract_root = download_root if not filename: filename = os.path.basename(url) download_url(url, download_root, filename, md5) archive = os.path.join(download_root, filename) print("Extracting {} to {}".format(archive, extract_root)) # print(archive) # print(extract_root) # extract_archive(archive, extract_root, remove_finished) def iterable_to_str(iterable: Iterable) -> str: return "'" + "', '".join([str(item) for item in iterable]) + "'" T = TypeVar("T", str, bytes) def verify_str_arg( value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, ) -> T: if not isinstance(value, torch._six.string_classes): if arg is None: msg = "Expected type str, but got type {type}." else: msg = "Expected type str for argument {arg}, but got type {type}." msg = msg.format(type=type(value), arg=arg) raise ValueError(msg) if valid_values is None: return value if value not in valid_values: if custom_msg is not None: msg = custom_msg else: msg = ("Unknown value '{value}' for argument {arg}. " "Valid values are {{{valid_values}}}.") msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) raise ValueError(msg) return value