You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_utils.py 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import contextlib
  2. from ..module import Sequential
  3. from ..module.module import Module, _access_structure
  4. from ..tensor import Tensor
  5. def get_expand_structure(obj: Module, key: str):
  6. r"""Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  7. Supports handling structure containing list or dict.
  8. Args:
  9. obj: Module:
  10. key: str:
  11. """
  12. def f(_, __, cur):
  13. return cur
  14. return _access_structure(obj, key, callback=f)
  15. def set_expand_structure(obj: Module, key: str, value):
  16. r"""Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  17. Supports handling structure containing list or dict.
  18. """
  19. def f(parent, key, cur):
  20. if isinstance(parent, (Tensor, Module)):
  21. # cannnot use setattr to be compatible with Sequential's ``__setitem__``
  22. if isinstance(cur, Sequential):
  23. parent[int(key)] = value
  24. else:
  25. setattr(parent, key, value)
  26. else:
  27. parent[key] = value
  28. _access_structure(obj, key, callback=f)
  29. @contextlib.contextmanager
  30. def set_module_mode_safe(
  31. module: Module, training: bool = False,
  32. ):
  33. r"""Adjust module to training/eval mode temporarily.
  34. Args:
  35. module: used module.
  36. training: training (bool): training mode. True for train mode, False fro eval mode.
  37. """
  38. backup_stats = {}
  39. def recursive_backup_stats(module, mode):
  40. for m in module.modules():
  41. backup_stats[m] = m.training
  42. m.train(mode, recursive=False)
  43. def recursive_recover_stats(module):
  44. for m in module.modules():
  45. m.training = backup_stats.pop(m)
  46. recursive_backup_stats(module, mode=training)
  47. yield module
  48. recursive_recover_stats(module)