You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sort.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. from typing import Optional, Tuple, Union
  11. import megengine._internal as mgb
  12. from ..core.tensor import Tensor, wrap_io_tensor
  13. __all__ = ["argsort", "sort", "top_k"]
  14. @wrap_io_tensor
  15. def argsort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  16. r"""
  17. Sort the target 2d matrix by row, return both the sorted tensor and indices.
  18. :param inp: The input tensor, if 2d, each row will be sorted
  19. :param descending: Sort in descending order, where the largest comes first. Default: ``False``
  20. :return: Tuple of two tensors (sorted_tensor, indices_of_int32)
  21. Examples:
  22. .. testcode::
  23. import numpy as np
  24. from megengine import tensor
  25. import megengine.functional as F
  26. data = tensor(np.array([1,2], dtype=np.float32))
  27. sorted, indices = F.argsort(data)
  28. print(sorted.numpy(), indices.numpy())
  29. Outputs:
  30. .. testoutput::
  31. [1. 2.] [0 1]
  32. """
  33. assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d"
  34. if descending:
  35. order = mgb.opr_param_defs.Argsort.Order.DESCENDING
  36. else:
  37. order = mgb.opr_param_defs.Argsort.Order.ASCENDING
  38. if len(inp.imm_shape) == 1:
  39. inp = inp.reshape(1, -1)
  40. tns, ind = mgb.opr.argsort(inp, order=order)
  41. return tns[0], ind[0]
  42. return mgb.opr.argsort(inp, order=order)
  43. @functools.wraps(argsort)
  44. def sort(*args, **kwargs):
  45. return argsort(*args, **kwargs)
  46. @wrap_io_tensor
  47. def top_k(
  48. inp: Tensor,
  49. k: int,
  50. descending: bool = False,
  51. kth_only: bool = False,
  52. no_sort: bool = False,
  53. ) -> Tuple[Tensor, Tensor]:
  54. r"""
  55. Selected the Top-K (by default) smallest elements of 2d matrix by row.
  56. :param inp: The input tensor, if 2d, each row will be sorted
  57. :param k: The number of elements needed
  58. :param descending: If true, return the largest elements instead. Default: ``False``
  59. :param kth_only: If true, only the k-th element will be returned. Default: ``False``
  60. :param no_sort: If true, the returned elements can be unordered. Default: ``False``
  61. :return: Tuple of two tensors (topk_tensor, indices_of_int32)
  62. Examples:
  63. .. testcode::
  64. import numpy as np
  65. from megengine import tensor
  66. import megengine.functional as F
  67. data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  68. top, indices = F.top_k(data, 5)
  69. print(top.numpy(), indices.numpy())
  70. Outputs:
  71. .. testoutput::
  72. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  73. """
  74. assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d"
  75. if kth_only:
  76. raise NotImplementedError(
  77. "TODO: would enconter:"
  78. "NotImplementedError: SymbolVar var could not be itered"
  79. )
  80. if descending:
  81. inp = -inp
  82. Mode = mgb.opr_param_defs.TopK.Mode
  83. if kth_only:
  84. mode = Mode.KTH_ONLY
  85. elif no_sort:
  86. mode = Mode.VALUE_IDX_NOSORT
  87. else:
  88. mode = Mode.VALUE_IDX_SORTED
  89. if len(inp.imm_shape) == 1:
  90. inp = inp.reshape(1, -1)
  91. tns, ind = mgb.opr.top_k(inp, k, mode=mode)
  92. tns = tns[0]
  93. ind = ind[0]
  94. else:
  95. tns, ind = mgb.opr.top_k(inp, k, mode=mode)
  96. if descending:
  97. tns = -tns
  98. return tns, ind

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台