You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytree.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import Callable, NamedTuple
  11. SUPPORTED_TYPE = {}
  12. NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
  13. def register_supported_type(type, flatten, unflatten):
  14. SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
  15. def _dict_flatten(inp):
  16. aux_data = []
  17. results = []
  18. for key, value in sorted(inp.items()):
  19. results.append(value)
  20. aux_data.append(key)
  21. return results, aux_data
  22. def _dict_unflatten(inps, aux_data):
  23. return dict(zip(aux_data, inps))
  24. register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
  25. register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
  26. register_supported_type(dict, _dict_flatten, _dict_unflatten)
  27. register_supported_type(
  28. slice,
  29. lambda x: ([x.start, x.stop, x.step], None),
  30. lambda x, aux_data: slice(x[0], x[1], x[2]),
  31. )
  32. def tree_flatten(
  33. values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
  34. ):
  35. if type(values) not in SUPPORTED_TYPE:
  36. assert is_leaf(values)
  37. return [values,], LeafDef(leaf_type(values))
  38. rst = []
  39. children_defs = []
  40. children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
  41. for v in children_values:
  42. v_list, treedef = tree_flatten(v, leaf_type)
  43. rst.extend(v_list)
  44. children_defs.append(treedef)
  45. return rst, TreeDef(type(values), aux_data, children_defs)
  46. class TreeDef:
  47. def __init__(self, type, aux_data, children_defs):
  48. self.type = type
  49. self.aux_data = aux_data
  50. self.children_defs = children_defs
  51. self.num_leaves = sum(ch.num_leaves for ch in children_defs)
  52. def unflatten(self, leaves):
  53. assert len(leaves) == self.num_leaves
  54. start = 0
  55. children = []
  56. for ch in self.children_defs:
  57. children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
  58. start += ch.num_leaves
  59. return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
  60. def __eq__(self, other):
  61. return (
  62. self.type == other.type
  63. and self.aux_data == other.aux_data
  64. and self.num_leaves == other.num_leaves
  65. and self.children_defs == other.children_defs
  66. )
  67. def __repr__(self):
  68. return "{}[{}]".format(self.type.__name__, self.children_defs)
  69. class LeafDef(TreeDef):
  70. def __init__(self, type):
  71. if not isinstance(type, collections.abc.Sequence):
  72. type = (type,)
  73. super().__init__(type, None, [])
  74. self.num_leaves = 1
  75. def unflatten(self, leaves):
  76. assert len(leaves) == 1
  77. assert isinstance(leaves[0], self.type), self.type
  78. return leaves[0]
  79. def __repr__(self):
  80. return "Leaf({})".format(", ".join(t.__name__ for t in self.type))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台