You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_3.ipynb 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "213d538c",
  6. "metadata": {},
  7. "source": [
  8. "# T3. dataloader 的内部结构和基本使用\n",
  9. "\n",
  10. "  1   fastNLP 中的 dataloader\n",
  11. " \n",
  12. "    1.1   dataloader 的基本介绍\n",
  13. "\n",
  14. "    1.2   dataloader 的函数创建\n",
  15. "\n",
  16. "  2   fastNLP 中 dataloader 的延伸\n",
  17. "\n",
  18. "    2.1   collator 的概念与使用\n",
  19. "\n",
  20. "    2.2   sampler 的概念与使用"
  21. ]
  22. },
  23. {
  24. "cell_type": "markdown",
  25. "id": "85857115",
  26. "metadata": {},
  27. "source": [
  28. "## 1. fastNLP 中的 dataloader\n",
  29. "\n",
  30. "### 1.1 dataloader 的基本介绍\n",
  31. "\n",
  32. "在`fastNLP 0.8`的开发中,最关键的开发目标就是**实现`fastNLP`对当前主流机器学习框架**,例如\n",
  33. "\n",
  34. "  **较为火热的`pytorch`**,以及**国产的`paddle`和`jittor`的兼容**,扩大受众的同时,也是助力国产\n",
  35. "\n",
  36. "本着分而治之的思想,我们可以将`fastNLP 0.8`对`pytorch`、`paddle`、`jittor`框架的兼容,划分为\n",
  37. "\n",
  38. "    **对数据预处理**、**批量`batch`的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n",
  39. "\n",
  40. "  针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n",
  41. "\n",
  42. "    而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n",
  43. "\n",
  44. "    因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n",
  45. "\n",
  46. "只有涉及到张量、模型,不同框架才展现出其各自的特色:**`pytorch`中的`tensor`和`nn.Module`**\n",
  47. "\n",
  48. "    **在`paddle`中称为`tensor`和`nn.Layer`**,**在`jittor`中则称为`Var`和`Module`**\n",
  49. "\n",
  50. "    因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n",
  51. "\n",
  52. "  针对批量`batch`的处理,作为`fastNLP 0.8`中框架无关部分想框架相关部分的过渡\n",
  53. "\n",
  54. "    就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n",
  55. "\n",
  56. "**`dataloader`模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n",
  57. "\n",
  58. "    第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n",
  59. "\n",
  60. "    第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n",
  61. "\n",
  62. "    第三,**`batch`内数据格式要匹配框架**,**但`batch`结构需保持一致**,**参数匹配机制**\n",
  63. "\n",
  64. "  对此,`fastNLP 0.8`给出了 **`TorchDataLoader`、`PaddleDataLoader`和`JittorDataLoader`**\n",
  65. "\n",
  66. "    分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n",
  67. "\n",
  68. "| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
  69. "|:--|:--:|:--:|:--|:--|\n",
  70. "| **`dataset`** | √ | √ | 指定`dataloader`的数据内容 | |\n",
  71. "| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n",
  72. "| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n",
  73. "| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n",
  74. "| `sampler` | √ | √ | ? | 默认`None` |\n",
  75. "| `batch_sampler` | √ | √ | ? | 默认`None` |\n",
  76. "| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n",
  77. "| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n",
  78. "| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n",
  79. "| `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n",
  80. "| `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n",
  81. "| `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |"
  82. ]
  83. },
  84. {
  85. "cell_type": "markdown",
  86. "id": "60a8a224",
  87. "metadata": {},
  88. "source": [
  89. "&emsp; 论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n",
  90. "\n",
  91. "&emsp; &emsp; 包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n",
  92. "\n",
  93. "&emsp; &emsp; 以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理"
  94. ]
  95. },
  96. {
  97. "cell_type": "code",
  98. "execution_count": 5,
  99. "id": "aca72b49",
  100. "metadata": {
  101. "pycharm": {
  102. "name": "#%%\n"
  103. }
  104. },
  105. "outputs": [
  106. {
  107. "data": {
  108. "application/vnd.jupyter.widget-view+json": {
  109. "model_id": "",
  110. "version_major": 2,
  111. "version_minor": 0
  112. },
  113. "text/plain": [
  114. "Processing: 0%| | 0/4 [00:00<?, ?it/s]"
  115. ]
  116. },
  117. "metadata": {},
  118. "output_type": "display_data"
  119. },
  120. {
  121. "data": {
  122. "application/vnd.jupyter.widget-view+json": {
  123. "model_id": "",
  124. "version_major": 2,
  125. "version_minor": 0
  126. },
  127. "text/plain": [
  128. "Processing: 0%| | 0/2 [00:00<?, ?it/s]"
  129. ]
  130. },
  131. "metadata": {},
  132. "output_type": "display_data"
  133. },
  134. {
  135. "data": {
  136. "application/vnd.jupyter.widget-view+json": {
  137. "model_id": "",
  138. "version_major": 2,
  139. "version_minor": 0
  140. },
  141. "text/plain": [
  142. "Processing: 0%| | 0/2 [00:00<?, ?it/s]"
  143. ]
  144. },
  145. "metadata": {},
  146. "output_type": "display_data"
  147. },
  148. {
  149. "name": "stdout",
  150. "output_type": "stream",
  151. "text": [
  152. "+------------+------------------+-----------+------------------+--------------------+--------------------+\n",
  153. "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask |\n",
  154. "+------------+------------------+-----------+------------------+--------------------+--------------------+\n",
  155. "| 5 | A comedy-dram... | positive | [101, 1037, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
  156. "| 2 | This quiet , ... | positive | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
  157. "| 1 | A series of e... | negative | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
  158. "| 6 | The Importanc... | neutral | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
  159. "+------------+------------------+-----------+------------------+--------------------+--------------------+\n"
  160. ]
  161. }
  162. ],
  163. "source": [
  164. "import sys\n",
  165. "sys.path.append('..')\n",
  166. "\n",
  167. "import pandas as pd\n",
  168. "from functools import partial\n",
  169. "from fastNLP.transformers.torch import BertTokenizer\n",
  170. "\n",
  171. "from fastNLP import DataSet\n",
  172. "from fastNLP import Vocabulary\n",
  173. "from fastNLP.io import DataBundle\n",
  174. "\n",
  175. "\n",
  176. "class PipeDemo:\n",
  177. " def __init__(self, tokenizer='bert-base-uncased'):\n",
  178. " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
  179. "\n",
  180. " def process_from_file(self, path='./data/test4dataset.tsv'):\n",
  181. " datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n",
  182. " train_ds, test_ds = datasets.split(ratio=0.7)\n",
  183. " train_ds, dev_ds = datasets.split(ratio=0.8)\n",
  184. " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
  185. "\n",
  186. " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
  187. " return_attention_mask=True)\n",
  188. " data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
  189. " \n",
  190. " target_vocab = Vocabulary(padding=None, unknown=None)\n",
  191. "\n",
  192. " target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
  193. " target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
  194. " new_field_name='target')\n",
  195. "\n",
  196. " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
  197. " data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment') \n",
  198. " return data_bundle\n",
  199. "\n",
  200. " \n",
  201. "pipe = PipeDemo(tokenizer='bert-base-uncased')\n",
  202. "\n",
  203. "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')"
  204. ]
  205. },
  206. {
  207. "cell_type": "markdown",
  208. "id": "76e6b8ab",
  209. "metadata": {},
  210. "source": [
  211. "### 1.2 dataloader 的函数创建\n",
  212. "\n",
  213. "在`fastNLP 0.8`中,**更方便、可能更常用的`dataloader`创建方法是通过`prepare_xx_dataloader`函数**\n",
  214. "\n",
  215. "&emsp; 例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n",
  216. "\n",
  217. "&emsp; 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`"
  218. ]
  219. },
  220. {
  221. "cell_type": "code",
  222. "execution_count": 7,
  223. "id": "5fd60e42",
  224. "metadata": {},
  225. "outputs": [],
  226. "source": [
  227. "from fastNLP import prepare_torch_dataloader\n",
  228. "\n",
  229. "train_dataset = data_bundle.get_dataset('train')\n",
  230. "evaluate_dataset = data_bundle.get_dataset('dev')\n",
  231. "\n",
  232. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  233. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  234. ]
  235. },
  236. {
  237. "cell_type": "markdown",
  238. "id": "7c53f181",
  239. "metadata": {},
  240. "source": [
  241. "```python\n",
  242. "trainer = Trainer(\n",
  243. " model=model,\n",
  244. " train_dataloader=train_dataloader,\n",
  245. " optimizers=optimizer,\n",
  246. "\t...\n",
  247. "\tdriver='torch',\n",
  248. "\tdevice='cuda',\n",
  249. "\t...\n",
  250. " evaluate_dataloaders=evaluate_dataloader, \n",
  251. " metrics={'acc': Accuracy()},\n",
  252. "\t...\n",
  253. ")\n",
  254. "```"
  255. ]
  256. },
  257. {
  258. "cell_type": "markdown",
  259. "id": "9f457a6e",
  260. "metadata": {},
  261. "source": [
  262. "之所以称`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n",
  263. "\n",
  264. "&emsp; **是`DataBundle`类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n",
  265. "\n",
  266. "&emsp; 例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n",
  267. "\n",
  268. "&emsp; 在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n",
  269. "\n",
  270. "&emsp; &emsp; 这里也可以看出 **`evaluate_dataloaders`的妙处**,一次评测可以针对多个数据集"
  271. ]
  272. },
  273. {
  274. "cell_type": "code",
  275. "execution_count": 6,
  276. "id": "7827557d",
  277. "metadata": {},
  278. "outputs": [],
  279. "source": [
  280. "from fastNLP import prepare_paddle_dataloader\n",
  281. "\n",
  282. "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)"
  283. ]
  284. },
  285. {
  286. "cell_type": "markdown",
  287. "id": "d898cf40",
  288. "metadata": {},
  289. "source": [
  290. "```python\n",
  291. "trainer = Trainer(\n",
  292. " model=model,\n",
  293. " train_dataloader=dl_bundle['train'],\n",
  294. " optimizers=optimizer,\n",
  295. "\t...\n",
  296. "\tdriver='paddle',\n",
  297. "\tdevice='gpu',\n",
  298. "\t...\n",
  299. " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n",
  300. " metrics={'acc': Accuracy()},\n",
  301. "\t...\n",
  302. ")\n",
  303. "```"
  304. ]
  305. },
  306. {
  307. "cell_type": "markdown",
  308. "id": "d74d0523",
  309. "metadata": {},
  310. "source": [
  311. "## 2. fastNLP 中 dataloader 的延伸\n",
  312. "\n",
  313. "### 2.1 collator 的概念与使用\n",
  314. "\n",
  315. "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n",
  316. "\n",
  317. "&emsp; 进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n",
  318. "\n",
  319. "&emsp; 本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n",
  320. "\n",
  321. "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过"
  322. ]
  323. },
  324. {
  325. "cell_type": "code",
  326. "execution_count": null,
  327. "id": "651baef6",
  328. "metadata": {
  329. "pycharm": {
  330. "name": "#%%\n"
  331. }
  332. },
  333. "outputs": [],
  334. "source": [
  335. "from fastNLP import prepare_torch_dataloader\n",
  336. "\n",
  337. "dl_bundle = prepare_torch_dataloader(data_bundle, train_batch_size=2)\n",
  338. "\n",
  339. "print(type(dl_bundle), type(dl_bundle['train']))"
  340. ]
  341. },
  342. {
  343. "cell_type": "markdown",
  344. "id": "5f816ef5",
  345. "metadata": {},
  346. "source": [
  347. "&emsp; "
  348. ]
  349. },
  350. {
  351. "cell_type": "code",
  352. "execution_count": null,
  353. "id": "726ba357",
  354. "metadata": {
  355. "pycharm": {
  356. "name": "#%%\n"
  357. }
  358. },
  359. "outputs": [],
  360. "source": [
  361. "dataloader = prepare_torch_dataloader(datasets['train'], train_batch_size=2)\n",
  362. "print(type(dataloader))\n",
  363. "print(dir(dataloader))"
  364. ]
  365. },
  366. {
  367. "cell_type": "code",
  368. "execution_count": null,
  369. "id": "d0795b3e",
  370. "metadata": {
  371. "pycharm": {
  372. "name": "#%%\n"
  373. }
  374. },
  375. "outputs": [],
  376. "source": [
  377. "dataloader.collate_fn"
  378. ]
  379. },
  380. {
  381. "cell_type": "markdown",
  382. "id": "f9bbd9a7",
  383. "metadata": {},
  384. "source": [
  385. "### 2.2 sampler 的概念与使用"
  386. ]
  387. },
  388. {
  389. "cell_type": "code",
  390. "execution_count": null,
  391. "id": "b0c3c58d",
  392. "metadata": {
  393. "pycharm": {
  394. "name": "#%%\n"
  395. }
  396. },
  397. "outputs": [],
  398. "source": [
  399. "dataloader.batch_sampler"
  400. ]
  401. },
  402. {
  403. "cell_type": "markdown",
  404. "id": "51bf0878",
  405. "metadata": {},
  406. "source": [
  407. "&emsp; "
  408. ]
  409. },
  410. {
  411. "cell_type": "code",
  412. "execution_count": null,
  413. "id": "3fd2486f",
  414. "metadata": {
  415. "pycharm": {
  416. "name": "#%%\n"
  417. }
  418. },
  419. "outputs": [],
  420. "source": []
  421. }
  422. ],
  423. "metadata": {
  424. "kernelspec": {
  425. "display_name": "Python 3 (ipykernel)",
  426. "language": "python",
  427. "name": "python3"
  428. },
  429. "language_info": {
  430. "codemirror_mode": {
  431. "name": "ipython",
  432. "version": 3
  433. },
  434. "file_extension": ".py",
  435. "mimetype": "text/x-python",
  436. "name": "python",
  437. "nbconvert_exporter": "python",
  438. "pygments_lexer": "ipython3",
  439. "version": "3.7.13"
  440. },
  441. "pycharm": {
  442. "stem_cell": {
  443. "cell_type": "raw",
  444. "metadata": {
  445. "collapsed": false
  446. },
  447. "source": []
  448. }
  449. }
  450. },
  451. "nbformat": 4,
  452. "nbformat_minor": 5
  453. }