You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_distributed.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. from fastNLP.envs.distributed import rank_zero_call, all_rank_call
  3. from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context
  4. @rank_zero_call
  5. def write_something():
  6. print(os.environ.get('RANK', '0')*5, flush=True)
  7. def write_other_thing():
  8. print(os.environ.get('RANK', '0')*5, flush=True)
  9. class PaddleTest:
  10. # @x54-729
  11. def test_rank_zero_call(self):
  12. pass
  13. def test_all_rank_run(self):
  14. pass
  15. class JittorTest:
  16. # @x54-729
  17. def test_rank_zero_call(self):
  18. pass
  19. def test_all_rank_run(self):
  20. pass
  21. class TestTorch:
  22. @magic_argv_env_context
  23. def test_rank_zero_call(self):
  24. os.environ['MASTER_ADDR'] = '127.0.0.1'
  25. os.environ['MASTER_PORT'] = '29500'
  26. if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
  27. os.environ['LOCAL_RANK'] = '0'
  28. os.environ['RANK'] = '0'
  29. os.environ['WORLD_SIZE'] = '2'
  30. re_run_current_cmd_for_torch(1, output_from_new_proc='all')
  31. with Capturing() as output:
  32. write_something()
  33. output = output[0]
  34. if os.environ['LOCAL_RANK'] == '0':
  35. assert '00000' in output and '11111' not in output
  36. else:
  37. assert '00000' not in output and '11111' not in output
  38. with Capturing() as output:
  39. rank_zero_call(write_other_thing)()
  40. output = output[0]
  41. if os.environ['LOCAL_RANK'] == '0':
  42. assert '00000' in output and '11111' not in output
  43. else:
  44. assert '00000' not in output and '11111' not in output
  45. @magic_argv_env_context
  46. def test_all_rank_run(self):
  47. os.environ['MASTER_ADDR'] = '127.0.0.1'
  48. os.environ['MASTER_PORT'] = '29500'
  49. if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
  50. os.environ['LOCAL_RANK'] = '0'
  51. os.environ['RANK'] = '0'
  52. os.environ['WORLD_SIZE'] = '2'
  53. re_run_current_cmd_for_torch(1, output_from_new_proc='all')
  54. # torch.distributed.init_process_group(backend='nccl')
  55. # torch.distributed.barrier()
  56. with all_rank_call():
  57. with Capturing(no_del=True) as output:
  58. write_something()
  59. output = output[0]
  60. if os.environ['LOCAL_RANK'] == '0':
  61. assert '00000' in output
  62. else:
  63. assert '11111' in output
  64. with all_rank_call():
  65. with Capturing(no_del=True) as output:
  66. rank_zero_call(write_other_thing)()
  67. output = output[0]
  68. if os.environ['LOCAL_RANK'] == '0':
  69. assert '00000' in output
  70. else:
  71. assert '11111' in output