{ "cells": [ { "cell_type": "markdown", "id": "fdd7ff16", "metadata": {}, "source": [ "# T6. fastNLP 与 paddle 或 jittor 的结合\n", "\n", "  1   fastNLP 结合 paddle 训练模型\n", " \n", "    1.1   关于 paddle 的简单介绍\n", "\n", "    1.2   使用 paddle 搭建并训练模型\n", "\n", "  2   fastNLP 结合 jittor 训练模型\n", "\n", "    2.1   关于 jittor 的简单介绍\n", "\n", "    2.2   使用 jittor 搭建并训练模型\n", "\n", "  3   fastNLP 实现 paddle 与 pytorch 互转" ] }, { "cell_type": "code", "execution_count": null, "id": "08752c5a", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "sst2data = load_dataset('glue', 'sst2')" ] }, { "cell_type": "code", "execution_count": null, "id": "7e8cc210", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "from fastNLP import DataSet\n", "\n", "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", "\n", "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", " progress_bar=\"tqdm\")\n", "dataset.delete_field('sentence')\n", "dataset.delete_field('label')\n", "dataset.delete_field('idx')\n", "\n", "from fastNLP import Vocabulary\n", "\n", "vocab = Vocabulary()\n", "vocab.from_dataset(dataset, field_name='words')\n", "vocab.index_dataset(dataset, field_name='words')\n", "\n", "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", "print(type(train_dataset), isinstance(train_dataset, DataSet))\n", "\n", "from fastNLP.io import DataBundle\n", "\n", "data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})" ] }, { "cell_type": "markdown", "id": "57a3272f", "metadata": {}, "source": [ "## 1. fastNLP 结合 paddle 训练模型\n", "\n", "```python\n", "import paddle\n", "\n", "lstm = paddle.nn.LSTM(16, 32, 2)\n", "\n", "x = paddle.randn((4, 23, 16))\n", "h = paddle.randn((2, 4, 32))\n", "c = paddle.randn((2, 4, 32))\n", "\n", "y, (h, c) = lstm(x, (h, c))\n", "\n", "print(y.shape) # [4, 23, 32]\n", "print(h.shape) # [2, 4, 32]\n", "print(c.shape) # [2, 4, 32]\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "e31b3198", "metadata": {}, "outputs": [], "source": [ "import paddle\n", "import paddle.nn as nn\n", "\n", "\n", "class ClsByPaddle(nn.Layer):\n", " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", " nn.Layer.__init__(self)\n", " self.hidden_dim = hidden_dim\n", "\n", " self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n", " # self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n", " # num_layers=num_layers, direction='bidirectional', dropout=dropout)\n", " self.mlp = nn.Sequential(('linear_1', nn.Linear(hidden_dim * 2, hidden_dim * 2)),\n", " ('activate', nn.ReLU()),\n", " ('linear_2', nn.Linear(hidden_dim * 2, output_dim)))\n", " \n", " self.loss_fn = nn.CrossEntropyLoss()\n", "\n", " def forward(self, words):\n", " output = self.embedding(words)\n", " # output, (hidden, cell) = self.lstm(output)\n", " hidden = paddle.randn((2, words.shape[0], self.hidden_dim))\n", " output = self.mlp(paddle.concat((hidden[-1], hidden[-2]), axis=1))\n", " return output\n", " \n", " def train_step(self, words, target):\n", " pred = self(words)\n", " return {\"loss\": self.loss_fn(pred, target)}\n", "\n", " def evaluate_step(self, words, target):\n", " pred = self(words)\n", " pred = paddle.max(pred, axis=-1)[1]\n", " return {\"pred\": pred, \"target\": target}" ] }, { "cell_type": "code", "execution_count": null, "id": "c63b030f", "metadata": {}, "outputs": [], "source": [ "model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", "\n", "model" ] }, { "cell_type": "code", "execution_count": null, "id": "2997c0aa", "metadata": {}, "outputs": [], "source": [ "from paddle.optimizer import AdamW\n", "\n", "optimizers = AdamW(parameters=model.parameters(), learning_rate=1e-2)" ] }, { "cell_type": "code", "execution_count": null, "id": "ead35fb8", "metadata": {}, "outputs": [], "source": [ "from fastNLP import prepare_paddle_dataloader\n", "\n", "# train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n", "# evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n", "\n", "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "25e8da83", "metadata": {}, "outputs": [], "source": [ "from fastNLP import Trainer, Accuracy\n", "\n", "trainer = Trainer(\n", " model=model,\n", " driver='paddle',\n", " device='gpu', # 'cpu', 'gpu', 'gpu:x'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " train_dataloader=dl_bundle['train'], # train_dataloader,\n", " evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n", " metrics={'acc': Accuracy()}\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "d63c5d74", "metadata": {}, "outputs": [], "source": [ "trainer.run(num_eval_batch_per_dl=10) # 然后卡了?" ] }, { "cell_type": "markdown", "id": "cb9a0b3c", "metadata": {}, "source": [ "## 2. fastNLP 结合 jittor 训练模型" ] }, { "cell_type": "code", "execution_count": null, "id": "c600191d", "metadata": {}, "outputs": [], "source": [ "import jittor\n", "import jittor.nn as nn\n", "\n", "from jittor import Module\n", "\n", "\n", "class ClsByJittor(Module):\n", " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", " Module.__init__(self)\n", " self.hidden_dim = hidden_dim\n", "\n", " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n", " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", " self.mlp = nn.Sequential([nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", " nn.ReLU(),\n", " nn.Linear(hidden_dim * 2, output_dim)])\n", "\n", " self.loss_fn = nn.BCELoss()\n", "\n", " def execute(self, words):\n", " output = self.embedding(words)\n", " output, (hidden, cell) = self.lstm(output)\n", " # hidden = jittor.randn((2, words.shape[0], self.hidden_dim))\n", " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), axis=1))\n", " return output\n", " \n", " def train_step(self, words, target):\n", " pred = self(words)\n", " return {\"loss\": self.loss_fn(pred, target)}\n", "\n", " def evaluate_step(self, words, target):\n", " pred = self(words)\n", " pred = jittor.max(pred, axis=-1)[1]\n", " return {\"pred\": pred, \"target\": target}" ] }, { "cell_type": "code", "execution_count": null, "id": "a94ed8c4", "metadata": {}, "outputs": [], "source": [ "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", "\n", "model" ] }, { "cell_type": "code", "execution_count": null, "id": "6d15ebc1", "metadata": {}, "outputs": [], "source": [ "from jittor.optim import AdamW\n", "\n", "optimizers = AdamW(params=model.parameters(), lr=1e-2)" ] }, { "cell_type": "code", "execution_count": null, "id": "95d8d09e", "metadata": {}, "outputs": [], "source": [ "from fastNLP import prepare_jittor_dataloader\n", "\n", "# train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", "# evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", "\n", "dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "917eab81", "metadata": {}, "outputs": [], "source": [ "from fastNLP import Trainer, Accuracy\n", "\n", "trainer = Trainer(\n", " model=model,\n", " driver='jittor',\n", " device='gpu', # 'cpu', 'gpu', 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " train_dataloader=dl_bundle['train'], # train_dataloader,\n", " evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n", " metrics={'acc': Accuracy()}\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f7c4ac5a", "metadata": {}, "outputs": [], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }