|
|
@@ -64,7 +64,7 @@ class Accuracy(Metric): |
|
|
|
if np.prod(pred.shape) != np.prod(target.shape): |
|
|
|
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." |
|
|
|
f" while target have shape:{target.shape}, " |
|
|
|
f"pred have shape: {target.shape}") |
|
|
|
f"pred have shape: {pred.shape}") |
|
|
|
|
|
|
|
elif pred.ndim == target.ndim + 1: |
|
|
|
pred = pred.argmax(axis=-1) |
|
|
|