Browse Source

在文档中增加了分布式训练的介绍(by yunfan)

tags/v0.5.5
ChenXin 5 years ago
parent
commit
7cdd2ee862
4 changed files with 19 additions and 21 deletions
  1. +16
    -19
      docs/source/tutorials/extend_2_dist.rst
  2. +0
    -0
      docs/source/tutorials/extend_3_fitlog.rst
  3. +2
    -1
      docs/source/user/tutorials.rst
  4. +1
    -1
      fastNLP/core/callback.py

tutorials/dist.rst → docs/source/tutorials/extend_2_dist.rst View File

@@ -5,36 +5,34 @@ Distributed Parallel Training
----

随着深度学习模型越来越复杂,单个GPU可能已经无法满足正常的训练。比如BERT等预训练模型,更是在多个GPU上训练得到的。为了使用多GPU训练,Pytorch框架已经提供了
```nn.DataParallel`` <https://pytorch.org/docs/stable/nn.html#dataparallel>`__
以及
```nn.DistributedDataParallel`` <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`__
两种方式的支持。
```nn.DataParallel`` <https://pytorch.org/docs/stable/nn.html#dataparallel>`__
`nn.DataParallel <https://pytorch.org/docs/stable/nn.html#dataparallel>`_ 以及
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_ 两种方式的支持。
`nn.DataParallel <https://pytorch.org/docs/stable/nn.html#dataparallel>`_
很容易使用,但是却有着GPU负载不均衡,单进程速度慢等缺点,无法发挥出多GPU的全部性能。因此,分布式的多GPU训练方式
```nn.DistributedDataParallel`` <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`__
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_
是更好的选择。然而,因为分布式训练的特点,
```nn.DistributedDataParallel`` <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`__
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_
常常难以理解和使用,也很难debug。所以,在使用分布式训练之前,需要理解它的原理。

在使用
```nn.DistributedDataParallel`` <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`__
`nn.DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#distributeddataparallel>`_
时,模型会被复制到所有使用的GPU,通常每个GPU上存有一个模型,并被一个单独的进程控制。这样有N块GPU,就会产生N个进程。当训练一个batch时,这一batch会被分为N份,每个进程会使用batch的一部分进行训练,然后在必要时进行同步,并通过网络传输需要同步的数据。这时,只有模型的梯度会被同步,而模型的参数不会,所以能缓解大部分的网络传输压力,网络传输不再是训练速度的瓶颈之一。你可能会好奇,不同步模型的参数,怎么保证不同进程所训练的模型相同?只要每个进程初始的模型是同一个,具有相同的参数,而之后每次更新,都使用相同的梯度,就能保证梯度更新后的模型也具有相同的参数了。

为了让每个进程的模型初始化完全相同,通常这N个进程都是由单个进程复制而来的,这时需要对分布式的进程进行初始化,建立相互通信的机制。在
Pytorch 中,我们用
```distributed.init_process_group`` <https://pytorch.org/docs/stable/distributed.html#initialization>`__
`distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
函数来完成,需要在程序开头就加入这一步骤。初始化完成后,每一个进程用唯一的编号
``rank`` 进行区分,从 0 到 N-1递增,一般地,我们将 ``rank`` 为 0
的进程当作主进程,而其他 ``rank`` 的进程为子进程。每个进程还要知道
``world_size`` ,即分布式训练的总进程数
N。训练时,每个进程使用batch的一部分,互相不能重复,这里通过
```nn.utils.data.DistributedSampler`` <https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html>`__
`nn.utils.data.DistributedSampler <https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html>`_
来实现。

使用方式
--------

Pytorch的分布式训练使用起来非常麻烦,难以理解,可以从给出的\ `官方教程 <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__\ 中看到。而\ ``fastNLP``
Pytorch的分布式训练使用起来非常麻烦,难以理解,可以从给出的\ `官方教程 <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ \ 中看到。而\ ``fastNLP``
提供了
``DistTrainer``\ ,将大部分的分布式训练的细节进行了封装,只需简单的改动训练代码,就能直接用上分布式训练。那么,具体怎么将普通的训练代码改成支持分布式训练的代码呢。下面我们来讲一讲分布式训练的完整流程。通常,分布式程序的多个进程是单个进程的复制。假设我们用N个GPU进行分布式训练,我们需要启动N个进程,这时,在命令行使用:

@@ -110,9 +108,8 @@ Pytorch的分布式训练使用起来非常麻烦,难以理解,可以从给
初始化进程
^^^^^^^^^^

在获取了\ ``local_rank``\ 等重要参数后,在开始训练前,我们需要建立不同进程的通信和同步机制。这时我们使用\ ```torch.distributed.init_process_group`` <https://pytorch.org/docs/stable/distributed.html#initialization>`__
来完成。通常,我们只需要
``torch.distributed.init_process_group('nccl')``
在获取了\ ``local_rank``\ 等重要参数后,在开始训练前,我们需要建立不同进程的通信和同步机制。这时我们使用\ `torch.distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
来完成。通常,我们只需要 ``torch.distributed.init_process_group('nccl')``
来指定使用\ ``nccl``\ 后端来进行同步即可。其他参数程序将读取环境变量自动设置。如果想手动设置这些参数,比如,使用TCP进行通信,可以设置:

.. code:: python
@@ -128,7 +125,7 @@ Pytorch的分布式训练使用起来非常麻烦,难以理解,可以从给
world_size=N, rank=args.rank)

注意,此时必须显式指定\ ``world_size``\ 和\ ``rank``\ ,具体可以参考
```torch.distributed.init_process_group`` <https://pytorch.org/docs/stable/distributed.html#initialization>`__
`torch.distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
的使用文档。

在初始化分布式通信后,再初始化\ ``DistTrainer``\ ,传入数据和模型,就完成了分布式训练的代码。代码修改完成后,使用上面给出的命令行启动脚本,就能成功运行分布式训练。但是,如果数据处理,训练中的自定义操作比较复杂,则可能需要额外的代码修改。下面列出一些需要特别注意的地方,在使用分布式训练前,请仔细检查这些事项。
@@ -137,12 +134,12 @@ Pytorch的分布式训练使用起来非常麻烦,难以理解,可以从给
--------

在执行完
```torch.distributed.init_process_group`` <https://pytorch.org/docs/stable/distributed.html#initialization>`__
后,我们就可以在不同进程间完成传输数据,进行同步等操作。这些操作都可以在\ ```torch.distributed`` <https://pytorch.org/docs/stable/distributed.html#>`__
`torch.distributed.init_process_group <https://pytorch.org/docs/stable/distributed.html#initialization>`_
后,我们就可以在不同进程间完成传输数据,进行同步等操作。这些操作都可以在\ `torch.distributed <https://pytorch.org/docs/stable/distributed.html#>`_
中找到。其中,最重要的是
```barrier`` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>`__
`barrier <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>`_
以及
```get_rank`` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.get_rank>`__
`get_rank <https://pytorch.org/docs/stable/distributed.html#torch.distributed.get_rank>`_
操作。对于训练而言,我们关心的是读入数据,记录日志,模型初始化,模型参数更新,模型保存等操作。这些操作大多是读写操作,在多进程状态下,这些操作都必须小心进行,否则可能出现难以预料的bug。而在\ ``fastNLP``\ 中,大部分操作都封装在
``DistTrainer`` 中,只需保证数据读入和模型初始化正确即可完成训练。


docs/source/tutorials/extend_2_fitlog.rst → docs/source/tutorials/extend_3_fitlog.rst View File


+ 2
- 1
docs/source/user/tutorials.rst View File

@@ -21,4 +21,5 @@ fastNLP 详细使用教程
:maxdepth: 1

拓展阅读1:BertEmbedding的各种用法 </tutorials/extend_1_bert_embedding>
拓展阅读2:使用fitlog 辅助 fastNLP 进行科研 </tutorials/extend_2_fitlog>
拓展阅读2:分布式训练简介 </tutorials/extend_2_dist>
拓展阅读3:使用fitlog 辅助 fastNLP 进行科研 </tutorials/extend_3_fitlog>

+ 1
- 1
fastNLP/core/callback.py View File

@@ -809,7 +809,7 @@ class TensorboardCallback(Callback):
.. warning::
fastNLP 已停止对此功能的维护,请等待 fastNLP 兼容 PyTorch1.1 的下一个版本。
或者使用和 fastNLP 高度配合的 fitlog(参见 :doc:`/tutorials/extend_2_fitlog` )。
或者使用和 fastNLP 高度配合的 fitlog(参见 :doc:`/tutorials/extend_3_fitlog` )。
"""


Loading…
Cancel
Save