You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metric.py 1.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # -*- coding: utf-8 -*-
  2. from typing import Iterable, Union
  3. import numpy as np
  4. from ..tensor import Tensor
  5. from .elemwise import abs, maximum, minimum
  6. from .math import topk as _topk
  7. from .tensor import broadcast_to, transpose
  8. __all__ = [
  9. "topk_accuracy",
  10. ]
  11. def topk_accuracy(
  12. logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
  13. ) -> Union[Tensor, Iterable[Tensor]]:
  14. r"""Calculates the classification accuracy given predicted logits and ground-truth labels.
  15. Args:
  16. logits: model predictions of shape `[batch_size, num_classes]`,
  17. representing the probability (likelyhood) of each class.
  18. target: ground-truth labels, 1d tensor of int32.
  19. topk: specifies the topk values, could be an int or tuple of ints. Default: 1
  20. Returns:
  21. tensor(s) of classification accuracy between 0.0 and 1.0.
  22. """
  23. if isinstance(topk, int):
  24. topk = (topk,)
  25. _, pred = _topk(logits, k=max(topk), descending=True)
  26. accs = []
  27. for k in topk:
  28. correct = pred[:, :k].detach() == broadcast_to(
  29. transpose(target, (0, "x")), (target.shape[0], k)
  30. )
  31. accs.append(correct.astype(np.float32).sum() / target.shape[0])
  32. if len(topk) == 1: # type: ignore[arg-type]
  33. accs = accs[0]
  34. return accs