You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_accuracy_paddle.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import os
  2. import pytest
  3. from fastNLP.core.metrics import Accuracy
  4. from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher
  5. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  6. if _NEED_IMPORT_PADDLE:
  7. import paddle
  8. import paddle.distributed
  9. import paddle.distributed.fleet.base.role_maker as role_maker
  10. import paddle.distributed.fleet as fleet
  11. ############################################################################
  12. #
  13. # 测试 单机单卡情况下的Accuracy
  14. #
  15. ############################################################################
  16. @pytest.mark.paddle
  17. def test_accuracy_single():
  18. pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669,
  19. 0.46514302],
  20. [-0.85775983, -2.18273783, -1.07505429, -1.45561373, 0.40011844,
  21. 1.02202022],
  22. [-0.39487389, 0.65682763, -0.62424040, 0.53692561, -0.28390560,
  23. -0.02559055],
  24. [-0.22586937, -0.07676325, -0.95977223, 0.36395910, -0.91758579,
  25. -0.83857095],
  26. [0.25136873, 2.49652624, 1.06251311, 1.60194016, 1.01451588,
  27. 0.08403367],
  28. [0.10844281, 1.19017303, -0.11378096, 1.12686944, -0.08654942,
  29. 0.48605862],
  30. [1.27320433, -1.13902378, 1.47072780, -0.98665696, -0.42589864,
  31. 0.64618838],
  32. [0.83809763, -0.05356205, 0.03042423, -0.28371972, 0.81611472,
  33. -0.45802942],
  34. [0.38535264, 0.09721313, 2.27187467, 0.32045507, -0.20711982,
  35. -0.13550705],
  36. [-0.75228405, -1.34161997, 1.08697927, 0.33218071, -1.19470012,
  37. 2.58735061]])
  38. tg = paddle.to_tensor([1, 2, 1, 3, 5, 4, 4, 2, 1, 5])
  39. acc_metric = Accuracy()
  40. acc_metric.update(pred, tg)
  41. result = acc_metric.get_metric()['acc']
  42. assert result == 0.3
  43. ############################################################################
  44. #
  45. # 测试 单机多卡情况下的Accuracy
  46. #
  47. ############################################################################
  48. # def test_accuracy_ddp():
  49. # launcher = FleetLauncher(devices=[0, 1])
  50. # launcher.launch()
  51. # role = role_maker.PaddleCloudRoleMaker(is_collective=True)
  52. # fleet.init(role)
  53. # if fleet.is_server():
  54. # pass
  55. # elif fleet.is_worker():
  56. # print(os.getenv("PADDLE_TRAINER_ID"))