You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. import os
  2. import pytest
  3. import numpy as np
  4. from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException
  5. from fastNLP import logger
  6. class TestDataSetInit:
  7. """初始化DataSet的办法有以下几种:
  8. 1) 用dict:
  9. 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
  10. 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
  11. 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
  12. 2) 用list of Instance:
  13. 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
  14. 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
  15. 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
  16. 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
  17. 只接受纯list或者最外层ndarray
  18. """
  19. def test_init_v1(self):
  20. # 一维list
  21. ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
  22. assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
  23. assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
  24. assert ds.field_arrays["y"].content == [[5, 6], ] * 40
  25. def test_init_v2(self):
  26. # 用dict
  27. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  28. assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
  29. assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
  30. assert ds.field_arrays["y"].content == [[5, 6], ] * 40
  31. def test_init_assert(self):
  32. with pytest.raises(AssertionError):
  33. _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
  34. with pytest.raises(AssertionError):
  35. _ = DataSet([[1, 2, 3, 4]] * 10)
  36. with pytest.raises(ValueError):
  37. _ = DataSet(0.00001)
  38. class TestDataSetMethods:
  39. def test_append(self):
  40. dd = DataSet()
  41. for _ in range(3):
  42. dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
  43. assert len(dd) == 3
  44. assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3
  45. assert dd.field_arrays["y"].content == [[5, 6]] * 3
  46. def test_add_field(self):
  47. dd = DataSet()
  48. dd.add_field("x", [[1, 2, 3]] * 10)
  49. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  50. dd.add_field("z", [[5, 6]] * 10)
  51. assert len(dd) == 10
  52. assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10
  53. assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10
  54. assert dd.field_arrays["z"].content == [[5, 6]] * 10
  55. with pytest.raises(RuntimeError):
  56. dd.add_field("??", [[1, 2]] * 40)
  57. def test_delete_field(self):
  58. dd = DataSet()
  59. dd.add_field("x", [[1, 2, 3]] * 10)
  60. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  61. dd.delete_field("x")
  62. assert ("x" in dd.field_arrays) == False
  63. assert "y" in dd.field_arrays
  64. def test_delete_instance(self):
  65. dd = DataSet()
  66. old_length = 2
  67. dd.add_field("x", [[1, 2, 3]] * old_length)
  68. dd.add_field("y", [[1, 2, 3, 4]] * old_length)
  69. dd.delete_instance(0)
  70. assert len(dd) == old_length - 1
  71. dd.delete_instance(0)
  72. assert len(dd) == old_length - 2
  73. def test_getitem(self):
  74. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  75. ins_1, ins_0 = ds[0], ds[1]
  76. assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True
  77. assert ins_1["x"] == [1, 2, 3, 4]
  78. assert ins_1["y"] == [5, 6]
  79. assert ins_0["x"] == [1, 2, 3, 4]
  80. assert ins_0["y"] == [5, 6]
  81. sub_ds = ds[:10]
  82. assert isinstance(sub_ds, DataSet) == True
  83. assert len(sub_ds) == 10
  84. sub_ds_1 = ds[[10, 0, 2, 3]]
  85. assert isinstance(sub_ds_1, DataSet) == True
  86. assert len(sub_ds_1) == 4
  87. field_array = ds['x']
  88. assert isinstance(field_array, FieldArray) == True
  89. assert len(field_array) == 40
  90. def test_setitem(self):
  91. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  92. ds.add_field('i', list(range(len(ds))))
  93. assert ds.get_field('i').content == list(range(len(ds)))
  94. import random
  95. random.shuffle(ds)
  96. import numpy as np
  97. np.random.shuffle(ds)
  98. assert ds.get_field('i').content != list(range(len(ds)))
  99. ins1 = ds[1]
  100. ds[2] = ds[1]
  101. assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']
  102. def test_get_item_error(self):
  103. with pytest.raises(RuntimeError):
  104. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  105. _ = ds[40:]
  106. with pytest.raises(KeyError):
  107. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  108. _ = ds["kom"]
  109. def test_len_(self):
  110. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  111. assert len(ds) == 40
  112. ds = DataSet()
  113. assert len(ds) == 0
  114. def test_add_fieldarray(self):
  115. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  116. ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40))
  117. assert ds['z'].content == [[7, 8]] * 40
  118. with pytest.raises(RuntimeError):
  119. ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10))
  120. with pytest.raises(TypeError):
  121. ds.add_fieldarray('z', [1, 2, 4])
  122. def test_copy_field(self):
  123. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  124. ds.copy_field('x', 'z')
  125. assert ds['x'].content == ds['z'].content
  126. def test_has_field(self):
  127. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  128. assert ds.has_field('x') == True
  129. assert ds.has_field('z') == False
  130. def test_get_field(self):
  131. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  132. with pytest.raises(KeyError):
  133. ds.get_field('z')
  134. x_array = ds.get_field('x')
  135. assert x_array.content == [[1, 2, 3, 4]] * 40
  136. def test_get_all_fields(self):
  137. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  138. field_arrays = ds.get_all_fields()
  139. assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40
  140. assert field_arrays['y'].content == [[5, 6]] * 40
  141. def test_get_field_names(self):
  142. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  143. field_names = ds.get_field_names()
  144. assert 'x' in field_names
  145. assert 'y' in field_names
  146. def test_apply(self):
  147. ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000})
  148. ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx')
  149. assert ("rx" in ds.field_arrays) == True
  150. assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1]
  151. ds.apply(lambda ins: len(ins["y"]), new_field_name="y", progress_bar=None)
  152. assert ds.field_arrays["y"].content[0] == 2
  153. res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
  154. assert (isinstance(res, list) and len(res) > 0) == True
  155. assert res[0] == 4
  156. ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k")
  157. # expect no exception raised
  158. def test_apply_progress_bar(self):
  159. import time
  160. ds = DataSet({"x": [[1, 2, 3, 4]] * 400, "y": [[5, 6]] * 400})
  161. def do_nothing(ins):
  162. time.sleep(0.01)
  163. ds.apply(do_nothing, progress_bar='rich', num_proc=0)
  164. ds.apply_field(do_nothing, field_name='x', progress_bar='rich')
  165. def test_apply_cannot_modify_instance(self):
  166. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  167. def modify_inplace(instance):
  168. instance['words'] = 1
  169. ds.apply(modify_inplace)
  170. # with self.assertRaises(TypeError):
  171. # ds.apply(modify_inplace)
  172. def test_apply_more(self):
  173. T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]})
  174. func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2}
  175. func_2 = lambda x: {"c": x * 3, "d": x ** 3}
  176. def func_err_1(x):
  177. if x["a"] == 1:
  178. return {"e": x["a"] * 2, "f": x["a"] ** 2}
  179. else:
  180. return {"e": x["a"] * 2}
  181. def func_err_2(x):
  182. if x == 1:
  183. return {"e": x * 2, "f": x ** 2}
  184. else:
  185. return {"e": x * 2}
  186. T.apply_more(func_1)
  187. # print(T['c'][0, 1, 2])
  188. assert list(T["c"].content) == [2, 4, 6]
  189. assert list(T["d"].content) == [1, 4, 9]
  190. res = T.apply_field_more(func_2, "a", modify_fields=False)
  191. assert list(T["c"].content) == [2, 4, 6]
  192. assert list(T["d"].content) == [1, 4, 9]
  193. assert list(res["c"]) == [3, 6, 9]
  194. assert list(res["d"]) == [1, 8, 27]
  195. with pytest.raises(ApplyResultException) as e:
  196. T.apply_more(func_err_1)
  197. print(e)
  198. with pytest.raises(ApplyResultException) as e:
  199. T.apply_field_more(func_err_2, "a")
  200. print(e)
  201. def test_drop(self):
  202. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
  203. ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
  204. assert len(ds) == 20
  205. def test_contains(self):
  206. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  207. assert ("x" in ds) == True
  208. assert ("y" in ds) == True
  209. assert ("z" in ds) == False
  210. def test_rename_field(self):
  211. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  212. ds.rename_field("x", "xx")
  213. assert ("xx" in ds) == True
  214. assert ("x" in ds) == False
  215. with pytest.raises(KeyError):
  216. ds.rename_field("yyy", "oo")
  217. def test_split(self):
  218. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  219. d1, d2 = ds.split(0.1)
  220. assert len(d2) == (len(ds) * 0.9)
  221. assert len(d1) == (len(ds) * 0.1)
  222. def test_add_field_v2(self):
  223. ds = DataSet({"x": [3, 4]})
  224. ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']])
  225. # ds.apply(lambda x:[x['x']]*3, new_field_name='y')
  226. print(ds)
  227. def test_save_load(self):
  228. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  229. ds.save("./my_ds.pkl")
  230. assert os.path.exists("./my_ds.pkl") == True
  231. ds_1 = DataSet.load("./my_ds.pkl")
  232. os.remove("my_ds.pkl")
  233. def test_add_null(self):
  234. ds = DataSet()
  235. with pytest.raises(RuntimeError) as RE:
  236. ds.add_field('test', [])
  237. def test_concat(self):
  238. """
  239. 测试两个dataset能否正确concat
  240. """
  241. ds1 = DataSet({"x": [[1, 2, 3, 4] for _ in range(10)], "y": [[5, 6] for _ in range(10)]})
  242. ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]})
  243. ds3 = ds1.concat(ds2)
  244. assert len(ds3) == 20
  245. assert ds1[9]['x'] == [1, 2, 3, 4]
  246. assert ds1[10]['x'] == [4, 3, 2, 1]
  247. ds2[0]['x'][0] = 100
  248. assert ds3[10]['x'][0] == 4 # 不改变copy后的field了
  249. ds3[10]['x'][0] = -100
  250. assert ds2[0]['x'][0] == 100 # 不改变copy前的field了
  251. # 测试inplace
  252. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  253. ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]})
  254. ds3 = ds1.concat(ds2, inplace=True)
  255. ds2[0]['x'][0] = 100
  256. assert ds3[10]['x'][0] == 4 # 不改变copy后的field了
  257. ds3[10]['x'][0] = -100
  258. assert ds2[0]['x'][0] == 100 # 不改变copy前的field了
  259. ds3[0]['x'][0] = 100
  260. assert ds1[0]['x'][0] == 100 # 改变copy前的field了
  261. # 测试mapping
  262. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  263. ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
  264. ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
  265. assert len(ds3) == 20
  266. # 测试忽略掉多余的
  267. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  268. ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z': [0] * 10})
  269. ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
  270. # 测试报错
  271. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  272. ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
  273. with pytest.raises(RuntimeError):
  274. ds3 = ds1.concat(ds2, field_mapping={'X': 'x'})
  275. def test_instance_field_disappear_bug(self):
  276. data = DataSet({'raw_chars': [[0, 1], [2]], 'target': [0, 1]})
  277. data.copy_field(field_name='raw_chars', new_field_name='chars')
  278. _data = data[:1]
  279. for field_name in ['raw_chars', 'target', 'chars']:
  280. assert _data.has_field(field_name) == True
  281. def test_from_pandas(self):
  282. import pandas as pd
  283. df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
  284. ds = DataSet.from_pandas(df)
  285. print(ds)
  286. assert ds['x'].content == [1, 2, 3]
  287. assert ds['y'].content == [4, 5, 6]
  288. def test_to_pandas(self):
  289. ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
  290. df = ds.to_pandas()
  291. def test_to_csv(self):
  292. ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
  293. ds.to_csv("1.csv")
  294. assert os.path.exists("1.csv") == True
  295. os.remove("1.csv")
  296. def test_add_seq_len(self):
  297. ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
  298. ds.add_seq_len('x')
  299. print(ds)
  300. def test_apply_proc(self):
  301. data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
  302. data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=0)
  303. def test_apply_more_proc(self):
  304. def func(x):
  305. print("x")
  306. logger.info("demo")
  307. return len(x)
  308. data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
  309. data.apply_field(func, field_name='x', new_field_name='len_x', num_proc=2)
  310. class TestFieldArrayInit:
  311. """
  312. 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
  313. 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
  314. 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
  315. 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
  316. 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray;
  317. 然后后面的样本使用FieldArray.append进行添加。
  318. 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
  319. 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
  320. 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
  321. 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
  322. """
  323. def test_init_v1(self):
  324. # 二维list
  325. fa = FieldArray("x", [[1, 2], [3, 4]] * 5)
  326. def test_init_v2(self):
  327. # 二维array
  328. fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5))
  329. def test_init_v3(self):
  330. # 三维list
  331. fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]])
  332. def test_init_v4(self):
  333. # 一维list
  334. val = [1, 2, 3, 4]
  335. fa = FieldArray("x", [val])
  336. fa.append(val)
  337. def test_init_v5(self):
  338. # 一维array
  339. val = np.array([1, 2, 3, 4])
  340. fa = FieldArray("x", [val])
  341. fa.append(val)
  342. def test_init_v6(self):
  343. # 二维array
  344. val = [[1, 2], [3, 4]]
  345. fa = FieldArray("x", [val])
  346. fa.append(val)
  347. def test_init_v7(self):
  348. # list of array
  349. fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])])
  350. def test_init_v8(self):
  351. # 二维list
  352. val = np.array([[1, 2], [3, 4]])
  353. fa = FieldArray("x", [val])
  354. fa.append(val)
  355. class TestFieldArray:
  356. def test_main(self):
  357. fa = FieldArray("x", [1, 2, 3, 4, 5])
  358. assert len(fa) == 5
  359. fa.append(6)
  360. assert len(fa) == 6
  361. assert fa[-1] == 6
  362. assert fa[0] == 1
  363. fa[-1] = 60
  364. assert fa[-1] == 60
  365. assert fa.get(0) == 1
  366. assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True
  367. assert list(fa.get([0, 1, 2])) == [1, 2, 3]
  368. def test_getitem_v1(self):
  369. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
  370. assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
  371. ans = fa[[0, 1]]
  372. assert isinstance(ans, np.ndarray) == True
  373. assert isinstance(ans[0], np.ndarray) == True
  374. assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5]
  375. assert ans[1].tolist() == [1, 2, 3, 4, 5]
  376. assert ans.dtype == np.float64
  377. def test_getitem_v2(self):
  378. x = np.random.rand(10, 5)
  379. fa = FieldArray("my_field", x)
  380. indices = [0, 1, 3, 4, 6]
  381. for a, b in zip(fa[indices], x[indices]):
  382. assert a.tolist() == b.tolist()
  383. def test_append(self):
  384. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
  385. fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
  386. assert len(fa) == 3
  387. assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6]
  388. def test_pop(self):
  389. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
  390. fa.pop(0)
  391. assert len(fa) == 1
  392. assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0]
  393. fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5]
  394. assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
  395. class TestCase:
  396. def test_init(self):
  397. fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
  398. ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
  399. assert isinstance(ins.fields, dict) == True
  400. assert ins.fields == fields
  401. ins = Instance(**fields)
  402. assert ins.fields == fields
  403. def test_add_field(self):
  404. fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
  405. ins = Instance(**fields)
  406. ins.add_field("z", [1, 1, 1])
  407. fields.update({"z": [1, 1, 1]})
  408. assert ins.fields == fields
  409. def test_get_item(self):
  410. fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
  411. ins = Instance(**fields)
  412. assert ins["x"] == [1, 2, 3]
  413. assert ins["y"] == [4, 5, 6]
  414. assert ins["z"] == [1, 1, 1]
  415. def test_repr(self):
  416. fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
  417. ins = Instance(**fields)
  418. # simple print, that is enough.
  419. print(ins)
  420. def test_dataset(self):
  421. from datasets import Dataset as HuggingfaceDataset
  422. # ds = DataSet({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
  423. ds = HuggingfaceDataset.from_dict({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
  424. print(DataSet.from_datasets(ds))
  425. # print(ds.from_datasets())