You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable, Union
  10. import megengine._internal as mgb
  11. from ..core.graph import _use_default_if_none
  12. from ..core.tensor import Tensor, wrap_io_tensor
  13. from .elemwise import equal
  14. from .sort import top_k
  15. def _decide_comp_node_and_comp_graph(*args: mgb.SymbolVar):
  16. for i in args:
  17. if isinstance(i, mgb.SymbolVar):
  18. return i.comp_node, i.owner_graph
  19. return _use_default_if_none(None, None)
  20. def accuracy(
  21. logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
  22. ) -> Union[Tensor, Iterable[Tensor]]:
  23. r"""
  24. Calculate the classification accuracy given predicted logits and ground-truth labels.
  25. :param logits: Model predictions of shape [batch_size, num_classes],
  26. representing the probability (likelyhood) of each class.
  27. :param target: Ground-truth labels, 1d tensor of int32
  28. :param topk: Specifies the topk values, could be an int or tuple of ints. Default: 1
  29. :return: Tensor(s) of classification accuracy between 0.0 and 1.0
  30. Examples:
  31. .. testcode::
  32. import numpy as np
  33. from megengine import tensor
  34. import megengine.functional as F
  35. logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10))
  36. target = tensor(np.arange(8, dtype=np.int32))
  37. top1, top5 = F.accuracy(logits, target, (1, 5))
  38. print(top1.numpy(), top5.numpy())
  39. Outputs:
  40. .. testoutput::
  41. [0.] [0.375]
  42. """
  43. if isinstance(topk, int):
  44. topk = (topk,)
  45. _, pred = top_k(logits, k=max(topk), descending=True)
  46. accs = []
  47. for k in topk:
  48. correct = equal(
  49. pred[:, :k], target.dimshuffle(0, "x").broadcast(target.shapeof(0), k)
  50. )
  51. accs.append(correct.sum() / target.shapeof(0))
  52. if len(topk) == 1: # type: ignore[arg-type]
  53. accs = accs[0]
  54. return accs
  55. @wrap_io_tensor
  56. def zero_grad(inp: Tensor) -> Tensor:
  57. r"""
  58. Returns a tensor which is treated as constant during backward gradient calcuation,
  59. i.e. its gradient is zero.
  60. :param inp: Input tensor.
  61. See implementation of :func:`~.softmax` for example.
  62. """
  63. return mgb.opr.zero_grad(inp)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台