You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_utils.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import contextlib
  9. from collections import Iterable
  10. from ..module import Sequential
  11. from ..module.module import Module, _access_structure
  12. from ..tensor import Tensor
  13. def get_expand_structure(obj: Module, key: str):
  14. """
  15. Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  16. Supports handling structure containing list or dict.
  17. """
  18. def f(_, __, cur):
  19. return cur
  20. return _access_structure(obj, key, callback=f)
  21. def set_expand_structure(obj: Module, key: str, value):
  22. """
  23. Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  24. Supports handling structure containing list or dict.
  25. """
  26. def f(parent, key, cur):
  27. if isinstance(parent, (Tensor, Module)):
  28. # cannnot use setattr to be compatible with Sequential's ``__setitem__``
  29. if isinstance(cur, Sequential):
  30. parent[int(key)] = value
  31. else:
  32. setattr(parent, key, value)
  33. else:
  34. parent[key] = value
  35. _access_structure(obj, key, callback=f)
  36. @contextlib.contextmanager
  37. def set_module_mode_safe(
  38. module: Module, training: bool = False,
  39. ):
  40. """Adjust module to training/eval mode temporarily.
  41. :param module: used module.
  42. :param training: training (bool): training mode. True for train mode, False fro eval mode.
  43. """
  44. backup_stats = {}
  45. def recursive_backup_stats(module, mode):
  46. for m in module.modules():
  47. backup_stats[m] = m.training
  48. m.train(mode, recursive=False)
  49. def recursive_recover_stats(module):
  50. for m in module.modules():
  51. m.training = backup_stats.pop(m)
  52. recursive_backup_stats(module, mode=training)
  53. yield module
  54. recursive_recover_stats(module)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台