You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collator.py 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2016- Facebook, Inc (Adam Paszke)
  3. # Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
  4. # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
  5. # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
  6. # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
  7. # Copyright (c) 2011-2013 NYU (Clement Farabet)
  8. # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
  9. # Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
  10. # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
  11. # ---------------------------------------------------------------------
  12. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  13. #
  14. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  15. #
  16. # Unless required by applicable law or agreed to in writing,
  17. # software distributed under the License is distributed on an
  18. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. #
  20. # This file has been modified by Megvii ("Megvii Modifications").
  21. # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
  22. # ----------------------------------------------------------------------
  23. import collections.abc
  24. import re
  25. import numpy as np
  26. np_str_obj_array_pattern = re.compile(r"[aO]")
  27. default_collate_err_msg_format = (
  28. "default_collator: inputs must contain numpy arrays, numbers, "
  29. "Unicode strings, bytes, dicts or lists; found {}"
  30. )
  31. class Collator:
  32. r"""
  33. Used for merging a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a dataset.
  34. Modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
  35. """
  36. def apply(self, inputs):
  37. """
  38. :param inputs: sequence_N(tuple(CHW, C, CK)).
  39. :return: tuple(NCHW, NC, NCK).
  40. """
  41. elem = inputs[0]
  42. elem_type = type(elem)
  43. if (
  44. elem_type.__module__ == "numpy"
  45. and elem_type.__name__ != "str_"
  46. and elem_type.__name__ != "string_"
  47. ):
  48. elem = inputs[0]
  49. if elem_type.__name__ == "ndarray":
  50. # array of string classes and object
  51. if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  52. raise TypeError(default_collate_err_msg_format.format(elem.dtype))
  53. return np.ascontiguousarray(np.stack(inputs))
  54. elif elem.shape == (): # scalars
  55. return np.array(inputs)
  56. elif isinstance(elem, float):
  57. return np.array(inputs, dtype=np.float64)
  58. elif isinstance(elem, int):
  59. return np.array(inputs)
  60. elif isinstance(elem, (str, bytes)):
  61. return inputs
  62. elif isinstance(elem, collections.abc.Mapping):
  63. return {key: self.apply([d[key] for d in inputs]) for key in elem}
  64. elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
  65. return elem_type(*(self.apply(samples) for samples in zip(*inputs)))
  66. elif isinstance(elem, collections.abc.Sequence):
  67. transposed = zip(*inputs)
  68. return [self.apply(samples) for samples in transposed]
  69. raise TypeError(default_collate_err_msg_format.format(elem_type))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台