You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tuple_function.py 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. def get_ndtuple(value, *, n, allow_zero=True):
  12. r"""Converts possibly 1D tuple to nd tuple.
  13. Args:
  14. allow_zero: whether to allow zero tuple value
  15. """
  16. if not isinstance(value, collections.abc.Iterable):
  17. value = int(value)
  18. value = tuple([value for i in range(n)])
  19. else:
  20. assert len(value) == n, "tuple len is not equal to n: {}".format(value)
  21. spatial_axis = map(int, value)
  22. value = tuple(spatial_axis)
  23. if allow_zero:
  24. minv = 0
  25. else:
  26. minv = 1
  27. assert min(value) >= minv, "invalid value: {}".format(value)
  28. return value
  29. _single = functools.partial(get_ndtuple, n=1, allow_zero=True)
  30. _pair = functools.partial(get_ndtuple, n=2, allow_zero=True)
  31. _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False)
  32. _triple = functools.partial(get_ndtuple, n=3, allow_zero=True)
  33. _triple_nonzero = functools.partial(get_ndtuple, n=3, allow_zero=False)
  34. _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True)