from typing import Callable, NamedTuple SUPPORTED_TYPE = {} NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) def register_supported_type(type, flatten, unflatten): SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type( dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) ) register_supported_type( slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, aux_data: slice(x[0], x[1], x[2]), ) def tree_flatten( values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True ): if type(values) not in SUPPORTED_TYPE: assert is_leaf(values) return [values,], LeafDef(leaf_type(values)) rst = [] children_defs = [] children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) for v in children_values: v_list, treedef = tree_flatten(v, leaf_type) rst.extend(v_list) children_defs.append(treedef) return rst, TreeDef(type(values), aux_data, children_defs) class TreeDef: def __init__(self, type, aux_data, children_defs): self.type = type self.aux_data = aux_data self.children_defs = children_defs self.num_leaves = sum(ch.num_leaves for ch in children_defs) def unflatten(self, leaves): assert len(leaves) == self.num_leaves start = 0 children = [] for ch in self.children_defs: children.append(ch.unflatten(leaves[start : start + ch.num_leaves])) start += ch.num_leaves return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) def __eq__(self, other): return ( self.type == other.type and self.aux_data == other.aux_data and self.num_leaves == other.num_leaves and self.children_defs == other.children_defs ) def __repr__(self): return "{}[{}]".format(self.type.__name__, self.children_defs) class LeafDef(TreeDef): def __init__(self, type): super().__init__(type, None, []) self.num_leaves = 1 def unflatten(self, leaves): assert len(leaves) == 1 assert isinstance(leaves[0], self.type), self.type return leaves[0] def __repr__(self): return "Leaf({})".format(self.type.__name__)