You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import Iterable, Union
  11. import numpy as np
  12. from ..core._wrap import device as as_device
  13. from ..core.ops.builtin import Copy, Identity
  14. from ..core.tensor import Tensor
  15. from ..core.tensor.core import apply
  16. from .math import topk as _topk
  17. from .tensor import broadcast_to, transpose
  18. __all__ = [
  19. "topk_accuracy",
  20. "copy",
  21. ]
  22. def topk_accuracy(
  23. logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
  24. ) -> Union[Tensor, Iterable[Tensor]]:
  25. r"""
  26. Calculates the classification accuracy given predicted logits and ground-truth labels.
  27. :param logits: model predictions of shape `[batch_size, num_classes]`,
  28. representing the probability (likelyhood) of each class.
  29. :param target: ground-truth labels, 1d tensor of int32.
  30. :param topk: specifies the topk values, could be an int or tuple of ints. Default: 1
  31. :return: tensor(s) of classification accuracy between 0.0 and 1.0.
  32. Examples:
  33. .. testcode::
  34. import numpy as np
  35. from megengine import tensor
  36. import megengine.functional as F
  37. logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10))
  38. target = tensor(np.arange(8, dtype=np.int32))
  39. top1, top5 = F.topk_accuracy(logits, target, (1, 5))
  40. print(top1.numpy(), top5.numpy())
  41. Outputs:
  42. .. testoutput::
  43. [0.] [0.375]
  44. """
  45. if isinstance(topk, int):
  46. topk = (topk,)
  47. _, pred = _topk(logits, k=max(topk), descending=True)
  48. accs = []
  49. for k in topk:
  50. correct = pred[:, :k].detach() == broadcast_to(
  51. transpose(target, (0, "x")), (target.shape[0], k)
  52. )
  53. accs.append(correct.astype(np.float32).sum() / target.shape[0])
  54. if len(topk) == 1: # type: ignore[arg-type]
  55. accs = accs[0]
  56. return accs
  57. def copy(inp, device=None):
  58. r"""
  59. Copies tensor to another device.
  60. :param inp: input tensor.
  61. :param device: destination device.
  62. Examples:
  63. .. testcode::
  64. import numpy as np
  65. from megengine import tensor
  66. import megengine.functional as F
  67. x = tensor([1, 2, 3], np.int32)
  68. y = F.copy(x, "xpu1")
  69. print(y.numpy())
  70. Outputs:
  71. .. testoutput::
  72. [1 2 3]
  73. """
  74. if device is None:
  75. return apply(Identity(), inp)[0]
  76. return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台