You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module_utils.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from collections import Iterable
  9. from ..module import Sequential
  10. from ..module.module import Module, _access_structure
  11. from ..tensor import Tensor
  12. def get_expand_structure(obj: Module, key: str):
  13. """
  14. Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  15. Supports handling structure containing list or dict.
  16. """
  17. def f(_, __, cur):
  18. return cur
  19. return _access_structure(obj, key, callback=f)
  20. def set_expand_structure(obj: Module, key: str, value):
  21. """
  22. Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
  23. Supports handling structure containing list or dict.
  24. """
  25. def f(parent, key, cur):
  26. if isinstance(parent, (Tensor, Module)):
  27. # cannnot use setattr to be compatible with Sequential's ``__setitem__``
  28. if isinstance(cur, Sequential):
  29. parent[int(key)] = value
  30. else:
  31. setattr(parent, key, value)
  32. else:
  33. parent[key] = value
  34. _access_structure(obj, key, callback=f)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台