You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.2 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import copy
  9. from typing import Any, Dict, List
  10. from ..expr import Expr, is_constant, is_getattr
  11. from ..node import Node, TensorNode
  12. def register_obj(objs: List[Any], _dict: Dict):
  13. if not isinstance(objs, List):
  14. objs = [objs]
  15. def _register(any_obj: Any):
  16. for obj in objs:
  17. _dict[obj] = any_obj
  18. return any_obj
  19. return _register
  20. def get_const_value(expr: Expr, fall_back: Any = None):
  21. value = fall_back
  22. if isinstance(expr, Node):
  23. expr = expr.expr
  24. if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode):
  25. module = expr.inputs[0].owner
  26. assert module is not None
  27. value = copy.deepcopy(expr.interpret(module)[0])
  28. elif is_constant(expr):
  29. value = copy.deepcopy(expr.interpret()[0])
  30. return value

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台