You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytree.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import Callable, NamedTuple
  2. SUPPORTED_TYPE = {}
  3. NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
  4. def register_supported_type(type, flatten, unflatten):
  5. SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
  6. register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
  7. register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
  8. register_supported_type(
  9. dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x))
  10. )
  11. register_supported_type(
  12. slice,
  13. lambda x: ([x.start, x.stop, x.step], None),
  14. lambda x, aux_data: slice(x[0], x[1], x[2]),
  15. )
  16. def tree_flatten(
  17. values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
  18. ):
  19. if type(values) not in SUPPORTED_TYPE:
  20. assert is_leaf(values)
  21. return [values,], LeafDef(leaf_type(values))
  22. rst = []
  23. children_defs = []
  24. children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
  25. for v in children_values:
  26. v_list, treedef = tree_flatten(v, leaf_type)
  27. rst.extend(v_list)
  28. children_defs.append(treedef)
  29. return rst, TreeDef(type(values), aux_data, children_defs)
  30. class TreeDef:
  31. def __init__(self, type, aux_data, children_defs):
  32. self.type = type
  33. self.aux_data = aux_data
  34. self.children_defs = children_defs
  35. self.num_leaves = sum(ch.num_leaves for ch in children_defs)
  36. def unflatten(self, leaves):
  37. assert len(leaves) == self.num_leaves
  38. start = 0
  39. children = []
  40. for ch in self.children_defs:
  41. children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
  42. start += ch.num_leaves
  43. return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
  44. def __eq__(self, other):
  45. return (
  46. self.type == other.type
  47. and self.aux_data == other.aux_data
  48. and self.num_leaves == other.num_leaves
  49. and self.children_defs == other.children_defs
  50. )
  51. def __repr__(self):
  52. return "{}[{}]".format(self.type.__name__, self.children_defs)
  53. class LeafDef(TreeDef):
  54. def __init__(self, type):
  55. super().__init__(type, None, [])
  56. self.num_leaves = 1
  57. def unflatten(self, leaves):
  58. assert len(leaves) == 1
  59. assert isinstance(leaves[0], self.type), self.type
  60. return leaves[0]
  61. def __repr__(self):
  62. return "Leaf({})".format(self.type.__name__)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台