From 315b908c1402ce302e1d149cca7bbfd3966368e2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 21 Nov 2021 19:08:57 +0800 Subject: [PATCH] feat(imperative): mge hub pretrained support s3 GitOrigin-RevId: a48e107623e9992ba15c84a55959dbb414df2117 --- .../python/megengine/data/dataset/vision/cifar.py | 4 +- .../python/megengine/data/dataset/vision/mnist.py | 2 +- .../python/megengine/data/dataset/vision/utils.py | 6 +-- imperative/python/megengine/hub/hub.py | 5 +-- imperative/python/megengine/utils/http_download.py | 43 ++++++---------------- imperative/python/requires.txt | 1 + 6 files changed, 19 insertions(+), 42 deletions(-) diff --git a/imperative/python/megengine/data/dataset/vision/cifar.py b/imperative/python/megengine/data/dataset/vision/cifar.py index 16e68a22..500b47b5 100644 --- a/imperative/python/megengine/data/dataset/vision/cifar.py +++ b/imperative/python/megengine/data/dataset/vision/cifar.py @@ -106,9 +106,7 @@ class CIFAR10(VisionDataset): def download(self): url = self.url_path + self.raw_file_name - load_raw_data_from_url( - url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout - ) + load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root) self.process() def untar(self, file_path, dirs): diff --git a/imperative/python/megengine/data/dataset/vision/mnist.py b/imperative/python/megengine/data/dataset/vision/mnist.py index efa81628..ae0d9435 100644 --- a/imperative/python/megengine/data/dataset/vision/mnist.py +++ b/imperative/python/megengine/data/dataset/vision/mnist.py @@ -118,7 +118,7 @@ class MNIST(VisionDataset): def download(self): for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): url = self.url_path + file_name - load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) + load_raw_data_from_url(url, file_name, md5, self.root) def process(self, train): # load raw files and transform them into meta data and datasets Tuple(np.array) diff --git a/imperative/python/megengine/data/dataset/vision/utils.py b/imperative/python/megengine/data/dataset/vision/utils.py index ed077842..ca878ce9 100644 --- a/imperative/python/megengine/data/dataset/vision/utils.py +++ b/imperative/python/megengine/data/dataset/vision/utils.py @@ -27,9 +27,7 @@ def _default_dataset_root(): return default_dataset_root -def load_raw_data_from_url( - url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int -): +def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str): cached_file = os.path.join(raw_data_dir, filename) logger.debug( "load_raw_data_from_url: downloading to or using cached %s ...", cached_file @@ -41,7 +39,7 @@ def load_raw_data_from_url( " File may be downloaded multiple times. We recommend\n" " users to download in single process first." ) - md5 = download_from_url(url, cached_file, http_read_timeout=timeout) + md5 = download_from_url(url, cached_file) else: md5 = calculate_md5(cached_file) if target_md5 == md5: diff --git a/imperative/python/megengine/hub/hub.py b/imperative/python/megengine/hub/hub.py index 953714bb..bad0dec8 100644 --- a/imperative/python/megengine/hub/hub.py +++ b/imperative/python/megengine/hub/hub.py @@ -25,7 +25,6 @@ from .const import ( DEFAULT_PROTOCOL, ENV_MGE_HOME, ENV_XDG_CACHE_HOME, - HTTP_READ_TIMEOUT, HUBCONF, HUBDEPENDENCY, ) @@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: " File may be downloaded multiple times. We recommend\n" " users to download in single process first." ) - download_from_url(url, cached_file, HTTP_READ_TIMEOUT) + download_from_url(url, cached_file) state_dict = _mge_load_serialized(cached_file) return state_dict class pretrained: - r"""Decorator which helps to download pretrained weights from the given url. + r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s). For example, we can decorate a resnet18 function as follows diff --git a/imperative/python/megengine/utils/http_download.py b/imperative/python/megengine/utils/http_download.py index 6342be48..47742535 100644 --- a/imperative/python/megengine/utils/http_download.py +++ b/imperative/python/megengine/utils/http_download.py @@ -12,6 +12,7 @@ import shutil from tempfile import NamedTemporaryFile import requests +from megfile import smart_copy, smart_getmd5, smart_getsize from tqdm import tqdm from ..logger import get_logger @@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException): r"""The class that represents http request error.""" -def download_from_url(url: str, dst: str, http_read_timeout=120): +class Bar: + def __init__(self, total=100): + self._bar = tqdm(total=total, unit="iB", unit_scale=True, ncols=80) + + def __call__(self, bytes_num): + self._bar.update(bytes_num) + + +def download_from_url(url: str, dst: str): r"""Downloads file from given url to ``dst``. Args: url: source URL. dst: saving path. - http_read_timeout: how many seconds to wait for data before giving up. """ dst = os.path.expanduser(dst) - dst_dir = os.path.dirname(dst) - - resp = requests.get( - url, timeout=(HTTP_CONNECTION_TIMEOUT, http_read_timeout), stream=True - ) - if resp.status_code != 200: - raise HTTPDownloadError("An error occured when downloading from {}".format(url)) - - md5 = hashlib.md5() - total_size = int(resp.headers.get("Content-Length", 0)) - bar = tqdm( - total=total_size, unit="iB", unit_scale=True, ncols=80 - ) # pylint: disable=blacklisted-name - try: - with NamedTemporaryFile("w+b", delete=False, suffix=".tmp", dir=dst_dir) as f: - logger.info("Download file to temp file %s", f.name) - for chunk in resp.iter_content(CHUNK_SIZE): - if not chunk: - break - bar.update(len(chunk)) - f.write(chunk) - md5.update(chunk) - bar.close() - shutil.move(f.name, dst) - finally: - # ensure tmp file is removed - if os.path.exists(f.name): - os.remove(f.name) - return md5.hexdigest() + smart_copy(url, dst, callback=Bar(total=smart_getsize(url))) + return smart_getmd5(dst) diff --git a/imperative/python/requires.txt b/imperative/python/requires.txt index 670193dc..58a806c0 100644 --- a/imperative/python/requires.txt +++ b/imperative/python/requires.txt @@ -8,3 +8,4 @@ redispy deprecated mprop wheel +megfile>=0.0.10 \ No newline at end of file