|
|
@@ -26,7 +26,8 @@ def _decide_comp_node_and_comp_graph(*args: mgb.SymbolVar): |
|
|
|
def accuracy( |
|
|
|
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 |
|
|
|
) -> Union[Tensor, Iterable[Tensor]]: |
|
|
|
r"""Calculate the classification accuracy given predicted logits and ground-truth labels. |
|
|
|
r""" |
|
|
|
Calculate the classification accuracy given predicted logits and ground-truth labels. |
|
|
|
|
|
|
|
:param logits: Model predictions of shape [batch_size, num_classes], |
|
|
|
representing the probability (likelyhood) of each class. |
|
|
|