You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sort.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. from typing import Optional, Tuple, Union
  11. import megengine._internal as mgb
  12. from ..core.tensor import Tensor, wrap_io_tensor
  13. __all__ = ["argsort", "sort", "top_k"]
  14. @wrap_io_tensor
  15. def argsort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  16. r"""
  17. Sort the target 2d matrix by row, return both the sorted tensor and indices.
  18. :param inp: The input tensor, if 2d, each row will be sorted
  19. :param descending: Sort in descending order, where the largest comes first. Default: ``False``
  20. :return: Tuple of two tensors (sorted_tensor, indices_of_int32)
  21. Examples:
  22. .. testcode::
  23. import numpy as np
  24. from megengine import tensor
  25. import megengine.functional as F
  26. data = tensor(np.array([1,2], dtype=np.float32))
  27. sorted, indices = F.argsort(data)
  28. print(sorted.numpy(), indices.numpy())
  29. Outputs:
  30. .. testoutput::
  31. :options: +NUMBER
  32. [1. 2.] [0 1]
  33. """
  34. assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d"
  35. if descending:
  36. order = mgb.opr_param_defs.Argsort.Order.DESCENDING
  37. else:
  38. order = mgb.opr_param_defs.Argsort.Order.ASCENDING
  39. if len(inp.imm_shape) == 1:
  40. inp = inp.reshape(1, -1)
  41. tns, ind = mgb.opr.argsort(inp, order=order)
  42. return tns[0], ind[0]
  43. return mgb.opr.argsort(inp, order=order)
  44. @functools.wraps(argsort)
  45. def sort(*args, **kwargs):
  46. return argsort(*args, **kwargs)
  47. @wrap_io_tensor
  48. def top_k(
  49. inp: Tensor,
  50. k: int,
  51. descending: bool = False,
  52. kth_only: bool = False,
  53. no_sort: bool = False,
  54. ) -> Tuple[Tensor, Tensor]:
  55. r"""
  56. Selected the Top-K (by default) smallest elements of 2d matrix by row.
  57. :param inp: The input tensor, if 2d, each row will be sorted
  58. :param k: The number of elements needed
  59. :param descending: If true, return the largest elements instead. Default: ``False``
  60. :param kth_only: If true, only the k-th element will be returned. Default: ``False``
  61. :param no_sort: If true, the returned elements can be unordered. Default: ``False``
  62. :return: Tuple of two tensors (topk_tensor, indices_of_int32)
  63. Examples:
  64. .. testcode::
  65. import numpy as np
  66. from megengine import tensor
  67. import megengine.functional as F
  68. data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  69. top, indices = F.top_k(data, 5)
  70. print(top.numpy(), indices.numpy())
  71. Outputs:
  72. .. testoutput::
  73. :options: +NUMBER
  74. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  75. """
  76. assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d"
  77. if kth_only:
  78. raise NotImplementedError(
  79. "TODO: would enconter:"
  80. "NotImplementedError: SymbolVar var could not be itered"
  81. )
  82. if descending:
  83. inp = -inp
  84. Mode = mgb.opr_param_defs.TopK.Mode
  85. if kth_only:
  86. mode = Mode.KTH_ONLY
  87. elif no_sort:
  88. mode = Mode.VALUE_IDX_NOSORT
  89. else:
  90. mode = Mode.VALUE_IDX_SORTED
  91. if len(inp.imm_shape) == 1:
  92. inp = inp.reshape(1, -1)
  93. tns, ind = mgb.opr.top_k(inp, k, mode=mode)
  94. tns = tns[0]
  95. ind = ind[0]
  96. else:
  97. tns, ind = mgb.opr.top_k(inp, k, mode=mode)
  98. if descending:
  99. tns = -tns
  100. return tns, ind

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)