You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_3.ipynb 8.1 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "213d538c",
  6. "metadata": {},
  7. "source": [
  8. "# T3. dataloader 的内部结构和基本使用\n",
  9. "\n",
  10. "  1   fastNLP 中的 dataloader\n",
  11. " \n",
  12. "    1.1   dataloader 的职责描述\n",
  13. "\n",
  14. "    1.2   dataloader 的基本使用\n",
  15. "\n",
  16. "  2   fastNLP 中 dataloader 的延伸\n",
  17. "\n",
  18. "    2.1   collator 的概念与使用\n",
  19. "\n",
  20. "    2.2   sampler 的概念与使用"
  21. ]
  22. },
  23. {
  24. "cell_type": "markdown",
  25. "id": "85857115",
  26. "metadata": {},
  27. "source": [
  28. "## 1. fastNLP 中的 dataloader\n",
  29. "\n",
  30. "### 1.1 dataloader 的职责描述\n",
  31. "\n",
  32. "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n",
  33. "\n",
  34. "  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n",
  35. "\n",
  36. "  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n",
  37. "\n",
  38. "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过"
  39. ]
  40. },
  41. {
  42. "cell_type": "markdown",
  43. "id": "eb8fb51c",
  44. "metadata": {},
  45. "source": [
  46. "### 1.2 dataloader 的基本使用\n",
  47. "\n",
  48. "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n",
  49. "\n",
  50. "  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n",
  51. "\n",
  52. "  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n",
  53. "\n",
  54. "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过"
  55. ]
  56. },
  57. {
  58. "cell_type": "code",
  59. "execution_count": null,
  60. "id": "aca72b49",
  61. "metadata": {
  62. "pycharm": {
  63. "name": "#%%\n"
  64. }
  65. },
  66. "outputs": [],
  67. "source": [
  68. "import pandas as pd\n",
  69. "from functools import partial\n",
  70. "from fastNLP.transformers.torch import BertTokenizer\n",
  71. "\n",
  72. "from fastNLP import DataSet\n",
  73. "from fastNLP import Vocabulary\n",
  74. "from fastNLP.io import DataBundle\n",
  75. "\n",
  76. "\n",
  77. "class PipeDemo:\n",
  78. " def __init__(self, tokenizer='bert-base-uncased', num_proc=1):\n",
  79. " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
  80. " self.num_proc = num_proc\n",
  81. "\n",
  82. " def process_from_file(self, path='./data/test4dataset.tsv'):\n",
  83. " datasets = DataSet.from_pandas(pd.read_csv(path))\n",
  84. " train_ds, test_ds = datasets.split(ratio=0.7)\n",
  85. " train_ds, dev_ds = datasets.split(ratio=0.8)\n",
  86. " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
  87. "\n",
  88. " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
  89. " return_attention_mask=True)\n",
  90. " data_bundle.apply_field_more(encode, field_name='text', num_proc=self.num_proc)\n",
  91. "\n",
  92. " target_vocab = Vocabulary(padding=None, unknown=None)\n",
  93. "\n",
  94. " target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
  95. " target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
  96. " new_field_name='target')\n",
  97. "\n",
  98. " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
  99. " data_bundle.set_ignore('label', 'text') \n",
  100. " return data_bundle"
  101. ]
  102. },
  103. {
  104. "cell_type": "markdown",
  105. "id": "de53bff4",
  106. "metadata": {},
  107. "source": [
  108. "  "
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": null,
  114. "id": "57a29cb9",
  115. "metadata": {},
  116. "outputs": [],
  117. "source": [
  118. "pipe = PipeDemo(tokenizer='bert-base-uncased', num_proc=4)\n",
  119. "\n",
  120. "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')"
  121. ]
  122. },
  123. {
  124. "cell_type": "markdown",
  125. "id": "226bb081",
  126. "metadata": {},
  127. "source": [
  128. "  "
  129. ]
  130. },
  131. {
  132. "cell_type": "code",
  133. "execution_count": null,
  134. "id": "7827557d",
  135. "metadata": {},
  136. "outputs": [],
  137. "source": [
  138. "from fastNLP import prepare_torch_dataloader\n",
  139. "\n",
  140. "dl_bundle = prepare_torch_dataloader(data_bundle, batch_size=arg.batch_size)"
  141. ]
  142. },
  143. {
  144. "cell_type": "markdown",
  145. "id": "d898cf40",
  146. "metadata": {},
  147. "source": [
  148. "  \n",
  149. "\n",
  150. "```python\n",
  151. "trainer = Trainer(\n",
  152. " model=model,\n",
  153. " train_dataloader=dl_bundle['train'],\n",
  154. " optimizers=optimizer,\n",
  155. "\t...\n",
  156. "\tdriver=\"torch\",\n",
  157. "\tdevice='cuda',\n",
  158. "\t...\n",
  159. " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n",
  160. " metrics={'acc': Accuracy()},\n",
  161. "\t...\n",
  162. ")\n",
  163. "```"
  164. ]
  165. },
  166. {
  167. "cell_type": "markdown",
  168. "id": "d74d0523",
  169. "metadata": {},
  170. "source": [
  171. "## 2. fastNLP 中 dataloader 的延伸\n",
  172. "\n",
  173. "### 2.1 collator 的概念与使用\n",
  174. "\n",
  175. "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n",
  176. "\n",
  177. "  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n",
  178. "\n",
  179. "  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n",
  180. "\n",
  181. "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过"
  182. ]
  183. },
  184. {
  185. "cell_type": "code",
  186. "execution_count": null,
  187. "id": "651baef6",
  188. "metadata": {
  189. "pycharm": {
  190. "name": "#%%\n"
  191. }
  192. },
  193. "outputs": [],
  194. "source": [
  195. "from fastNLP import prepare_torch_dataloader\n",
  196. "\n",
  197. "dl_bundle = prepare_torch_dataloader(data_bundle, train_batch_size=2)\n",
  198. "\n",
  199. "print(type(dl_bundle), type(dl_bundle['train']))"
  200. ]
  201. },
  202. {
  203. "cell_type": "code",
  204. "execution_count": null,
  205. "id": "726ba357",
  206. "metadata": {
  207. "pycharm": {
  208. "name": "#%%\n"
  209. }
  210. },
  211. "outputs": [],
  212. "source": [
  213. "dataloader = prepare_torch_dataloader(datasets['train'], train_batch_size=2)\n",
  214. "print(type(dataloader))\n",
  215. "print(dir(dataloader))"
  216. ]
  217. },
  218. {
  219. "cell_type": "code",
  220. "execution_count": null,
  221. "id": "d0795b3e",
  222. "metadata": {
  223. "pycharm": {
  224. "name": "#%%\n"
  225. }
  226. },
  227. "outputs": [],
  228. "source": [
  229. "dataloader.collate_fn"
  230. ]
  231. },
  232. {
  233. "cell_type": "markdown",
  234. "id": "f9bbd9a7",
  235. "metadata": {},
  236. "source": [
  237. "### 2.2 sampler 的概念与使用"
  238. ]
  239. },
  240. {
  241. "cell_type": "code",
  242. "execution_count": null,
  243. "id": "b0c3c58d",
  244. "metadata": {
  245. "pycharm": {
  246. "name": "#%%\n"
  247. }
  248. },
  249. "outputs": [],
  250. "source": [
  251. "dataloader.batch_sampler"
  252. ]
  253. },
  254. {
  255. "cell_type": "markdown",
  256. "id": "51bf0878",
  257. "metadata": {},
  258. "source": [
  259. "  "
  260. ]
  261. },
  262. {
  263. "cell_type": "code",
  264. "execution_count": null,
  265. "id": "3fd2486f",
  266. "metadata": {
  267. "pycharm": {
  268. "name": "#%%\n"
  269. }
  270. },
  271. "outputs": [],
  272. "source": []
  273. }
  274. ],
  275. "metadata": {
  276. "kernelspec": {
  277. "display_name": "Python 3 (ipykernel)",
  278. "language": "python",
  279. "name": "python3"
  280. },
  281. "language_info": {
  282. "codemirror_mode": {
  283. "name": "ipython",
  284. "version": 3
  285. },
  286. "file_extension": ".py",
  287. "mimetype": "text/x-python",
  288. "name": "python",
  289. "nbconvert_exporter": "python",
  290. "pygments_lexer": "ipython3",
  291. "version": "3.7.13"
  292. },
  293. "pycharm": {
  294. "stem_cell": {
  295. "cell_type": "raw",
  296. "metadata": {
  297. "collapsed": false
  298. },
  299. "source": []
  300. }
  301. }
  302. },
  303. "nbformat": 4,
  304. "nbformat_minor": 5
  305. }