You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

naming.py 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from ..core._imperative_rt.core2 import pop_scope, push_scope
  9. class AutoNaming:
  10. r"""
  11. Name all executed operators automaticlly during tracing and record all tensors
  12. renamed by the user.
  13. """
  14. scopes = []
  15. c_ops = []
  16. name2ops = {}
  17. handle2names = {}
  18. __cls_attributes__ = {"scopes", "c_ops", "name2ops", "handle2names"}
  19. @classmethod
  20. def clear(cls):
  21. for attr in cls.__cls_attributes__:
  22. getattr(cls, attr).clear()
  23. @classmethod
  24. def push_scope(cls, scope):
  25. if scope is not None:
  26. push_scope(scope)
  27. cls.scopes.append(scope)
  28. @classmethod
  29. def pop_scope(cls):
  30. scope = cls.scopes.pop()
  31. if scope is not None:
  32. pop_scope(scope)
  33. @classmethod
  34. def get_scope(cls):
  35. return ".".join(s for s in cls.scopes if s is not None)
  36. @classmethod
  37. def gen_name(cls, x) -> str:
  38. scope = cls.get_scope()
  39. name = x.c_name if x.c_name else x._name
  40. return scope + "." + name if len(scope) else name
  41. @classmethod
  42. def record_var_name(cls, handle, name):
  43. cls.handle2names[handle] = name
  44. @classmethod
  45. def get_var_name(cls, handle):
  46. return cls.handle2names.pop(handle, None)
  47. @classmethod
  48. def record_opnode(cls, op):
  49. ops = cls.name2ops.get(op.name, [])
  50. if op not in ops:
  51. ops.append(op)
  52. cls.name2ops[op.name] = ops
  53. @classmethod
  54. def remove_duplicate_names(cls):
  55. for key, ops in cls.name2ops.items():
  56. if len(ops) == 1:
  57. continue
  58. for i, op in enumerate(ops):
  59. op.name = key + "[%s]" % str(i)
  60. if len(op.outputs) == 1:
  61. continue
  62. for var in op.outputs:
  63. var.name = var.name.replace(key, op.name)
  64. cls.name2ops.clear()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台