You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collator.py 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2016- Facebook, Inc (Adam Paszke)
  3. # Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
  4. # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
  5. # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
  6. # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
  7. # Copyright (c) 2011-2013 NYU (Clement Farabet)
  8. # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
  9. # Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
  10. # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
  11. # ---------------------------------------------------------------------
  12. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  13. #
  14. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  15. #
  16. # Unless required by applicable law or agreed to in writing,
  17. # software distributed under the License is distributed on an
  18. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. #
  20. # This file has been modified by Megvii ("Megvii Modifications").
  21. # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
  22. # ----------------------------------------------------------------------
  23. import collections.abc
  24. import re
  25. import numpy as np
  26. np_str_obj_array_pattern = re.compile(r"[aO]")
  27. default_collate_err_msg_format = (
  28. "default_collator: inputs must contain numpy arrays, numbers, "
  29. "Unicode strings, bytes, dicts or lists; found {}"
  30. )
  31. class Collator:
  32. r"""Used for merging a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a dataset.
  33. Modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
  34. """
  35. def apply(self, inputs):
  36. elem = inputs[0]
  37. elem_type = type(elem)
  38. if (
  39. elem_type.__module__ == "numpy"
  40. and elem_type.__name__ != "str_"
  41. and elem_type.__name__ != "string_"
  42. ):
  43. elem = inputs[0]
  44. if elem_type.__name__ == "ndarray":
  45. # array of string classes and object
  46. if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  47. raise TypeError(default_collate_err_msg_format.format(elem.dtype))
  48. return np.ascontiguousarray(np.stack(inputs))
  49. elif elem.shape == (): # scalars
  50. return np.array(inputs)
  51. elif isinstance(elem, float):
  52. return np.array(inputs, dtype=np.float64)
  53. elif isinstance(elem, int):
  54. return np.array(inputs)
  55. elif isinstance(elem, (str, bytes)):
  56. return inputs
  57. elif isinstance(elem, collections.abc.Mapping):
  58. return {key: self.apply([d[key] for d in inputs]) for key in elem}
  59. elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
  60. return elem_type(*(self.apply(samples) for samples in zip(*inputs)))
  61. elif isinstance(elem, collections.abc.Sequence):
  62. transposed = zip(*inputs)
  63. return [self.apply(samples) for samples in transposed]
  64. raise TypeError(default_collate_err_msg_format.format(elem_type))