You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collator.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2016- Facebook, Inc (Adam Paszke)
  3. # Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
  4. # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
  5. # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
  6. # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
  7. # Copyright (c) 2011-2013 NYU (Clement Farabet)
  8. # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
  9. # Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
  10. # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
  11. # ---------------------------------------------------------------------
  12. #
  13. # This file has been modified by Megvii ("Megvii Modifications").
  14. # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
  15. # ----------------------------------------------------------------------
  16. import collections.abc
  17. import re
  18. import numpy as np
  19. np_str_obj_array_pattern = re.compile(r"[aO]")
  20. default_collate_err_msg_format = (
  21. "default_collator: inputs must contain numpy arrays, numbers, "
  22. "Unicode strings, bytes, dicts or lists; found {}"
  23. )
  24. class Collator:
  25. r"""Used for merging a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a dataset.
  26. Modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
  27. """
  28. def apply(self, inputs):
  29. elem = inputs[0]
  30. elem_type = type(elem)
  31. if (
  32. elem_type.__module__ == "numpy"
  33. and elem_type.__name__ != "str_"
  34. and elem_type.__name__ != "string_"
  35. ):
  36. elem = inputs[0]
  37. if elem_type.__name__ == "ndarray":
  38. # array of string classes and object
  39. if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  40. raise TypeError(default_collate_err_msg_format.format(elem.dtype))
  41. return np.ascontiguousarray(np.stack(inputs))
  42. elif elem.shape == (): # scalars
  43. return np.array(inputs)
  44. elif isinstance(elem, float):
  45. return np.array(inputs, dtype=np.float64)
  46. elif isinstance(elem, int):
  47. return np.array(inputs)
  48. elif isinstance(elem, (str, bytes)):
  49. return inputs
  50. elif isinstance(elem, collections.abc.Mapping):
  51. return {key: self.apply([d[key] for d in inputs]) for key in elem}
  52. elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
  53. return elem_type(*(self.apply(samples) for samples in zip(*inputs)))
  54. elif isinstance(elem, collections.abc.Sequence):
  55. transposed = zip(*inputs)
  56. return [self.apply(samples) for samples in transposed]
  57. raise TypeError(default_collate_err_msg_format.format(elem_type))