You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_utils.py 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import contextlib
  9. from ..module import Sequential
  10. from ..module.module import Module, _access_structure
  11. from ..tensor import Tensor
  12. def get_expand_structure(obj: Module, key: str):
  13. r"""Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  14. Supports handling structure containing list or dict.
  15. Args:
  16. obj: Module:
  17. key: str:
  18. """
  19. def f(_, __, cur):
  20. return cur
  21. return _access_structure(obj, key, callback=f)
  22. def set_expand_structure(obj: Module, key: str, value):
  23. r"""Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  24. Supports handling structure containing list or dict.
  25. """
  26. def f(parent, key, cur):
  27. if isinstance(parent, (Tensor, Module)):
  28. # cannnot use setattr to be compatible with Sequential's ``__setitem__``
  29. if isinstance(cur, Sequential):
  30. parent[int(key)] = value
  31. else:
  32. setattr(parent, key, value)
  33. else:
  34. parent[key] = value
  35. _access_structure(obj, key, callback=f)
  36. @contextlib.contextmanager
  37. def set_module_mode_safe(
  38. module: Module, training: bool = False,
  39. ):
  40. r"""Adjust module to training/eval mode temporarily.
  41. Args:
  42. module: used module.
  43. training: training (bool): training mode. True for train mode, False fro eval mode.
  44. """
  45. backup_stats = {}
  46. def recursive_backup_stats(module, mode):
  47. for m in module.modules():
  48. backup_stats[m] = m.training
  49. m.train(mode, recursive=False)
  50. def recursive_recover_stats(module):
  51. for m in module.modules():
  52. m.training = backup_stats.pop(m)
  53. recursive_backup_stats(module, mode=training)
  54. yield module
  55. recursive_recover_stats(module)