Distributed Training in PyTorch (Distributed Data Parallel)

Praneet Bomma
Analytics Vidhya
Published in
6 min readApr 17, 2021

--

Today we will be covering Distributed Data Parallel in PyTorch which can be used to distribute data across GPUs to train the model with multiple GPUs.

Photo by Taylor Vick on Unsplash

Just like we use threading in our programs to carry out tasks in a parallel manner to complete it in the quickest time possible, we can use a similar technique to train the Deep Learning models in PyTorch parallely across GPUs. All we need is a system with multiple GPUs or multiple systems with multiple GPUs in each of them. Let’s see how to achieve parallelism in training our models.

In order to start doing it, we need to accustom ourselves to some key-words that we may use in the write-up ahead. We will be covering these points in the given order.

  1. Data Parallel in PyTorch
  2. Global Interpreter Lock (GIL)
  3. Distributed Data Parallel in PyTorch
  4. Key-words
  5. Implementation

Data Parallel in PyTorch

If we quickly want to get started with Distributed Training, we can use Data Parallel in PyTorch which uses threading to achieve parallel training.

All we need to do is add 1 line as given below in your script and PyTorch will handle the parallelism for us. We will basically be adding a wrapper over our model to let PyTorch know that it needs to parallelized.

model = torch.nn.DataParallel(model)

As Data Parallel uses threading to achieve parallelism, it suffers from a major well-known issue that arise due to Global Interpreter Lock (GIL) in Python. The way Python interpreter is designed, it is not possible to achieve perfect parallelism in Python using threading. Let’s see what GIL is.

Global Interpreter Lock (GIL)

As I mentioned earlier, the way Python interpreter is implemented, it is very difficult to achieve perfect parallelism using threading. This is due to something called Global Interpreter Lock.

GIL

The Python Global Interpreter Lock or GIL, in simple words, is a mutex (or a lock) that allows only one thread to hold the control of the Python interpreter. Only one thread can be in a state of execution at any point in time.

Mutex

Mutex is a mutual exclusion object that synchronizes access to a resource. It is created with a unique name at the start of a program. The Mutex is a locking mechanism that makes sure only one thread can acquire the Mutex at a time and enter the critical section.

This basically defeats the whole purpose of using threading in the first place. Which is why we have something in PyTorch that can be used to achieve perfect parallelism.

Distributed Data Parallel in PyTorch

DDP in PyTorch does the same thing but in a much proficient way and also gives us better control while achieving perfect parallelism. DDP uses multiprocessing instead of threading and executes propagation through the model as a different process for each GPU. DDP duplicates the model across multiple GPUs, each of which is controlled by one process. A process here can be called a script that runs on your system. Usually we spawn processes such that there is a separate process for each GPU.

Each of the process here does identical tasks but with different batch of data. Each process communicates with other processes to share gradients which needs to be all-reduced during the optimization step. At the end of an optimization step each process has averaged gradients, ensuring that the model weights stay synchronized.

Key-words

“node” is a system in your distributed architecture. In lay man’s terms, a single system that has multiple GPUs can be called as a node.

“global rank” is a unique identification number for each node in our architecture.

“local rank” is a unique identification number for processes in each node.

“world” is a union of all of the above which can have multiple nodes where each node spawns multiple processes. (Ideally, one for each GPU)

“world_size” is equal to number of nodes * number of gpus

Implementation

Let’s implement a simple example and walk-through the important changes required to train a model in a distributed architecture across GPUs! We will be implementing a simple classifier which classifies handwritten digits that is trained using the very famous MNIST dataset. I won’t be going through all the details of writing a neural network and other basic things as this write-up assumes you to have a basic idea of training a simple model in PyTorch. Without further ado, let’s start!

Starting with first and foremost, importing necessary libraries.

Let’s write a simple convnet that can classify a given input into ten classes.

Now let’s implement the main function that will take in command line arguments that define the node id, number of gpus, ranks and other things. We need these arguments to make sure each process can communicate with the Master Node to reduce the gradients and get back the averaged gradients. Each process needs to know which GPU to use and where it ranks amongst all the processes that are running.

We will be passing arguments like — node id with ‘-n’, number of gpus with -g’, global rank with ‘-nr’ and number of epochs with ‘ — epochs’.

The world_size is calculated by multiplying number of gpus and number of nodes. The MASTER_ADDRESS and MASTER_PORT are set as environment variables which can be directly accessed by all the processes on the node.

Finally, we spawn the train function that does the most important part in the script, i.e. train the model. nprocs defines the number of processes to spawn which is equal to args.gpus . The train function while it is spawned, by default, gets an argument which is used to identify a single process from all the spawned processes. This number is an integer that ranges from 0 to args.gpus — 1. This number can be used to identify the GPU number as well as device numbers that lie in the same range.

Let’s finally, implement the train function to see how we use these arguments extensively.

Initially, we calculate the rank of the worker using args.nr * args.gpus + gpu If this is the first worker on the second node which has 2 GPUs, the numbers here will be: args.nr = 1as it is the second node, args.gpus = 2 as there are 2 GPUs on the node and gpu = 0 as this is the first worker which is using the first GPU on the second node. Calculating the rank here will therefore be equal to 2 which is nothing but our local rank among all the processes.

Now we initialize the process group with nccl which is NVIDIA Collective Communication Library that implements multi-GPU and multi-node communication primitives optimized for NVIDIA GPUs and Networking. We also pass in init_method = 'env://' which basically says that, the information like MASTER_ADDRESS and MASTER_PORT should be taken from the environment variables. We specify the world_size and the local_rank that we calculated.

On line 21, we wrap our model with PyTorch’s DistributedDataParallel class which takes care of the model cloning and parallel training.

On line 31, we initialize a sampler that can take care of the distributed sampling of batches on different GPUs without repeating any batch. This is done using DistributedSampler . We also pass in this sampler in the next line to the DataLoader that can make use of the sampler while batching the data.

This is how we train our model with Distributed Data on multiple GPUs. You can get the complete code on this link! Please comment down if I need to elaborate on any points or if there are any corrections needed.

You can also refer to a blog by neptune.ai here on various frameworks and tools available for distributed training.

References

--

--