From 9b0a30c8fb9b7f79e319a9e3d6bcf11227442cd4 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 29 May 2022 13:51:10 +0000 Subject: [PATCH] small bug --- fastNLP/core/drivers/jittor_driver/jittor_driver.py | 3 +-- tests/core/drivers/paddle_driver/test_dist_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 5b38747d..e486df8e 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -1,7 +1,6 @@ import os from pathlib import Path from typing import Union, Optional, Dict -from contextlib import nullcontext from dataclasses import dataclass from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.log import logger -from fastNLP.core.utils import apply_to_collection +from fastNLP.core.utils import apply_to_collection, nullcontext from fastNLP.envs import ( FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, diff --git a/tests/core/drivers/paddle_driver/test_dist_utils.py b/tests/core/drivers/paddle_driver/test_dist_utils.py index e3a3eb5d..30ce9d29 100644 --- a/tests/core/drivers/paddle_driver/test_dist_utils.py +++ b/tests/core/drivers/paddle_driver/test_dist_utils.py @@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast: @classmethod def setup_class(cls): - devices = [0,1,2] + devices = [0,1] output_from_new_proc = "all" launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc) @@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast: dist.barrier() @magic_argv_env_context - @pytest.mark.parametrize("src_rank", ([0, 1, 2])) + @pytest.mark.parametrize("src_rank", ([0, 1])) def test_fastnlp_paddle_broadcast_object(self, src_rank): if self.local_rank == src_rank: obj = {