You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_4.ipynb 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "fdd7ff16",
  6. "metadata": {},
  7. "source": [
  8. "# T4. trainer 和 evaluator 的深入介绍\n",
  9. "\n",
  10. "  1   fastNLP 中的更多 metric 类型\n",
  11. "\n",
  12. "    1.1   预定义的 metric 类型\n",
  13. "\n",
  14. "    1.2   自定义的 metric 类型\n",
  15. "\n",
  16. "  2   fastNLP 中 trainer 的补充介绍\n",
  17. " \n",
  18. "    2.1   trainer 的提出构想 \n",
  19. "\n",
  20. "    2.2   trainer 的内部结构\n",
  21. "\n",
  22. "    2.3   实例:\n",
  23. "\n",
  24. "  3   fastNLP 中的 driver 与 device\n",
  25. "\n",
  26. "    3.1   driver 的提出构想\n",
  27. "\n",
  28. "    3.2   device 与多卡训练"
  29. ]
  30. },
  31. {
  32. "cell_type": "markdown",
  33. "id": "8d19220c",
  34. "metadata": {},
  35. "source": [
  36. "## 1. fastNLP 中的更多 metric 类型\n",
  37. "\n",
  38. "### 1.1 预定义的 metric 类型\n",
  39. "\n",
  40. "在`fastNLP 0.8`中,除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**,还有其他**预定义的评价标准`metric`**\n",
  41. "\n",
  42. "  包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
  43. "\n",
  44. "    **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**(其中也包括**召回率`Pre`**、**精确率`Rec`**\n",
  45. "\n",
  46. "    **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**;相关基本信息内容见下表,之后是详细分析\n",
  47. "\n",
  48. "| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
  49. "|:--|:--|:--|\n",
  50. "| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
  51. "| `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n",
  52. "| `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
  53. "| `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
  54. "| `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
  55. ]
  56. },
  57. {
  58. "cell_type": "markdown",
  59. "id": "fdc083a3",
  60. "metadata": {
  61. "pycharm": {
  62. "name": "#%%\n"
  63. }
  64. },
  65. "source": [
  66. "大概的描述一下,给出各个正确率的计算公式"
  67. ]
  68. },
  69. {
  70. "cell_type": "code",
  71. "execution_count": null,
  72. "id": "9775ea5e",
  73. "metadata": {},
  74. "outputs": [],
  75. "source": []
  76. },
  77. {
  78. "cell_type": "markdown",
  79. "id": "8a22f522",
  80. "metadata": {},
  81. "source": [
  82. "### 2.2 自定义的 metric 类型\n",
  83. "\n",
  84. "在`fastNLP 0.8`中,&emsp; 给一个案例,训练部分留到trainer部分"
  85. ]
  86. },
  87. {
  88. "cell_type": "code",
  89. "execution_count": null,
  90. "id": "d8caba1d",
  91. "metadata": {},
  92. "outputs": [],
  93. "source": []
  94. },
  95. {
  96. "cell_type": "code",
  97. "execution_count": null,
  98. "id": "4e6247dd",
  99. "metadata": {},
  100. "outputs": [],
  101. "source": []
  102. },
  103. {
  104. "cell_type": "markdown",
  105. "id": "08752c5a",
  106. "metadata": {
  107. "pycharm": {
  108. "name": "#%% md\n"
  109. }
  110. },
  111. "source": [
  112. "## 2. fastNLP 中 trainer 的补充介绍\n",
  113. "\n",
  114. "### 2.1 trainer 的提出构想\n",
  115. "\n",
  116. "在`fastNLP 0.8`中,&emsp; "
  117. ]
  118. },
  119. {
  120. "cell_type": "code",
  121. "execution_count": null,
  122. "id": "977a6355",
  123. "metadata": {},
  124. "outputs": [],
  125. "source": []
  126. },
  127. {
  128. "cell_type": "code",
  129. "execution_count": null,
  130. "id": "69203cdc",
  131. "metadata": {
  132. "pycharm": {
  133. "name": "#%%\n"
  134. }
  135. },
  136. "outputs": [],
  137. "source": []
  138. },
  139. {
  140. "cell_type": "markdown",
  141. "id": "ab1cea7d",
  142. "metadata": {},
  143. "source": [
  144. "### 2.2 trainer 的内部结构\n",
  145. "\n",
  146. "在`fastNLP 0.8`中,&emsp; \n",
  147. "\n",
  148. "'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n",
  149. "'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n",
  150. "'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n",
  151. "'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n",
  152. "'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n",
  153. "'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n",
  154. "'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n",
  155. "'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n",
  156. "'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n",
  157. "'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n",
  158. "'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n",
  159. "'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n",
  160. "'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n",
  161. "'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n",
  162. "'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n",
  163. "'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n",
  164. "'trainer_state', 'zero_grad'\n",
  165. "\n",
  166. "&emsp; run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)"
  167. ]
  168. },
  169. {
  170. "cell_type": "code",
  171. "execution_count": null,
  172. "id": "b3c8342e",
  173. "metadata": {
  174. "pycharm": {
  175. "name": "#%%\n"
  176. }
  177. },
  178. "outputs": [],
  179. "source": []
  180. },
  181. {
  182. "cell_type": "code",
  183. "execution_count": null,
  184. "id": "d28f2624",
  185. "metadata": {
  186. "pycharm": {
  187. "name": "#%%\n"
  188. }
  189. },
  190. "outputs": [],
  191. "source": []
  192. },
  193. {
  194. "cell_type": "markdown",
  195. "id": "ce6322b4",
  196. "metadata": {},
  197. "source": [
  198. "### 2.3 实例:\n",
  199. "\n",
  200. "在`fastNLP 0.8`中,&emsp; "
  201. ]
  202. },
  203. {
  204. "cell_type": "code",
  205. "execution_count": null,
  206. "id": "43be274f",
  207. "metadata": {
  208. "pycharm": {
  209. "name": "#%%\n"
  210. }
  211. },
  212. "outputs": [],
  213. "source": []
  214. },
  215. {
  216. "cell_type": "code",
  217. "execution_count": null,
  218. "id": "c348864c",
  219. "metadata": {
  220. "pycharm": {
  221. "name": "#%%\n"
  222. }
  223. },
  224. "outputs": [],
  225. "source": []
  226. },
  227. {
  228. "cell_type": "markdown",
  229. "id": "175d6ebb",
  230. "metadata": {},
  231. "source": [
  232. "## 3. fastNLP 中的 driver 与 device\n",
  233. "\n",
  234. "### 3.1 driver 的提出构想\n",
  235. "\n",
  236. "在`fastNLP 0.8`中,&emsp; "
  237. ]
  238. },
  239. {
  240. "cell_type": "code",
  241. "execution_count": null,
  242. "id": "47100e7a",
  243. "metadata": {
  244. "pycharm": {
  245. "name": "#%%\n"
  246. }
  247. },
  248. "outputs": [],
  249. "source": []
  250. },
  251. {
  252. "cell_type": "code",
  253. "execution_count": null,
  254. "id": "0204a223",
  255. "metadata": {
  256. "pycharm": {
  257. "name": "#%%\n"
  258. }
  259. },
  260. "outputs": [],
  261. "source": []
  262. },
  263. {
  264. "cell_type": "markdown",
  265. "id": "6e723b87",
  266. "metadata": {},
  267. "source": [
  268. "### 3.2 device 与多卡训练\n",
  269. "\n",
  270. "在`fastNLP 0.8`中,&emsp; "
  271. ]
  272. },
  273. {
  274. "cell_type": "code",
  275. "execution_count": null,
  276. "id": "5ad81ac7",
  277. "metadata": {
  278. "pycharm": {
  279. "name": "#%%\n"
  280. }
  281. },
  282. "outputs": [],
  283. "source": []
  284. },
  285. {
  286. "cell_type": "code",
  287. "execution_count": null,
  288. "id": "cfb28b1b",
  289. "metadata": {
  290. "pycharm": {
  291. "name": "#%%\n"
  292. }
  293. },
  294. "outputs": [],
  295. "source": []
  296. }
  297. ],
  298. "metadata": {
  299. "kernelspec": {
  300. "display_name": "Python 3 (ipykernel)",
  301. "language": "python",
  302. "name": "python3"
  303. },
  304. "language_info": {
  305. "codemirror_mode": {
  306. "name": "ipython",
  307. "version": 3
  308. },
  309. "file_extension": ".py",
  310. "mimetype": "text/x-python",
  311. "name": "python",
  312. "nbconvert_exporter": "python",
  313. "pygments_lexer": "ipython3",
  314. "version": "3.7.13"
  315. },
  316. "pycharm": {
  317. "stem_cell": {
  318. "cell_type": "raw",
  319. "metadata": {
  320. "collapsed": false
  321. },
  322. "source": []
  323. }
  324. }
  325. },
  326. "nbformat": 4,
  327. "nbformat_minor": 5
  328. }