Mark III Systems Blog

Deep Learning Model Multi-Node, Distributed Training Strategies (Primer)

Several different strategies have been developed for effectively pretraining and fine tuning large models in multi-GPU and multi-node environments. In this blog you will find a high-level overview of some of the most popular strategies including DDP, DeepSpeed ZeRO, and FSDP. All of these methods have been implemented as “strategies” for the PyTorch Lightning Trainer and links to the documentation for using them in Lightning are included in each section.

DDP - Distributed Data Parallel

https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html#distributed-data-parallel 

DDP allows one to utilize multiple GPUs for model training. When the training process is initialized, the model is loaded to one GPU. The model is then replicated from the first GPU to all other GPUs you wish to use for training. For each training batch, each GPU is sent a mini-batch of the data which it uses to train with its own version of the model. Then once training on the mini-batch is complete, all the GPUs send back their local gradients which are then averaged. Then all the local models on the GPUs are updated with these average gradients and the cycle begins again.

This is one of the simpler training strategies available for multi-device training but one must be able to fit the whole model on one GPU. This can be limiting when training very large models or using small GPUs.

DeepSpeed ZeRO-DP

https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DeepSpeedStrategy.html#lightning.pytorch.strategies.DeepSpeedStrategy 

ZeRO-DP (Zero Redundancy Optimizer Data Parallelism)  is a training strategy introduced in the DeepSpeed open source library by Microsoft. The goal of ZeRO was to develop a strategy that could handle very large models with billions or even a trillion parameters. It builds on the idea we have seen in DDP while introducing new methods to keep memory usage in check. This is accomplished by taking the parts of the training process that use up the GPU memory and sharding, or distributing them among all the available GPUs. This means that we can still accomplish our training even if the data required for training is larger than the VRAM of one of our GPUs. There are three different levels of ZeRO implemented in most frameworks known as stage 1, 2, and 3. Each of these stages vary in what they distribute over all the GPUs. 

Stage 1 ZeRO is limited to only partitioning the optimizer states between the available devices. While this is helpful it is generally recommended to skip stage 1 and go right to stage 2 or 3.

Stage 2 ZeRO adds on to stage 1 by also partitioning the gradients between available devices. This is most likely the optimal stage to use if your model is small enough to fit on one GPU.

Stage 3 ZeRO adds on to stage 2 by also partitioning the model parameters between available devices. This is what allows ZeRO to train extremely large models that don’t fit on one GPU. However it is important to remember that one takes a performance hit because of the increased communication required to share all this information between devices. Network speed between your devices can become a major bottleneck with ZeRO. For example, I ran a two-node ZeRO 3 training job in an environment that had only 1000MB/s network speed between the two nodes. This training saw an enormous slowdown compared to doing the same job on only one of the nodes due to this networking speed and all the information being passed around with ZeRO 3. See the blog post on benchmarking Falcon-7B for more details.

FSDP (Fully Sharded Data Parallel) 

https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#fully-sharded-training 

As the name implies, FSDP serves the same purpose as DeepSpeed ZeRO3. It fully shards the model so that a large model can be trained in a multi-device environment without hitting OOM (out of memory) errors. FSDP was developed by Meta and is meant to be easier to get started with than ZeRO 3 when already using the PyTorch framework. However it is currently considered an experimental feature in PyTorch Lightning.