- # -*- coding: utf-8 -*-
- import collections
- import functools
-
-
- def get_ndtuple(value, *, n, allow_zero=True):
- r"""Converts possibly 1D tuple to nd tuple.
-
- Args:
- allow_zero: whether to allow zero tuple value
- """
- if not isinstance(value, collections.abc.Iterable):
- value = int(value)
- value = tuple([value for i in range(n)])
- else:
- assert len(value) == n, "tuple len is not equal to n: {}".format(value)
- spatial_axis = map(int, value)
- value = tuple(spatial_axis)
- if allow_zero:
- minv = 0
- else:
- minv = 1
- assert min(value) >= minv, "invalid value: {}".format(value)
- return value
-
-
- _single = functools.partial(get_ndtuple, n=1, allow_zero=True)
- _pair = functools.partial(get_ndtuple, n=2, allow_zero=True)
- _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False)
- _triple = functools.partial(get_ndtuple, n=3, allow_zero=True)
- _triple_nonzero = functools.partial(get_ndtuple, n=3, allow_zero=False)
- _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True)
|