You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_data_bundle.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import pytest
  2. from fastNLP.core.dataset import DataSet
  3. from fastNLP.io.data_bundle import DataBundle
  4. def test_add_seq_len():
  5. dataset1 = DataSet({
  6. "x": [[0,1,2], [5,3,2,3], [5,21,5,10], [3,6,8,1]]
  7. })
  8. dataset2 = DataSet({
  9. "x": [[0,1,2,3,4], [5,3,2,3], [5,20,45,1,98], [3,6,8,3,6,31]]
  10. })
  11. dataset3 = DataSet({
  12. "x": [[0,1,2,7,5,2], [5,3], [0], [3,6,8]]
  13. })
  14. data_bundle = DataBundle(datasets={
  15. "dataset1": dataset1,
  16. "dataset2": dataset2,
  17. "dataset3": dataset3
  18. })
  19. data_bundle.add_seq_len("x")
  20. print(data_bundle.get_dataset("dataset1"))
  21. for i, data in enumerate(data_bundle.get_dataset("dataset1")):
  22. print(data["seq_len"], dataset1["x"][i])
  23. assert data["seq_len"] == len(dataset1["x"][i])
  24. for i, data in enumerate(data_bundle.get_dataset("dataset2")):
  25. assert data["seq_len"] == len(dataset2["x"][i])
  26. for i, data in enumerate(data_bundle.get_dataset("dataset3")):
  27. assert data["seq_len"] == len(dataset3["x"][i])
  28. @pytest.mark.parametrize("inplace", [True, False])
  29. def test_drop(inplace):
  30. dataset1 = DataSet({
  31. "x": [0, 1, 1, 4, 2, 1, 0, 1, 1, 6, 7, 1]
  32. })
  33. dataset2 = DataSet({
  34. "x": [0, 0, 0, 0, 0]
  35. })
  36. dataset3 = DataSet({
  37. "x": [1, 1, 1, 1, 1, 2, 3, 4]
  38. })
  39. data_bundle = DataBundle(datasets={
  40. "dataset1": dataset1,
  41. "dataset2": dataset2,
  42. "dataset3": dataset3
  43. })
  44. res = data_bundle.drop(lambda x: x["x"] == 0, inplace)
  45. if inplace:
  46. assert res is data_bundle
  47. else:
  48. assert not (res is data_bundle)
  49. assert data_bundle.get_dataset("dataset1")["x"] == dataset1["x"]
  50. assert data_bundle.get_dataset("dataset2")["x"] == dataset2["x"]
  51. assert data_bundle.get_dataset("dataset3")["x"] == dataset3["x"]
  52. dataset1_drop = [1, 1, 4, 2, 1, 1, 1, 6, 7, 1]
  53. for i, data in enumerate(res.get_dataset("dataset1")["x"]):
  54. assert data == dataset1_drop[i]
  55. dataset2_drop = []
  56. for i, data in enumerate(res.get_dataset("dataset2")["x"]):
  57. assert data == dataset2_drop[i]
  58. dataset3_drop = [1, 1, 1, 1, 1, 2, 3, 4]
  59. for i, data in enumerate(res.get_dataset("dataset3")["x"]):
  60. assert data == dataset3_drop[i]