You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # -*- coding: utf-8 -*-
  2. import hashlib
  3. import os
  4. import tarfile
  5. from ....distributed.group import is_distributed
  6. from ....logger import get_logger
  7. from ....utils.http_download import download_from_url
  8. IMG_EXT = (".jpg", ".png", ".jpeg", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
  9. logger = get_logger(__name__)
  10. def _default_dataset_root():
  11. default_dataset_root = os.path.expanduser(
  12. os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "megengine")
  13. )
  14. return default_dataset_root
  15. def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str):
  16. cached_file = os.path.join(raw_data_dir, filename)
  17. logger.debug(
  18. "load_raw_data_from_url: downloading to or using cached %s ...", cached_file
  19. )
  20. if not os.path.exists(cached_file):
  21. if is_distributed():
  22. logger.warning(
  23. "Downloading raw data in DISTRIBUTED mode\n"
  24. " File may be downloaded multiple times. We recommend\n"
  25. " users to download in single process first."
  26. )
  27. md5 = download_from_url(url, cached_file)
  28. else:
  29. md5 = calculate_md5(cached_file)
  30. if target_md5 == md5:
  31. logger.debug("%s exists with correct md5: %s", filename, target_md5)
  32. else:
  33. os.remove(cached_file)
  34. raise RuntimeError("{} exists but fail to match md5".format(filename))
  35. def calculate_md5(filename):
  36. m = hashlib.md5()
  37. with open(filename, "rb") as f:
  38. while True:
  39. data = f.read(4096)
  40. if not data:
  41. break
  42. m.update(data)
  43. return m.hexdigest()
  44. def is_img(filename):
  45. return filename.lower().endswith(IMG_EXT)
  46. def untar(path, to=None, remove=False):
  47. if to is None:
  48. to = os.path.dirname(path)
  49. with tarfile.open(path, "r") as tar:
  50. tar.extractall(path=to)
  51. if remove:
  52. os.remove(path)
  53. def untargz(path, to=None, remove=False):
  54. if path.endswith(".tar.gz"):
  55. if to is None:
  56. to = os.path.dirname(path)
  57. with tarfile.open(path, "r:gz") as tar:
  58. tar.extractall(path=to)
  59. else:
  60. raise ValueError("path %s does not end with .tar" % path)
  61. if remove:
  62. os.remove(path)