diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index c8928086..47d5e114 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -64,7 +64,7 @@ class Accuracy(Metric): if np.prod(pred.shape) != np.prod(target.shape): raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." f" while target have shape:{target.shape}, " - f"pred have shape: {target.shape}") + f"pred have shape: {pred.shape}") elif pred.ndim == target.ndim + 1: pred = pred.argmax(axis=-1)