|
@@ -6,7 +6,6 @@ import re |
|
|
|
|
|
|
|
|
import pytest |
|
|
import pytest |
|
|
from fastNLP.core.controllers.trainer import Trainer |
|
|
from fastNLP.core.controllers.trainer import Trainer |
|
|
from torchmetrics import Accuracy |
|
|
|
|
|
from fastNLP.core.callbacks import CheckpointCallback |
|
|
from fastNLP.core.callbacks import CheckpointCallback |
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset |
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset |
|
@@ -18,6 +17,7 @@ if _NEED_IMPORT_TORCH: |
|
|
import torch.distributed as dist |
|
|
import torch.distributed as dist |
|
|
from torch.optim import SGD |
|
|
from torch.optim import SGD |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from torchmetrics import Accuracy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
@dataclass |
|
|