From af36813296665919042ae7dd0c7845c5ae7a63d8 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 18 May 2022 08:33:43 +0000 Subject: [PATCH] =?UTF-8?q?1.=20test=5Ffleet.py=20=E5=92=8C=20test=5Fddp.p?= =?UTF-8?q?y=20=E4=B8=AD=E7=9A=84=20load=5Fcheckpoint=20=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BE=8B=E6=B7=BB=E5=8A=A0=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=202.=E4=BF=AE=E6=94=B9=20Jenkinsfile=20=E4=B8=AD=20Te?= =?UTF-8?q?st=20Other=20=E7=9A=84=20pytest=20=E6=A0=87=E7=AD=BE=E5=BD=A2?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Jenkinsfile | 2 +- tests/core/drivers/paddle_driver/test_fleet.py | 3 +++ tests/core/drivers/torch_driver/test_ddp.py | 7 +++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 4ef51291..9af78a62 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -15,7 +15,7 @@ pipeline { } } steps { - sh 'pytest ./tests --durations=0 -m "not torch and not paddle and not paddledist and not jittor and not torchpaddle and not torchjittor"' + sh 'pytest ./tests --durations=0 -m "not (torch or paddle or paddledist or jittor or torchpaddle or torchjittor)"' } } stage('Test Torch-1.11') { diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index d3bffb9f..ad680dcb 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -629,6 +629,7 @@ class TestSaveLoad: self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + dist.barrier() # 加载 # 更改 batch_size dataloader = DataLoader( @@ -644,8 +645,10 @@ class TestSaveLoad: rank=self.driver2.global_rank, pad=True ) + dist.barrier() load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") + dist.barrier() # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index cb7ed68c..a7c4705a 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -649,6 +649,7 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + dist.barrier() # 加载 # 更改 batch_size dataloader = dataloader_with_bucketedbatchsampler( @@ -663,8 +664,12 @@ class TestSaveLoad: rank=driver2.global_rank, pad=True ) + dist.barrier() + print("========load=======", driver1.global_rank, driver2.global_rank) load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) + dist.barrier() replaced_loader = load_states.pop("dataloader") + # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -708,8 +713,10 @@ class TestSaveLoad: assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + dist.barrier() finally: rank_zero_rm(path) + print("=======delete======") if dist.is_initialized(): dist.destroy_process_group()