You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_3.ipynb 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "213d538c",
  6. "metadata": {},
  7. "source": [
  8. "# T3. dataloader 的内部结构和基本使用\n",
  9. "\n",
  10. "  1   fastNLP 中的 dataloader\n",
  11. " \n",
  12. "    1.1   dataloader 的基本介绍\n",
  13. "\n",
  14. "    1.2   dataloader 的函数创建\n",
  15. "\n",
  16. "  2   fastNLP 中 dataloader 的延伸\n",
  17. "\n",
  18. "    2.1   collator 的概念与使用\n",
  19. "\n",
  20. "    2.2   结合 datasets 框架"
  21. ]
  22. },
  23. {
  24. "cell_type": "markdown",
  25. "id": "85857115",
  26. "metadata": {},
  27. "source": [
  28. "## 1. fastNLP 中的 dataloader\n",
  29. "\n",
  30. "### 1.1 dataloader 的基本介绍\n",
  31. "\n",
  32. "在`fastNLP 0.8`的开发中,最关键的开发目标就是**实现`fastNLP`对当前主流机器学习框架**,例如\n",
  33. "\n",
  34. "  **较为火热的`pytorch`**,以及**国产的`paddle`和`jittor`的兼容**,扩大受众的同时,也是助力国产\n",
  35. "\n",
  36. "本着分而治之的思想,我们可以将`fastNLP 0.8`对`pytorch`、`paddle`、`jittor`框架的兼容,划分为\n",
  37. "\n",
  38. "    **对数据预处理**、**批量`batch`的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n",
  39. "\n",
  40. "  针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n",
  41. "\n",
  42. "    而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n",
  43. "\n",
  44. "    因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n",
  45. "\n",
  46. "只有涉及到张量、模型,不同框架才展现出其各自的特色:**`pytorch`中的`tensor`和`nn.Module`**\n",
  47. "\n",
  48. "    **在`paddle`中称为`tensor`和`nn.Layer`**,**在`jittor`中则称为`Var`和`Module`**\n",
  49. "\n",
  50. "    因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n",
  51. "\n",
  52. "  针对批量`batch`的处理,作为`fastNLP 0.8`中框架无关部分想框架相关部分的过渡\n",
  53. "\n",
  54. "    就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n",
  55. "\n",
  56. "**`dataloader`模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n",
  57. "\n",
  58. "    第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n",
  59. "\n",
  60. "    第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n",
  61. "\n",
  62. "    第三,**`batch`内数据格式要匹配框架**,**但`batch`结构需保持一致**,**参数匹配机制**\n",
  63. "\n",
  64. "  对此,`fastNLP 0.8`给出了 **`TorchDataLoader`、`PaddleDataLoader`和`JittorDataLoader`**\n",
  65. "\n",
  66. "    分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n",
  67. "\n",
  68. "| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
  69. "|:--|:--:|:--:|:--|:--|\n",
  70. "| **`dataset`** | √ | √ | 指定`dataloader`的数据内容 | |\n",
  71. "| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n",
  72. "| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n",
  73. "| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n",
  74. "| `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n",
  75. "| `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n",
  76. "| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n",
  77. "| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n",
  78. "| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n",
  79. "| `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n",
  80. "| `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n",
  81. "| `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |"
  82. ]
  83. },
  84. {
  85. "cell_type": "markdown",
  86. "id": "60a8a224",
  87. "metadata": {},
  88. "source": [
  89. "&emsp; 论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n",
  90. "\n",
  91. "&emsp; &emsp; 包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n",
  92. "\n",
  93. "&emsp; &emsp; 以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理"
  94. ]
  95. },
  96. {
  97. "cell_type": "code",
  98. "execution_count": 1,
  99. "id": "aca72b49",
  100. "metadata": {
  101. "pycharm": {
  102. "name": "#%%\n"
  103. }
  104. },
  105. "outputs": [
  106. {
  107. "name": "stderr",
  108. "output_type": "stream",
  109. "text": [
  110. "\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n"
  111. ]
  112. },
  113. {
  114. "data": {
  115. "text/html": [
  116. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  117. "</pre>\n"
  118. ],
  119. "text/plain": [
  120. "\n"
  121. ]
  122. },
  123. "metadata": {},
  124. "output_type": "display_data"
  125. },
  126. {
  127. "data": {
  128. "application/vnd.jupyter.widget-view+json": {
  129. "model_id": "",
  130. "version_major": 2,
  131. "version_minor": 0
  132. },
  133. "text/plain": [
  134. "Processing: 0%| | 0/4 [00:00<?, ?it/s]"
  135. ]
  136. },
  137. "metadata": {},
  138. "output_type": "display_data"
  139. },
  140. {
  141. "data": {
  142. "application/vnd.jupyter.widget-view+json": {
  143. "model_id": "",
  144. "version_major": 2,
  145. "version_minor": 0
  146. },
  147. "text/plain": [
  148. "Processing: 0%| | 0/2 [00:00<?, ?it/s]"
  149. ]
  150. },
  151. "metadata": {},
  152. "output_type": "display_data"
  153. },
  154. {
  155. "data": {
  156. "application/vnd.jupyter.widget-view+json": {
  157. "model_id": "",
  158. "version_major": 2,
  159. "version_minor": 0
  160. },
  161. "text/plain": [
  162. "Processing: 0%| | 0/2 [00:00<?, ?it/s]"
  163. ]
  164. },
  165. "metadata": {},
  166. "output_type": "display_data"
  167. },
  168. {
  169. "name": "stdout",
  170. "output_type": "stream",
  171. "text": [
  172. "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n",
  173. "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask | target |\n",
  174. "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n",
  175. "| 1 | A series of... | negative | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n",
  176. "| 4 | A positivel... | neutral | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n",
  177. "| 3 | Even fans o... | negative | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n",
  178. "| 5 | A comedy-dr... | positive | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n",
  179. "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n"
  180. ]
  181. }
  182. ],
  183. "source": [
  184. "import sys\n",
  185. "sys.path.append('..')\n",
  186. "\n",
  187. "import pandas as pd\n",
  188. "from functools import partial\n",
  189. "from fastNLP.transformers.torch import BertTokenizer\n",
  190. "\n",
  191. "from fastNLP import DataSet\n",
  192. "from fastNLP import Vocabulary\n",
  193. "from fastNLP.io import DataBundle\n",
  194. "\n",
  195. "\n",
  196. "class PipeDemo:\n",
  197. " def __init__(self, tokenizer='bert-base-uncased'):\n",
  198. " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
  199. "\n",
  200. " def process_from_file(self, path='./data/test4dataset.tsv'):\n",
  201. " datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n",
  202. " train_ds, test_ds = datasets.split(ratio=0.7)\n",
  203. " train_ds, dev_ds = datasets.split(ratio=0.8)\n",
  204. " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
  205. "\n",
  206. " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
  207. " return_attention_mask=True)\n",
  208. " data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
  209. " \n",
  210. " target_vocab = Vocabulary(padding=None, unknown=None)\n",
  211. "\n",
  212. " target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
  213. " target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
  214. " new_field_name='target')\n",
  215. "\n",
  216. " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
  217. " data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment') \n",
  218. " return data_bundle\n",
  219. "\n",
  220. " \n",
  221. "pipe = PipeDemo(tokenizer='bert-base-uncased')\n",
  222. "\n",
  223. "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n",
  224. "\n",
  225. "print(data_bundle.get_dataset('train'))"
  226. ]
  227. },
  228. {
  229. "cell_type": "markdown",
  230. "id": "76e6b8ab",
  231. "metadata": {},
  232. "source": [
  233. "### 1.2 dataloader 的函数创建\n",
  234. "\n",
  235. "在`fastNLP 0.8`中,**更方便、可能更常用的`dataloader`创建方法是通过`prepare_xx_dataloader`函数**\n",
  236. "\n",
  237. "&emsp; 例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n",
  238. "\n",
  239. "&emsp; 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`\n",
  240. "\n",
  241. "同时我们看还可以发现,在`fastNLP 0.8`中,**`batch`表示为字典`dict`类型**,**`key`值就是原先数据集中各个字段**\n",
  242. "\n",
  243. "&emsp; **除去经过`DataBundle.set_ignore`函数隐去的部分**,而`value`值为`pytorch`框架对应的`torch.Tensor`类型"
  244. ]
  245. },
  246. {
  247. "cell_type": "code",
  248. "execution_count": 2,
  249. "id": "5fd60e42",
  250. "metadata": {},
  251. "outputs": [
  252. {
  253. "name": "stdout",
  254. "output_type": "stream",
  255. "text": [
  256. "<class 'fastNLP.core.dataloaders.torch_dataloader.fdl.TorchDataLoader'>\n",
  257. "<class 'dict'> <class 'torch.Tensor'> ['input_ids', 'token_type_ids', 'attention_mask', 'target']\n",
  258. "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  259. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  260. " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  261. " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
  262. " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  263. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  264. " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  265. " 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n",
  266. " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n",
  267. " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n",
  268. " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
  269. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  270. " 0, 0, 0, 0],\n",
  271. " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n",
  272. " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n",
  273. " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n",
  274. " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n",
  275. " 1037, 2466, 1012, 102],\n",
  276. " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n",
  277. " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n",
  278. " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
  279. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  280. " 0, 0, 0, 0],\n",
  281. " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n",
  282. " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n",
  283. " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n",
  284. " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n",
  285. " 0, 0, 0, 0]]),\n",
  286. " 'target': tensor([0, 1, 1, 2]),\n",
  287. " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  288. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  289. " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  290. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  291. " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  292. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  293. " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  294. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
  295. ]
  296. }
  297. ],
  298. "source": [
  299. "from fastNLP import prepare_torch_dataloader\n",
  300. "\n",
  301. "train_dataset = data_bundle.get_dataset('train')\n",
  302. "evaluate_dataset = data_bundle.get_dataset('dev')\n",
  303. "\n",
  304. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  305. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)\n",
  306. "\n",
  307. "print(type(train_dataloader))\n",
  308. "\n",
  309. "import pprint\n",
  310. "\n",
  311. "for batch in train_dataloader:\n",
  312. " print(type(batch), type(batch['input_ids']), list(batch))\n",
  313. " pprint.pprint(batch, width=1)"
  314. ]
  315. },
  316. {
  317. "cell_type": "markdown",
  318. "id": "9f457a6e",
  319. "metadata": {},
  320. "source": [
  321. "之所以说`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n",
  322. "\n",
  323. "&emsp; **是`DataBundle`类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n",
  324. "\n",
  325. "例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n"
  326. ]
  327. },
  328. {
  329. "cell_type": "code",
  330. "execution_count": 3,
  331. "id": "7827557d",
  332. "metadata": {},
  333. "outputs": [
  334. {
  335. "name": "stdout",
  336. "output_type": "stream",
  337. "text": [
  338. "<class 'fastNLP.core.dataloaders.paddle_dataloader.fdl.PaddleDataLoader'>\n"
  339. ]
  340. }
  341. ],
  342. "source": [
  343. "from fastNLP import prepare_paddle_dataloader\n",
  344. "\n",
  345. "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)\n",
  346. "\n",
  347. "print(type(dl_bundle['train']))"
  348. ]
  349. },
  350. {
  351. "cell_type": "markdown",
  352. "id": "d898cf40",
  353. "metadata": {},
  354. "source": [
  355. "&emsp; 而在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n",
  356. "\n",
  357. "&emsp; 这里也可以看出`trainer`模块中,**`evaluate_dataloaders`的设计允许评测可以针对多个数据集**\n",
  358. "\n",
  359. "```python\n",
  360. "trainer = Trainer(\n",
  361. " model=model,\n",
  362. " train_dataloader=dl_bundle['train'],\n",
  363. " optimizers=optimizer,\n",
  364. "\t...\n",
  365. "\tdriver='paddle',\n",
  366. "\tdevice='gpu',\n",
  367. "\t...\n",
  368. " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n",
  369. " metrics={'acc': Accuracy()},\n",
  370. "\t...\n",
  371. ")\n",
  372. "```"
  373. ]
  374. },
  375. {
  376. "cell_type": "markdown",
  377. "id": "d74d0523",
  378. "metadata": {},
  379. "source": [
  380. "## 2. fastNLP 中 dataloader 的延伸\n",
  381. "\n",
  382. "### 2.1 collator 的概念与使用\n",
  383. "\n",
  384. "在`fastNLP 0.8`中,在数据加载模块`dataloader`内部,如之前表格所列举的,还存在其他的一些模块\n",
  385. "\n",
  386. "&emsp; 例如,**实现序列的补零对齐的核对器`collator`模块**;注:`collate vt. 整理(文件或书等);核对,校勘`\n",
  387. "\n",
  388. "在`fastNLP 0.8`中,虽然`dataloader`随框架不同,但`collator`模块却是统一的,主要属性、方法如下表所示\n",
  389. "\n",
  390. "| <div align=\"center\">名称</div> | <div align=\"center\">属性</div> | <div align=\"center\">方法</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
  391. "|:--|:--:|:--:|:--|:--|\n",
  392. "| `backend` | √ | | 记录`collator`对应框架 | 字符串型,如`'torch'` |\n",
  393. "| `padders` | √ | | 记录各字段对应的`padder`,每个负责具体补零对齐&emsp; | 字典类型 |\n",
  394. "| `ignore_fields` | √ | | 记录`dataloader`采样`batch`时不予考虑的字段 | 集合类型 |\n",
  395. "| `input_fields` | √ | | 记录`collator`每个字段的补零值、数据类型等 | 字典类型 |\n",
  396. "| `set_backend` | | √ | 设置`collator`对应框架 | 字符串型,如`'torch'` |\n",
  397. "| `set_ignore` | | √ | 设置`dataloader`采样`batch`时不予考虑的字段 | 字符串型,表示`field_name`&emsp; |\n",
  398. "| `set_pad` | | √ | 设置`collator`每个字段的补零值、数据类型等 | |"
  399. ]
  400. },
  401. {
  402. "cell_type": "code",
  403. "execution_count": 4,
  404. "id": "d0795b3e",
  405. "metadata": {
  406. "pycharm": {
  407. "name": "#%%\n"
  408. }
  409. },
  410. "outputs": [
  411. {
  412. "name": "stdout",
  413. "output_type": "stream",
  414. "text": [
  415. "<class 'function'>\n"
  416. ]
  417. }
  418. ],
  419. "source": [
  420. "train_dataloader.collate_fn\n",
  421. "\n",
  422. "print(type(train_dataloader.collate_fn))"
  423. ]
  424. },
  425. {
  426. "cell_type": "markdown",
  427. "id": "5f816ef5",
  428. "metadata": {},
  429. "source": [
  430. "此外,还可以**手动定义`dataloader`中的`collate_fn`**,而不是使用`fastNLP 0.8`中自带的`collator`模块\n",
  431. "\n",
  432. "&emsp; 该函数的定义可以大致如下,需要注意的是,**定义`collate_fn`之前需要了解`batch`作为字典的格式**\n",
  433. "\n",
  434. "&emsp; 该函数通过`collate_fn`参数传入`dataloader`,**在`batch`分发**(**而不是`batch`划分**)**时调用**"
  435. ]
  436. },
  437. {
  438. "cell_type": "code",
  439. "execution_count": 5,
  440. "id": "ff8e405e",
  441. "metadata": {},
  442. "outputs": [],
  443. "source": [
  444. "import torch\n",
  445. "\n",
  446. "def collate_fn(batch):\n",
  447. " input_ids, atten_mask, labels = [], [], []\n",
  448. " max_length = [0] * 3\n",
  449. " for each_item in batch:\n",
  450. " input_ids.append(each_item['input_ids'])\n",
  451. " max_length[0] = max(len(each_item['input_ids']), max_length[0])\n",
  452. " atten_mask.append(each_item['token_type_ids'])\n",
  453. " max_length[1] = max(len(each_item['token_type_ids']), max_length[1])\n",
  454. " labels.append(each_item['attention_mask'])\n",
  455. " max_length[2] = max(len(each_item['attention_mask']), max_length[2])\n",
  456. "\n",
  457. " for i in range(3):\n",
  458. " each = (input_ids, atten_mask, labels)[i]\n",
  459. " for item in each:\n",
  460. " item.extend([0] * (max_length[i] - len(item)))\n",
  461. " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
  462. " 'token_type_ids': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
  463. " 'attention_mask': torch.cat([torch.tensor(item) for item in labels], dim=0)}"
  464. ]
  465. },
  466. {
  467. "cell_type": "markdown",
  468. "id": "487b75fb",
  469. "metadata": {},
  470. "source": [
  471. "注意:使用自定义的`collate_fn`函数,`trainer`的`collate_fn`变量也会自动调整为`function`类型"
  472. ]
  473. },
  474. {
  475. "cell_type": "code",
  476. "execution_count": 6,
  477. "id": "e916d1ac",
  478. "metadata": {},
  479. "outputs": [
  480. {
  481. "name": "stdout",
  482. "output_type": "stream",
  483. "text": [
  484. "<class 'fastNLP.core.dataloaders.torch_dataloader.fdl.TorchDataLoader'>\n",
  485. "<class 'function'>\n",
  486. "{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  487. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
  488. " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  489. " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  490. " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  491. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
  492. " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n",
  493. " 0, 0, 0, 0, 0, 0, 0, 0]),\n",
  494. " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n",
  495. " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n",
  496. " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
  497. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  498. " 0, 0, 0, 0],\n",
  499. " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n",
  500. " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n",
  501. " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n",
  502. " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n",
  503. " 1037, 2466, 1012, 102],\n",
  504. " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n",
  505. " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n",
  506. " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
  507. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  508. " 0, 0, 0, 0],\n",
  509. " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n",
  510. " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n",
  511. " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n",
  512. " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n",
  513. " 0, 0, 0, 0]]),\n",
  514. " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  515. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  516. " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  517. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  518. " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  519. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
  520. " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
  521. " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
  522. ]
  523. }
  524. ],
  525. "source": [
  526. "train_dataloader = prepare_torch_dataloader(train_dataset, collate_fn=collate_fn, shuffle=True)\n",
  527. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, collate_fn=collate_fn, shuffle=True)\n",
  528. "\n",
  529. "print(type(train_dataloader))\n",
  530. "print(type(train_dataloader.collate_fn))\n",
  531. "\n",
  532. "for batch in train_dataloader:\n",
  533. " pprint.pprint(batch, width=1)"
  534. ]
  535. },
  536. {
  537. "cell_type": "markdown",
  538. "id": "0bd98365",
  539. "metadata": {},
  540. "source": [
  541. "### 2.2 fastNLP 与 datasets 的结合\n",
  542. "\n",
  543. "从`tutorial-1`至`tutorial-3`,我们已经完成了对`fastNLP v0.8`数据读取、预处理、加载,整个流程的介绍\n",
  544. "\n",
  545. "&emsp; 不过在实际使用中,我们往往也会采取更为简便的方法读取数据,例如使用`huggingface`的`datasets`模块\n",
  546. "\n",
  547. "**使用`datasets`模块中的`load_dataset`函数**,通过指定数据集两级的名称,示例中即是**`GLUE`标准中的`SST-2`数据集**\n",
  548. "\n",
  549. "&emsp; 即可以快速从网上下载好`SST-2`数据集读入,之后以`pandas.DataFrame`作为中介,再转化成`fastNLP.DataSet`\n",
  550. "\n",
  551. "&emsp; 之后的步骤就和其他关于`dataset`、`databundle`、`vocabulary`、`dataloader`中介绍的相关使用相同了"
  552. ]
  553. },
  554. {
  555. "cell_type": "code",
  556. "execution_count": 7,
  557. "id": "91879c30",
  558. "metadata": {},
  559. "outputs": [
  560. {
  561. "name": "stderr",
  562. "output_type": "stream",
  563. "text": [
  564. "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
  565. ]
  566. },
  567. {
  568. "data": {
  569. "application/vnd.jupyter.widget-view+json": {
  570. "model_id": "639a0ad3c63944c6abef4e8ee1f7bf7c",
  571. "version_major": 2,
  572. "version_minor": 0
  573. },
  574. "text/plain": [
  575. " 0%| | 0/3 [00:00<?, ?it/s]"
  576. ]
  577. },
  578. "metadata": {},
  579. "output_type": "display_data"
  580. }
  581. ],
  582. "source": [
  583. "from datasets import load_dataset\n",
  584. "\n",
  585. "sst2data = load_dataset('glue', 'sst2')\n",
  586. "\n",
  587. "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())"
  588. ]
  589. }
  590. ],
  591. "metadata": {
  592. "kernelspec": {
  593. "display_name": "Python 3 (ipykernel)",
  594. "language": "python",
  595. "name": "python3"
  596. },
  597. "language_info": {
  598. "codemirror_mode": {
  599. "name": "ipython",
  600. "version": 3
  601. },
  602. "file_extension": ".py",
  603. "mimetype": "text/x-python",
  604. "name": "python",
  605. "nbconvert_exporter": "python",
  606. "pygments_lexer": "ipython3",
  607. "version": "3.7.13"
  608. },
  609. "pycharm": {
  610. "stem_cell": {
  611. "cell_type": "raw",
  612. "metadata": {
  613. "collapsed": false
  614. },
  615. "source": []
  616. }
  617. }
  618. },
  619. "nbformat": 4,
  620. "nbformat_minor": 5
  621. }