You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_3.ipynb 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "213d538c",
  6. "metadata": {},
  7. "source": [
  8. "# T3. dataloader 的内部结构和基本使用\n",
  9. "\n",
  10. "  1   fastNLP 中的 dataloader\n",
  11. " \n",
  12. "    1.1   dataloader 的职责描述\n",
  13. "\n",
  14. "    1.2   dataloader 的基本使用\n",
  15. "\n",
  16. "  2   fastNLP 中 dataloader 的延伸\n",
  17. "\n",
  18. "    2.1   collator 的概念与使用\n",
  19. "\n",
  20. "    2.2   sampler 的概念与使用"
  21. ]
  22. },
  23. {
  24. "cell_type": "markdown",
  25. "id": "85857115",
  26. "metadata": {},
  27. "source": [
  28. "## 1. fastNLP 中的 dataloader\n",
  29. "\n",
  30. "### 1.1 dataloader 的职责描述\n",
  31. "\n",
  32. "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前"
  33. ]
  34. },
  35. {
  36. "cell_type": "markdown",
  37. "id": "eb8fb51c",
  38. "metadata": {},
  39. "source": [
  40. "### 1.2 dataloader 的基本使用\n",
  41. "\n",
  42. "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,"
  43. ]
  44. },
  45. {
  46. "cell_type": "code",
  47. "execution_count": null,
  48. "id": "aca72b49",
  49. "metadata": {
  50. "pycharm": {
  51. "name": "#%%\n"
  52. }
  53. },
  54. "outputs": [],
  55. "source": [
  56. "import pandas as pd\n",
  57. "from functools import partial\n",
  58. "from fastNLP.transformers.torch import BertTokenizer\n",
  59. "\n",
  60. "from fastNLP import DataSet\n",
  61. "from fastNLP import Vocabulary\n",
  62. "from fastNLP.io import DataBundle\n",
  63. "\n",
  64. "\n",
  65. "class PipeDemo:\n",
  66. " def __init__(self, tokenizer='bert-base-uncased', num_proc=1):\n",
  67. " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
  68. " self.num_proc = num_proc\n",
  69. "\n",
  70. " def process_from_file(self, path='./data/test4dataset.tsv'):\n",
  71. " datasets = DataSet.from_pandas(pd.read_csv(path))\n",
  72. " train_ds, test_ds = datasets.split(ratio=0.7)\n",
  73. " train_ds, dev_ds = datasets.split(ratio=0.8)\n",
  74. " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
  75. "\n",
  76. " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
  77. " return_attention_mask=True)\n",
  78. " data_bundle.apply_field_more(encode, field_name='text', num_proc=self.num_proc)\n",
  79. "\n",
  80. " target_vocab = Vocabulary(padding=None, unknown=None)\n",
  81. "\n",
  82. " target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
  83. " target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
  84. " new_field_name='target')\n",
  85. "\n",
  86. " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
  87. " data_bundle.set_ignore('label', 'text') \n",
  88. " return data_bundle"
  89. ]
  90. },
  91. {
  92. "cell_type": "markdown",
  93. "id": "de53bff4",
  94. "metadata": {},
  95. "source": [
  96. "  "
  97. ]
  98. },
  99. {
  100. "cell_type": "code",
  101. "execution_count": null,
  102. "id": "57a29cb9",
  103. "metadata": {},
  104. "outputs": [],
  105. "source": [
  106. "pipe = PipeDemo(tokenizer='bert-base-uncased', num_proc=4)\n",
  107. "\n",
  108. "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')"
  109. ]
  110. },
  111. {
  112. "cell_type": "markdown",
  113. "id": "226bb081",
  114. "metadata": {},
  115. "source": [
  116. "  "
  117. ]
  118. },
  119. {
  120. "cell_type": "code",
  121. "execution_count": null,
  122. "id": "7827557d",
  123. "metadata": {},
  124. "outputs": [],
  125. "source": [
  126. "from fastNLP import prepare_torch_dataloader\n",
  127. "\n",
  128. "dl_bundle = prepare_torch_dataloader(data_bundle, batch_size=arg.batch_size)"
  129. ]
  130. },
  131. {
  132. "cell_type": "markdown",
  133. "id": "d898cf40",
  134. "metadata": {},
  135. "source": [
  136. "  \n",
  137. "\n",
  138. "```python\n",
  139. "trainer = Trainer(\n",
  140. " model=model,\n",
  141. " train_dataloader=dl_bundle['train'],\n",
  142. " optimizers=optimizer,\n",
  143. "\t...\n",
  144. "\tdriver=\"torch\",\n",
  145. "\tdevice='cuda',\n",
  146. "\t...\n",
  147. " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n",
  148. " metrics={'acc': Accuracy()},\n",
  149. "\t...\n",
  150. ")\n",
  151. "```"
  152. ]
  153. },
  154. {
  155. "cell_type": "markdown",
  156. "id": "d74d0523",
  157. "metadata": {},
  158. "source": [
  159. "## 2. fastNLP 中 dataloader 的延伸\n",
  160. "\n",
  161. "### 2.1 collator 的概念与使用\n",
  162. "\n",
  163. "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n",
  164. "\n",
  165. "  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n",
  166. "\n",
  167. "  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n",
  168. "\n",
  169. "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过"
  170. ]
  171. },
  172. {
  173. "cell_type": "code",
  174. "execution_count": null,
  175. "id": "651baef6",
  176. "metadata": {
  177. "pycharm": {
  178. "name": "#%%\n"
  179. }
  180. },
  181. "outputs": [],
  182. "source": [
  183. "from fastNLP import prepare_torch_dataloader\n",
  184. "\n",
  185. "dl_bundle = prepare_torch_dataloader(data_bundle, train_batch_size=2)\n",
  186. "\n",
  187. "print(type(dl_bundle), type(dl_bundle['train']))"
  188. ]
  189. },
  190. {
  191. "cell_type": "code",
  192. "execution_count": null,
  193. "id": "726ba357",
  194. "metadata": {
  195. "pycharm": {
  196. "name": "#%%\n"
  197. }
  198. },
  199. "outputs": [],
  200. "source": [
  201. "dataloader = prepare_torch_dataloader(datasets['train'], train_batch_size=2)\n",
  202. "print(type(dataloader))\n",
  203. "print(dir(dataloader))"
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": null,
  209. "id": "d0795b3e",
  210. "metadata": {
  211. "pycharm": {
  212. "name": "#%%\n"
  213. }
  214. },
  215. "outputs": [],
  216. "source": [
  217. "dataloader.collate_fn"
  218. ]
  219. },
  220. {
  221. "cell_type": "markdown",
  222. "id": "f9bbd9a7",
  223. "metadata": {},
  224. "source": [
  225. "### 2.2 sampler 的概念与使用"
  226. ]
  227. },
  228. {
  229. "cell_type": "code",
  230. "execution_count": null,
  231. "id": "b0c3c58d",
  232. "metadata": {
  233. "pycharm": {
  234. "name": "#%%\n"
  235. }
  236. },
  237. "outputs": [],
  238. "source": [
  239. "dataloader.batch_sampler"
  240. ]
  241. },
  242. {
  243. "cell_type": "markdown",
  244. "id": "51bf0878",
  245. "metadata": {},
  246. "source": [
  247. "  "
  248. ]
  249. },
  250. {
  251. "cell_type": "code",
  252. "execution_count": null,
  253. "id": "3fd2486f",
  254. "metadata": {
  255. "pycharm": {
  256. "name": "#%%\n"
  257. }
  258. },
  259. "outputs": [],
  260. "source": []
  261. }
  262. ],
  263. "metadata": {
  264. "kernelspec": {
  265. "display_name": "Python 3 (ipykernel)",
  266. "language": "python",
  267. "name": "python3"
  268. },
  269. "language_info": {
  270. "codemirror_mode": {
  271. "name": "ipython",
  272. "version": 3
  273. },
  274. "file_extension": ".py",
  275. "mimetype": "text/x-python",
  276. "name": "python",
  277. "nbconvert_exporter": "python",
  278. "pygments_lexer": "ipython3",
  279. "version": "3.7.13"
  280. },
  281. "pycharm": {
  282. "stem_cell": {
  283. "cell_type": "raw",
  284. "metadata": {
  285. "collapsed": false
  286. },
  287. "source": []
  288. }
  289. }
  290. },
  291. "nbformat": 4,
  292. "nbformat_minor": 5
  293. }