|
|
@@ -649,6 +649,7 @@ class TestSaveLoad: |
|
|
|
sampler_states = dataloader.batch_sampler.state_dict() |
|
|
|
save_states = {"num_consumed_batches": num_consumed_batches} |
|
|
|
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) |
|
|
|
dist.barrier() |
|
|
|
# 加载 |
|
|
|
# 更改 batch_size |
|
|
|
dataloader = dataloader_with_bucketedbatchsampler( |
|
|
@@ -663,8 +664,12 @@ class TestSaveLoad: |
|
|
|
rank=driver2.global_rank, |
|
|
|
pad=True |
|
|
|
) |
|
|
|
dist.barrier() |
|
|
|
print("========load=======", driver1.global_rank, driver2.global_rank) |
|
|
|
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) |
|
|
|
dist.barrier() |
|
|
|
replaced_loader = load_states.pop("dataloader") |
|
|
|
|
|
|
|
# 1. 检查 optimizer 的状态 |
|
|
|
# TODO optimizer 的 state_dict 总是为空 |
|
|
|
|
|
|
@@ -708,8 +713,10 @@ class TestSaveLoad: |
|
|
|
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas |
|
|
|
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas |
|
|
|
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas |
|
|
|
dist.barrier() |
|
|
|
finally: |
|
|
|
rank_zero_rm(path) |
|
|
|
print("=======delete======") |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|