A Short Guide to PyTorch DDP¶
In this blog post, we explore what
torchrun
and
DistributedDataParallel
are and how they can be used to speed up your neural network training by using
multiple GPUs.
Neural networks, or even deep neural networks, are popular models for machine learning. Mathematically, they can be interpreted as nested functions with millions of parameters. If the parameters are tuned well, they can be used to make predictions, such as when given a photo, it predicts what that photo contains. A famous example is Google Lens. These parameters are tuned by adjusting them using data. For example, if you show the network a photo of a dog and point out it's a dog, the parameters are adjusted to make it likely that the next time it sees the same photo, it will recognise that it's a photo of a dog. This is done for millions of photos.
The Python package PyTorch can be used to specify a neural network and train it with data on a GPU. Even better, it can be trained using multiple GPUs, speeding up the training process.
Basic Concept of torchrun
¶
After installing torch
on your virtual environment (installation instructions
are available on PyTorch's website) the command
torchrun
is usually available. It allows you to launch multiple workers, each
running a copy of your Python script and they can interact with each other.
You can try a simple example on your own computer to understand this. Consider
a hello world script called hello_world.py
import os
RANK = int(os.environ["RANK"])
def say_hello(name):
print(f"Hello {name}")
if __name__ == "__main__":
names = [
"Barry",
"Alice",
"Barbara",
"Tom",
]
say_hello(names[RANK])
You can launch four instances of hello_world.py
, each with a different value
of RANK
, thus greeting a different person. This can be done with
torchrun --nproc-per-node 4 hello_world.py
This is pretty much the same as mpirun
as described in a
previous blog post.
Useful environment variables
Information about a worker can be obtained by using os.environ[]
as shown
above. Here are more useful environment variables:
RANK
- The rank of the worker within a worker group.WORLD_SIZE
- The total number of workers in a worker group.LOCAL_RANK
- The rank of the worker within a local worker group.
There are more variables in PyTorch's documentation.
To use torchrun
on Apocrita, launching a worker for each requested GPU, use
the following in your job script
torchrun --nproc-per-node gpu --rdzv-backend=c10d --rdzv-endpoint=localhost:0 \
training_script training_script_args
where the positional arguments are
positional arguments:
training_script Full path to the (single GPU) training program/script to
be launched in parallel, followed by all the arguments
for the training script.
training_script_args Arguments to pass to your script, these can be, for
example, programmed using argparse
Ensure you use a free port
The options --rdzv-backend=c10d
and --rdzv-endpoint=localhost:0
should
be used on Apocrita to ensure there are no port clashes should multiple
users be running torchrun
on the same node. This automatically assigns a
free port to your job, otherwise, you may encounter a
torch.distributed.DistNetworkError
exception. See PyTorch's
documentation for more
information.
Assign a worker for each GPU
The option --nproc-per-node gpu
will automatically launch a process for
each requested GPU.
Single node multiple GPUs jobs on Apocrita
At the time of writing, only single-node multiple GPUs jobs are available on
Apocrita. Up to 4 GPUs can be requested this way. torchrun
does work
with multiple nodes of GPUs but this is beyond the scope of this blog.
With torchrun
, you can use PyTorch's features such as
DistributedDataParallel
. This is a way to parallelise your machine learning
code across multiple GPUs, and even multiple nodes of GPUs.
GPU Architecture for Machine Learning¶
In this section, we recap the architecture of a GPU and how it applies to typical PyTorch code. Figure 1 shows a photo of a commercial grade graphics card, which houses many components such as a GPU chip, memory, a heat sink, fans and many more. If it isn't needed to distinguish the components, quite commonly the entire card is referred to as the GPU.
For this blog, we will focus on the GPU and the memory housed on the card, commonly referred to as video RAM (VRAM). Figure 2 shows an illustration of this. Commercial grade cards, such as the GTX and RTX series, typically have VRAM of about 2-24 GB. Whereas enterprise cards such as our A100 and H100 cards have either 40 GB or 80 GB of VRAM. The amount of VRAM the card you're using is important as it will need to contain the parameters of the neural network you want to train. Small and older neural networks such as AlexNet can fit on a commercial grade card, but larger newer models, like RegNet and ConvNeXt typically require enterprise cards.
It is also worth pointing out that datasets used in machine learning, such as ImageNet, are in the order of hundreds or thousands of GB. It just isn't possible to fit it in VRAM. Thus typically, batches of data are loaded to VRAM, one at a time, when training a neural network. Optimisation methods, such as stochastic gradient descent, can use batches of data to train the neural network. Figure 2 illustrates how VRAM contains the parameters of the neural network and a batch of data. Figure 3 illustrates how a dataset can be split into batches, a GPU will train the neural network one batch of data at a time.
Assessing how much VRAM you need
Use tools such as nvidia-smi
, nvtop
and nvitop
to monitor your GPU
utility and how much VRAM is being used. They are available on Apocrita,
please see the
documentation.
If you encounter memory errors, you may need a GPU with more VRAM. It is also possible to fit your model across multiple GPUs but this is beyond the scope of this blog.
How DistributedDataParallel
Uses Multiple GPUs¶
The idea behind DistributedDataParallel
is that each GPU trains the same
neural network using different batches of data in parallel, getting through all
of the batches faster. Figure 4 illustrates this.
It should be noted that this is not a pleasingly parallel task. In a previous blog post, we looked at pleasingly parallel GPU tasks. They are so because each GPU can work independently, working on different models or parameters.
However, in the case of DistributedDataParallel
, each GPU is training the same
model and is working together, not independently. After each GPU works out the
gradient for a batch, they are collated together and used to update the
parameters of the neural network. The updated neural network is sent to all
GPUs, ready to work on the next batch. This is illustrated in Figure 5. This
cycle continues until some stopping condition is met. This type of
parallelisation can be described as tightly coupled.
Using DistributedDataParallel
With Your Code¶
For reference, PyTorch has documentation on DistributedDataParallel
such as in
their API
documentation,
their beginner's
tutorial and
their intermediate's
tutorial. PyTorch
also has example code on their
GitHub.
There are also blog posts such as one by Kevin Kaichuang
Yang
and Jackson
Kek,
the latter being my recommendation to read.
When reading these posts, it may be useful to note the following:
- The alternate to
DistributedDataParallel
isDataParallel
. For performance reasons, it should be noted that usingDistributedDataParallel
should be favoured over usingDataParallel
. See PyTorch's comparison for further information. torchrun
was introduced at around v1.10. Previously, you would run the library modulepython -m torch.distributed.launch
.- On Apocrita where GPU jobs are only limited to one node, you do not need to
set the variables
os.environ['MASTER_ADDR']
andos.environ['MASTER_PORT']
.
In this blog post, I'll provide some steps to help rewrite your existing
PyTorch code to use DistributedDataParallel
and then a full example script.
Steps To Use DistributedDataParallel
¶
Get information about each worker¶
We first get information about each worker, such as the rank and the world size.
RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
MASTER_RANK = 0
I personally like to set them as global variables to make them easier to access in the middle of the code. They shouldn't need to change value so their content should be crystal clear during the run time of the software.
I've also defined MASTER_RANK
to designate a worker to be the master.
When to use the master rank
When doing tasks which only require one worker, such as printing or logging, only one worker should do that task whereas the rest do not. This can be done by using a conditional, for example
if RANK == MASTER_RANK:
print("Here's some information")
The difference between LOCAL_RANK
and RANK
The RANK
will be unique for each worker whereas LOCAL_RANK
will be
unique for each worker in a group. Typically in multi-node jobs,
LOCAL_RANK
is used to identify each GPU on a node. Apocrita does not
support multi-node GPU jobs, however, the distinction is important should
your code need to scale up.
Set the GPU¶
We set the GPU, or device, we want a worker to use. Typically we want one GPU
per worker so we use torch.device(f"cuda:{LOCAL_RANK}")
to assign a unique GPU
to each worker. This is then followed by init_process_group()
to make the
workers aware of each other.
device = torch.device(f"cuda:{LOCAL_RANK}")
torch.cuda.set_device(device)
torch.distributed.init_process_group(
backend="nccl", world_size=WORLD_SIZE, rank=RANK)
torch.distributed.barrier()
The nccl
backend
We choose the nccl
backend as we find this works best for multiple GPUs.
See the PyTorch's
documentation for further
information.
Move the model to the GPU and wrap it with DistributedDataParallel
model.to(device)
model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
Running code on a CPU instead
It is possible to run your code on multiple CPU cores instead of multiple
GPUs by using device = torch.device("cpu")
and tinkering the code a bit
more. You may want your code to be runnable on either CPU or GPU so that it
can be compatible with different systems. However, this can bloat your code.
Plan accordingly.
Set the DataLoader
and DistrbutedSampler
¶
Use a DistributedSampler
, this ensures each GPU is allocated different data
points. This is provided to your usual DataLoader
.
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
# set data loader
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=num_workers,
pin_memory=True
)
Setting num_workers
The argument num_workers
sets how many subprocesses to use. Remember that
when using torchrun
, multiple copies of your script are executed. Thus, to
avoid overthreading, num_workers
should be no bigger than the number of
CPU cores per worker (or GPU) minus one.
At the time of writing, on Apocrita, set this argument to or less than 11.
I recommend doing this programmatically. For example, num_workers
can be
an argument for your script. Or specifically for Apocrita
num_workers = int(os.getenv("NSLOTS")) // torch.cuda.device_count() - 1
Train your model¶
Train your model as usual. By default, the DistributedSampler
randomise the
data. We provide set_epoch()
what epoch we are on so that for every epoch,
different randomised data are distributed to each GPU.
model.train()
for epoch in range(n_epoch):
train_sampler.set_epoch(epoch)
for image, target in data_loader:
image = image.to(device)
target = target.to(device)
output = model(image)
...
Run your code¶
Once your script has been modified, run your script with torchrun
as explained
in the previous section.
torchrun --nproc-per-node gpu --rdzv-backend=c10d --rdzv-endpoint=localhost:0 \
training_script training_script_args
Example Script¶
If you're still stuck, you can study and run the example script below. We wrote our own neural network model and trained it on the MNIST dataset (the dataset was studied previously in a blog post).
The code demonstrates how to use DistributedDataParallel
and
DistributedSampler
and can run on Apocrita with one or more GPUs. We've also
restricted downloading the MNIST dataset and printing to the master process to
avoid downloading the dataset more than once. The function
torch.distributed.reduce
is used to gather results from all GPUs.
To run the script, call
torchrun --nproc-per-node gpu --rdzv-backend=c10d --rdzv-endpoint=localhost:0 \
training_script
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import os
RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
MASTER_RANK = 0
N_EPOCH = 5
BATCH_SIZE = 100
# Our own custom neural network on the MNIST dataset
class Net(nn.Module):
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
def main():
device = torch.device(f"cuda:{LOCAL_RANK}")
torch.cuda.set_device(device)
torch.distributed.init_process_group(backend="nccl", world_size=WORLD_SIZE,
rank=RANK)
torch.distributed.barrier()
model = Net()
model.to(device)
model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
# place this condition so only one process downloads the mnist dataset
if RANK == MASTER_RANK:
dataset = torchvision.datasets.MNIST(
root='.', train=True, transform=transforms.ToTensor(),
download=True)
torch.distributed.barrier()
else:
# all remaining processes can read the mnist dataset once the master
# process finished downloading the mnist dataset
torch.distributed.barrier()
dataset = torchvision.datasets.MNIST(
root='.', train=True, transform=transforms.ToTensor(),
download=True)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE,
sampler=sampler,
num_workers=11,
pin_memory=True
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
model.train() # set the model to training mode
for epoch in range(N_EPOCH):
sampler.set_epoch(epoch)
total_loss = torch.zeros(1).to(device)
for image, target in data_loader:
image = image.to(device)
target = target.to(device)
output = model(image)
loss = criterion(output, target)
total_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# gather and sum the loss from each worker
torch.distributed.reduce(total_loss, MASTER_RANK,
torch.distributed.ReduceOp.SUM)
# only the master rank prints the loss at every epoch
if RANK == MASTER_RANK:
print(f"Total loss: {total_loss[0]}")
if __name__ == "__main__":
main()
Summary¶
We have explained what DistributedDataParallel
is and how it can be used with
torchrun
and multiple GPUs to speed up your machine learning training scripts.
We've provided some guidelines and tutorials to get you started using it on
Apocrita. As you progress in your research, you may need additional features
such as checkpointing, to save and resume training your model, and seeding, to
make your training reproducible.
In the next blog, we will benchmark DistributedDataParallel
on different GPUs
and neural networks to see how much of a performance gain we get from using
DistributedDataParallel
.
Acknowledgement¶
We like to thank Niki Foteinopoulou for raising a ticket with us and discussing
how to get DistributedDataParallel
working on Apocrita with us.
The GPU illustration is from vecteezy.com.