You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_paddle_utils.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import os
  2. import pytest
  3. from fastNLP.core.utils.paddle_utils import _convert_data_device, paddle_to, paddle_move_data_to_device
  4. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  5. if _NEED_IMPORT_PADDLE:
  6. import paddle
  7. @pytest.mark.parametrize(
  8. ("user_visible_devices, cuda_visible_devices, device, correct"),
  9. (
  10. (None, None, 1, "gpu:1"),
  11. (None, "2,4,5,6", 2, "gpu:2"),
  12. (None, "3,4,5", 1, "gpu:1"),
  13. ("0,1,2,3,4,5,6,7", "0", "cpu", "cpu"),
  14. ("3,4,5,6,7", "0", "cpu", "cpu"),
  15. ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", "gpu:1"),
  16. ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", "gpu:2"),
  17. ("3,4,5,6", "3,5", 0, "gpu:0"),
  18. ("3,6,7,8", "6,7,8", "gpu:2", "gpu:1"),
  19. )
  20. )
  21. def test_convert_data_device(user_visible_devices, cuda_visible_devices, device, correct):
  22. _cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
  23. _user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
  24. if cuda_visible_devices is not None:
  25. os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
  26. if user_visible_devices is not None:
  27. os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices
  28. res = _convert_data_device(device)
  29. assert res == correct
  30. # 还原环境变量
  31. if _cuda_visible_devices is None:
  32. os.environ.pop("CUDA_VISIBLE_DEVICES", None)
  33. else:
  34. os.environ["CUDA_VISIBLE_DEVICES"] = _cuda_visible_devices
  35. if _user_visible_devices is None:
  36. os.environ.pop("USER_CUDA_VISIBLE_DEVICES", None)
  37. else:
  38. os.environ["USER_CUDA_VISIBLE_DEVICES"] = _user_visible_devices
  39. ############################################################################
  40. #
  41. # 测试仅将单个paddle张量迁移到指定设备
  42. #
  43. ############################################################################
  44. @pytest.mark.paddle
  45. class TestPaddleToDevice:
  46. def test_case(self):
  47. tensor = paddle.rand((4, 5))
  48. res = paddle_to(tensor, "gpu")
  49. assert res.place.is_gpu_place()
  50. assert res.place.gpu_device_id() == 0
  51. res = paddle_to(tensor, "cpu")
  52. assert res.place.is_cpu_place()
  53. ############################################################################
  54. #
  55. # 测试将参数中包含的所有paddle张量迁移到指定设备
  56. #
  57. ############################################################################
  58. @pytest.mark.paddle
  59. class TestPaddleMoveDataToDevice:
  60. def check_gpu(self, tensor, idx):
  61. """
  62. 检查张量是否在指定的设备上的工具函数
  63. """
  64. assert tensor.place.is_gpu_place()
  65. assert tensor.place.gpu_device_id() == idx
  66. def check_cpu(self, tensor):
  67. """
  68. 检查张量是否在cpu上的工具函数
  69. """
  70. assert tensor.place.is_cpu_place()
  71. def test_tensor_transfer(self):
  72. """
  73. 测试单个张量的迁移
  74. """
  75. paddle_tensor = paddle.rand((3, 4, 5)).cpu()
  76. res = paddle_move_data_to_device(paddle_tensor, device=None)
  77. self.check_cpu(res)
  78. res = paddle_move_data_to_device(paddle_tensor, device="gpu:0")
  79. self.check_gpu(res, 0)
  80. def test_list_transfer(self):
  81. """
  82. 测试张量列表的迁移
  83. """
  84. paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
  85. res = paddle_move_data_to_device(paddle_list, device="cpu")
  86. assert isinstance(res, list)
  87. for r in res:
  88. self.check_cpu(r)
  89. res = paddle_move_data_to_device(paddle_list, device="gpu:0")
  90. assert isinstance(res, list)
  91. for r in res:
  92. self.check_gpu(r, 0)
  93. def test_tensor_tuple_transfer(self):
  94. """
  95. 测试张量元组的迁移
  96. """
  97. paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
  98. paddle_tuple = tuple(paddle_list)
  99. res = paddle_move_data_to_device(paddle_tuple, device="cpu")
  100. assert isinstance(res, tuple)
  101. for r in res:
  102. self.check_cpu(r)
  103. res = paddle_move_data_to_device(paddle_tuple, device="gpu:0")
  104. assert isinstance(res, tuple)
  105. for r in res:
  106. self.check_gpu(r, 0)
  107. def test_dict_transfer(self):
  108. """
  109. 测试字典结构的迁移
  110. """
  111. paddle_dict = {
  112. "tensor": paddle.rand((3, 4)),
  113. "list": [paddle.rand((6, 4, 2)) for i in range(10)],
  114. "dict":{
  115. "list": [paddle.rand((6, 4, 2)) for i in range(10)],
  116. "tensor": paddle.rand((3, 4))
  117. },
  118. "int": 2,
  119. "string": "test string"
  120. }
  121. res = paddle_move_data_to_device(paddle_dict, device="gpu:0")
  122. assert isinstance(res, dict)
  123. self.check_gpu(res["tensor"], 0)
  124. assert isinstance(res["list"], list)
  125. for t in res["list"]:
  126. self.check_gpu(t, 0)
  127. assert isinstance(res["int"], int)
  128. assert isinstance(res["string"], str)
  129. assert isinstance(res["dict"], dict)
  130. assert isinstance(res["dict"]["list"], list)
  131. for t in res["dict"]["list"]:
  132. self.check_gpu(t, 0)
  133. self.check_gpu(res["dict"]["tensor"], 0)
  134. res = paddle_move_data_to_device(paddle_dict, device="gpu:0")
  135. assert isinstance(res, dict)
  136. self.check_gpu(res["tensor"], 0)
  137. assert isinstance(res["list"], list)
  138. for t in res["list"]:
  139. self.check_gpu(t, 0)
  140. assert isinstance(res["int"], int)
  141. assert isinstance(res["string"], str)
  142. assert isinstance(res["dict"], dict)
  143. assert isinstance(res["dict"]["list"], list)
  144. for t in res["dict"]["list"]:
  145. self.check_gpu(t, 0)
  146. self.check_gpu(res["dict"]["tensor"], 0)
  147. res = paddle_move_data_to_device(paddle_dict, device="cpu")
  148. assert isinstance(res, dict)
  149. self.check_cpu(res["tensor"])
  150. assert isinstance(res["list"], list)
  151. for t in res["list"]:
  152. self.check_cpu(t)
  153. assert isinstance(res["int"], int)
  154. assert isinstance(res["string"], str)
  155. assert isinstance(res["dict"], dict)
  156. assert isinstance(res["dict"]["list"], list)
  157. for t in res["dict"]["list"]:
  158. self.check_cpu(t)
  159. self.check_cpu(res["dict"]["tensor"])