|
@@ -15,6 +15,10 @@ from .elemwise import abs, maximum, minimum |
|
|
from .math import topk as _topk |
|
|
from .math import topk as _topk |
|
|
from .tensor import broadcast_to, transpose |
|
|
from .tensor import broadcast_to, transpose |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
|
|
"topk_accuracy", |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def topk_accuracy( |
|
|
def topk_accuracy( |
|
|
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 |
|
|
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 |
|
|