You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dtr_config.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # -*- coding: utf-8 -*-
  2. class DTRConfig:
  3. r"""Configuration for DTR memory optimization.
  4. Args:
  5. eviction_threshold: eviction threshold in bytes. When GPU memory usage
  6. exceeds this value, DTR will heuristically select and evict resident
  7. tensors until the amount of used memory falls below this threshold.
  8. evictee_minimum_size: memory threshold of tensors in bytes. Only tensors
  9. whose size exceeds this threshold will be added to the candidate set.
  10. A tensor that is not added to the candidate set will never be evicted
  11. during its lifetime. Default: 1048576.
  12. recomp_memory_factor: hyperparameter of the estimated memory of recomputing
  13. the tensor. The larger this value is, the less memory-consuming
  14. tensor will be evicted in heuristic strategies. This value is greater
  15. than or equal to 0. Default: 1.
  16. recomp_time_factor: hyperparameter of the estimated time of recomputing
  17. the tensor. The larger this value is, the less time-consuming
  18. tensor will be evicted in heuristic strategies. This value is greater
  19. than or equal to 0. Default: 1.
  20. """
  21. def __init__(
  22. self,
  23. eviction_threshold: int = 0,
  24. evictee_minimum_size: int = 1 << 20,
  25. recomp_memory_factor: float = 1,
  26. recomp_time_factor: float = 1,
  27. ):
  28. assert eviction_threshold > 0, "eviction_threshold must be greater to zero"
  29. self.eviction_threshold = eviction_threshold
  30. assert (
  31. evictee_minimum_size >= 0
  32. ), "evictee_minimum_size must be greater or equal to zero"
  33. self.evictee_minimum_size = evictee_minimum_size
  34. assert (
  35. recomp_memory_factor >= 0
  36. ), "recomp_memory_factor must be greater or equal to zero"
  37. self.recomp_memory_factor = recomp_memory_factor
  38. assert (
  39. recomp_time_factor >= 0
  40. ), "recomp_time_factor must be greater or equal to zero"
  41. self.recomp_time_factor = recomp_time_factor