GitOrigin-RevId: a48e107623
tags/v1.7.1.m1
@@ -106,9 +106,7 @@ class CIFAR10(VisionDataset): | |||||
def download(self): | def download(self): | ||||
url = self.url_path + self.raw_file_name | url = self.url_path + self.raw_file_name | ||||
load_raw_data_from_url( | |||||
url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout | |||||
) | |||||
load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root) | |||||
self.process() | self.process() | ||||
def untar(self, file_path, dirs): | def untar(self, file_path, dirs): | ||||
@@ -118,7 +118,7 @@ class MNIST(VisionDataset): | |||||
def download(self): | def download(self): | ||||
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | ||||
url = self.url_path + file_name | url = self.url_path + file_name | ||||
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | |||||
load_raw_data_from_url(url, file_name, md5, self.root) | |||||
def process(self, train): | def process(self, train): | ||||
# load raw files and transform them into meta data and datasets Tuple(np.array) | # load raw files and transform them into meta data and datasets Tuple(np.array) | ||||
@@ -27,9 +27,7 @@ def _default_dataset_root(): | |||||
return default_dataset_root | return default_dataset_root | ||||
def load_raw_data_from_url( | |||||
url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int | |||||
): | |||||
def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str): | |||||
cached_file = os.path.join(raw_data_dir, filename) | cached_file = os.path.join(raw_data_dir, filename) | ||||
logger.debug( | logger.debug( | ||||
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file | "load_raw_data_from_url: downloading to or using cached %s ...", cached_file | ||||
@@ -41,7 +39,7 @@ def load_raw_data_from_url( | |||||
" File may be downloaded multiple times. We recommend\n" | " File may be downloaded multiple times. We recommend\n" | ||||
" users to download in single process first." | " users to download in single process first." | ||||
) | ) | ||||
md5 = download_from_url(url, cached_file, http_read_timeout=timeout) | |||||
md5 = download_from_url(url, cached_file) | |||||
else: | else: | ||||
md5 = calculate_md5(cached_file) | md5 = calculate_md5(cached_file) | ||||
if target_md5 == md5: | if target_md5 == md5: | ||||
@@ -25,7 +25,6 @@ from .const import ( | |||||
DEFAULT_PROTOCOL, | DEFAULT_PROTOCOL, | ||||
ENV_MGE_HOME, | ENV_MGE_HOME, | ||||
ENV_XDG_CACHE_HOME, | ENV_XDG_CACHE_HOME, | ||||
HTTP_READ_TIMEOUT, | |||||
HUBCONF, | HUBCONF, | ||||
HUBDEPENDENCY, | HUBDEPENDENCY, | ||||
) | ) | ||||
@@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: | |||||
" File may be downloaded multiple times. We recommend\n" | " File may be downloaded multiple times. We recommend\n" | ||||
" users to download in single process first." | " users to download in single process first." | ||||
) | ) | ||||
download_from_url(url, cached_file, HTTP_READ_TIMEOUT) | |||||
download_from_url(url, cached_file) | |||||
state_dict = _mge_load_serialized(cached_file) | state_dict = _mge_load_serialized(cached_file) | ||||
return state_dict | return state_dict | ||||
class pretrained: | class pretrained: | ||||
r"""Decorator which helps to download pretrained weights from the given url. | |||||
r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s). | |||||
For example, we can decorate a resnet18 function as follows | For example, we can decorate a resnet18 function as follows | ||||
@@ -12,6 +12,7 @@ import shutil | |||||
from tempfile import NamedTemporaryFile | from tempfile import NamedTemporaryFile | ||||
import requests | import requests | ||||
from megfile import smart_copy, smart_getmd5, smart_getsize | |||||
from tqdm import tqdm | from tqdm import tqdm | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
@@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException): | |||||
r"""The class that represents http request error.""" | r"""The class that represents http request error.""" | ||||
def download_from_url(url: str, dst: str, http_read_timeout=120): | |||||
class Bar: | |||||
def __init__(self, total=100): | |||||
self._bar = tqdm(total=total, unit="iB", unit_scale=True, ncols=80) | |||||
def __call__(self, bytes_num): | |||||
self._bar.update(bytes_num) | |||||
def download_from_url(url: str, dst: str): | |||||
r"""Downloads file from given url to ``dst``. | r"""Downloads file from given url to ``dst``. | ||||
Args: | Args: | ||||
url: source URL. | url: source URL. | ||||
dst: saving path. | dst: saving path. | ||||
http_read_timeout: how many seconds to wait for data before giving up. | |||||
""" | """ | ||||
dst = os.path.expanduser(dst) | dst = os.path.expanduser(dst) | ||||
dst_dir = os.path.dirname(dst) | |||||
resp = requests.get( | |||||
url, timeout=(HTTP_CONNECTION_TIMEOUT, http_read_timeout), stream=True | |||||
) | |||||
if resp.status_code != 200: | |||||
raise HTTPDownloadError("An error occured when downloading from {}".format(url)) | |||||
md5 = hashlib.md5() | |||||
total_size = int(resp.headers.get("Content-Length", 0)) | |||||
bar = tqdm( | |||||
total=total_size, unit="iB", unit_scale=True, ncols=80 | |||||
) # pylint: disable=blacklisted-name | |||||
try: | |||||
with NamedTemporaryFile("w+b", delete=False, suffix=".tmp", dir=dst_dir) as f: | |||||
logger.info("Download file to temp file %s", f.name) | |||||
for chunk in resp.iter_content(CHUNK_SIZE): | |||||
if not chunk: | |||||
break | |||||
bar.update(len(chunk)) | |||||
f.write(chunk) | |||||
md5.update(chunk) | |||||
bar.close() | |||||
shutil.move(f.name, dst) | |||||
finally: | |||||
# ensure tmp file is removed | |||||
if os.path.exists(f.name): | |||||
os.remove(f.name) | |||||
return md5.hexdigest() | |||||
smart_copy(url, dst, callback=Bar(total=smart_getsize(url))) | |||||
return smart_getmd5(dst) |
@@ -8,3 +8,4 @@ redispy | |||||
deprecated | deprecated | ||||
mprop | mprop | ||||
wheel | wheel | ||||
megfile>=0.0.10 |