Browse Source

1. test_fleet.py 和 test_ddp.py 中的 load_checkpoint 相关的测试例添加同步 2.修改 Jenkinsfile 中 Test Other 的 pytest 标签形式

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
af36813296
3 changed files with 11 additions and 1 deletions
  1. +1
    -1
      Jenkinsfile
  2. +3
    -0
      tests/core/drivers/paddle_driver/test_fleet.py
  3. +7
    -0
      tests/core/drivers/torch_driver/test_ddp.py

+ 1
- 1
Jenkinsfile View File

@@ -15,7 +15,7 @@ pipeline {
}
}
steps {
sh 'pytest ./tests --durations=0 -m "not torch and not paddle and not paddledist and not jittor and not torchpaddle and not torchjittor"'
sh 'pytest ./tests --durations=0 -m "not (torch or paddle or paddledist or jittor or torchpaddle or torchjittor)"'
}
}
stage('Test Torch-1.11') {


+ 3
- 0
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -629,6 +629,7 @@ class TestSaveLoad:
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
dist.barrier()
# 加载
# 更改 batch_size
dataloader = DataLoader(
@@ -644,8 +645,10 @@ class TestSaveLoad:
rank=self.driver2.global_rank,
pad=True
)
dist.barrier()
load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
dist.barrier()
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空



+ 7
- 0
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -649,6 +649,7 @@ class TestSaveLoad:
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
dist.barrier()
# 加载
# 更改 batch_size
dataloader = dataloader_with_bucketedbatchsampler(
@@ -663,8 +664,12 @@ class TestSaveLoad:
rank=driver2.global_rank,
pad=True
)
dist.barrier()
print("========load=======", driver1.global_rank, driver2.global_rank)
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True)
dist.barrier()
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

@@ -708,8 +713,10 @@ class TestSaveLoad:
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
dist.barrier()
finally:
rank_zero_rm(path)
print("=======delete======")

if dist.is_initialized():
dist.destroy_process_group()


Loading…
Cancel
Save