Browse Source

feat(imperative): mge hub pretrained support s3

GitOrigin-RevId: a48e107623
tags/v1.7.1.m1
Megvii Engine Team XindaH 3 years ago
parent
commit
315b908c14
6 changed files with 19 additions and 42 deletions
  1. +1
    -3
      imperative/python/megengine/data/dataset/vision/cifar.py
  2. +1
    -1
      imperative/python/megengine/data/dataset/vision/mnist.py
  3. +2
    -4
      imperative/python/megengine/data/dataset/vision/utils.py
  4. +2
    -3
      imperative/python/megengine/hub/hub.py
  5. +12
    -31
      imperative/python/megengine/utils/http_download.py
  6. +1
    -0
      imperative/python/requires.txt

+ 1
- 3
imperative/python/megengine/data/dataset/vision/cifar.py View File

@@ -106,9 +106,7 @@ class CIFAR10(VisionDataset):

def download(self):
url = self.url_path + self.raw_file_name
load_raw_data_from_url(
url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout
)
load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root)
self.process()

def untar(self, file_path, dirs):


+ 1
- 1
imperative/python/megengine/data/dataset/vision/mnist.py View File

@@ -118,7 +118,7 @@ class MNIST(VisionDataset):
def download(self):
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
url = self.url_path + file_name
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout)
load_raw_data_from_url(url, file_name, md5, self.root)

def process(self, train):
# load raw files and transform them into meta data and datasets Tuple(np.array)


+ 2
- 4
imperative/python/megengine/data/dataset/vision/utils.py View File

@@ -27,9 +27,7 @@ def _default_dataset_root():
return default_dataset_root


def load_raw_data_from_url(
url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int
):
def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str):
cached_file = os.path.join(raw_data_dir, filename)
logger.debug(
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file
@@ -41,7 +39,7 @@ def load_raw_data_from_url(
" File may be downloaded multiple times. We recommend\n"
" users to download in single process first."
)
md5 = download_from_url(url, cached_file, http_read_timeout=timeout)
md5 = download_from_url(url, cached_file)
else:
md5 = calculate_md5(cached_file)
if target_md5 == md5:


+ 2
- 3
imperative/python/megengine/hub/hub.py View File

@@ -25,7 +25,6 @@ from .const import (
DEFAULT_PROTOCOL,
ENV_MGE_HOME,
ENV_XDG_CACHE_HOME,
HTTP_READ_TIMEOUT,
HUBCONF,
HUBDEPENDENCY,
)
@@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
" File may be downloaded multiple times. We recommend\n"
" users to download in single process first."
)
download_from_url(url, cached_file, HTTP_READ_TIMEOUT)
download_from_url(url, cached_file)

state_dict = _mge_load_serialized(cached_file)
return state_dict


class pretrained:
r"""Decorator which helps to download pretrained weights from the given url.
r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s).

For example, we can decorate a resnet18 function as follows



+ 12
- 31
imperative/python/megengine/utils/http_download.py View File

@@ -12,6 +12,7 @@ import shutil
from tempfile import NamedTemporaryFile

import requests
from megfile import smart_copy, smart_getmd5, smart_getsize
from tqdm import tqdm

from ..logger import get_logger
@@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException):
r"""The class that represents http request error."""


def download_from_url(url: str, dst: str, http_read_timeout=120):
class Bar:
def __init__(self, total=100):
self._bar = tqdm(total=total, unit="iB", unit_scale=True, ncols=80)

def __call__(self, bytes_num):
self._bar.update(bytes_num)


def download_from_url(url: str, dst: str):
r"""Downloads file from given url to ``dst``.

Args:
url: source URL.
dst: saving path.
http_read_timeout: how many seconds to wait for data before giving up.
"""
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)

resp = requests.get(
url, timeout=(HTTP_CONNECTION_TIMEOUT, http_read_timeout), stream=True
)
if resp.status_code != 200:
raise HTTPDownloadError("An error occured when downloading from {}".format(url))

md5 = hashlib.md5()
total_size = int(resp.headers.get("Content-Length", 0))
bar = tqdm(
total=total_size, unit="iB", unit_scale=True, ncols=80
) # pylint: disable=blacklisted-name
try:
with NamedTemporaryFile("w+b", delete=False, suffix=".tmp", dir=dst_dir) as f:
logger.info("Download file to temp file %s", f.name)
for chunk in resp.iter_content(CHUNK_SIZE):
if not chunk:
break
bar.update(len(chunk))
f.write(chunk)
md5.update(chunk)
bar.close()
shutil.move(f.name, dst)
finally:
# ensure tmp file is removed
if os.path.exists(f.name):
os.remove(f.name)
return md5.hexdigest()
smart_copy(url, dst, callback=Bar(total=smart_getsize(url)))
return smart_getmd5(dst)

+ 1
- 0
imperative/python/requires.txt View File

@@ -8,3 +8,4 @@ redispy
deprecated
mprop
wheel
megfile>=0.0.10

Loading…
Cancel
Save