12
Feb

How parallel training works in PyTorch and Deep Learning? The comprehensive guide.

Why you need parallel training?

In the world of machine learning, handling big chunks of data is crucial, especially for tasks like processing images and text. Imagine you’re working on a project with a massive model like Large Language Models (LLMs), and it takes a whopping 64 days to train it on a single GPU. But, you’ve got a looming deadline, and you need to cut that training time down to just 1 day or even 1 hour. Is that even possible? Yes, it is, thanks to parallel training. In this article, we’ll explore how you can use PyTorch, a popular deep learning package, to train your models in parallel, saving you valuable time for your next big project.

To make things simpler, let’s break down the terms.

– First, imagine a single computer as a “node.”

– In our example, if we want to cut the training time from 64 days to 32 days, we’ll need 2 GPUs (Graphics Processing Units).

– But, if we need to reduce the training to just 1 hour, we’d require more than 32 GPUs, which can’t fit in a single node. Note that depending on the complexity of the model and the size of the dataset, it’s likely that we would need far more than 32 GPUs to accomplish this within an hour.

– So, we’d use multiple nodes, i.e. 4 nodes each with around 8 GPUs.

– When we train our model using multiple nodes like this, it’s called “distributed training.”

Parallel training in PyTorch involves distributing the computation across multiple GPUs to accelerate the training process. This can be achieved using PyTorch’s DataParallel or DistributedDataParallel(DDP) modules.

Dataparallel

DataParallel in PyTorch is a module designed to distribute data across multiple GPUs during training. It replicates the model onto each GPU and splits the input batches into smaller batches to be processed independently on each GPU. After processing, the results are gathered and combined, and the backward pass is performed to update the model parameters.

Here’s a simplified overview of how DataParallel works internally:

Replication of the Model:

The original model is replicated onto each available GPU. Each replica has its set of parameters, and updates are performed independently on each GPU. It is important that you set the seed for each GPU to be of same number as otherwise the replicated models will have different parameters.

Input Splitting:

The input batch is split into smaller batches, with each smaller batch sent to a different GPU. If we have 4 GPUs and total of 64 images. A effective sample batch will contain 16 images.

Parallel Forward Pass:

Each GPU computes the forward pass independently on its portion of the input data. The model parameters are shared among all GPUs, and each GPU computes the forward pass using its subset of parameters.

Gradient Computation:

After the forward pass, each GPU independently computes the gradient of the loss with respect to its subset of parameters.

Gradient All-Reduce:

The gradients computed on each GPU are synchronised across all GPUs through a process called gradient all-reduction. All-reduce ensures that the gradients from each GPU are summed and averaged before performing back propagation. If the gradient produced in GPU1 is dl1 and the gradient produced in GPU2 in dl2. Then the gradient for both the GPU will be same which is (dl1 + dl2)/2.

Backward Pass:

With the synchronised gradients, the backward pass is performed. Each GPU updates its subset of parameters using the aggregated gradients.

Update Model Parameters:

The updated parameters are then shared among all GPUs, ensuring that the model parameters are consistent across all replicas.

Repeat for the Next Batch:

Steps 2-7 are repeated for each batch in the training dataset. It’s important to note that DataParallel is straightforward to use, but it may not be the most efficient solution in some cases, especially when dealing with very large models or datasets. For more advanced scenarios, such as distributed training across multiple nodes, DistributedDataParallel might be more appropriate. DistributedDataParallel builds upon the principles of DataParallel but introduces more flexibility and control over distributed training.

It is fundamental to understand, DataParallel in PyTorch is primarily designed to work with multiple GPUs, and its goal is to distribute data and computation across these GPUs. If you have a single GPU with 60 cores, DataParallel would still attempt to distribute the data across these cores, but in practice, there might not be significant benefits compared to running the model on a single core.

In the case of a single GPU with multiple cores, each core usually handles a portion of the workload. However, GPUs are highly parallelized processors designed to handle parallel tasks efficiently. For deep learning tasks, the parallelization often occurs at a higher level, distributing batches of data across different GPUs. When you only have one GPU with multiple cores, the benefit of using DataParallel might be limited compared to the case where you have multiple GPUs.

It’s worth noting that the performance gains from parallelization are often more apparent when working with multiple GPUs or distributed computing environments. In a single GPU scenario, you may still see some benefits, but the magnitude of improvement might not be as significant as when using multiple GPUs.

If you have a single GPU with 60 cores and you want to make the most out of it, it’s essential to ensure that your model and training pipeline are optimized for parallel processing. This might involve optimizing your model architecture, using efficient batch sizes, and leveraging GPU-specific optimizations.

In summary, while DataParallel can technically be used with a single GPU having multiple cores, the benefits may not be as pronounced as in a multi-GPU setup. It’s recommended to profile and experiment to find the optimal configuration for your specific model and hardware.

It is also essential for one to understand what we mean by parallel processing in GPUs. Parallel processing is a computing paradigm where multiple tasks or processes are executed simultaneously. The goal is to divide a complex problem into smaller sub-problems and solve them concurrently, leveraging multiple processing units, such as CPU cores or GPUs, to improve overall computational efficiency. Parallel processing can be applied to various types of computations, from simple arithmetic operations to complex simulations and data processing tasks.

There are two primary types of parallel processing:

Task Parallelism:

In task parallelism, different tasks or processes are executed concurrently. Each task is independent and can be executed simultaneously on different processors. This approach is suitable for problems where the workload can be divided into separate, independent tasks that do not rely heavily on each other.

Data Parallelism:

In data parallelism, a single task is divided into multiple subtasks, each operating on different sets of data. These subtasks are then executed simultaneously on different processors. This approach is commonly used in parallel computing for machine learning, image processing, and other data-intensive tasks.

Parallel processing provides several advantages:

Increased Performance: By executing multiple tasks concurrently, the overall computation time can be significantly reduced, leading to improved performance.

Efficient Resource Utilisation: Utilising multiple processors allows for better utilisation of available computational resources, leading to faster completion of tasks.

Scalability: As the size of the problem or the amount of data increases, parallel processing provides a scalable solution by distributing the workload across multiple processors.

Improved Throughput: Parallel processing can lead to higher throughput, allowing a system to handle more tasks within a given time frame.

However, implementing parallel processing comes with its challenges:

Synchronisation: Ensuring that parallel tasks or processes synchronise correctly and share data without conflicts can be complex.

Communication Overhead: Coordinating communication between parallel processes can introduce overhead and impact overall performance.

Load Balancing: Distributing the workload evenly across processors is crucial for optimal performance, and load imbalances can lead to suboptimal results.

Parallel Algorithm Design: Developing algorithms that effectively exploit parallelism requires careful consideration and design.

Various programming models and frameworks, such as OpenMP, MPI (Message Passing Interface), CUDA (for GPUs), and parallel extensions in high-level languages like Python’s multiprocessing module, are available to facilitate the implementation of parallel processing in different contexts.

You can check the official documentation for DataParallel for PyTorch here.

Distributed Data Parallel (DDP)

Distributed data parallelism is a method for training deep learning models that involves distributing both the model parameters and the data across multiple processing units or nodes. Unlike traditional parallelism, where each processing unit operates independently on a subset of the data, distributed data parallelism synchronises the updates to the model parameters across all nodes, enabling efficient coordination and scaling of training tasks.

How Does Distributed Data Parallelism Work?

At its core, distributed data parallelism operates on the principle of splitting both the model and the dataset into smaller chunks and distributing them across multiple computing nodes. Each node is responsible for processing a portion of the data and computing gradients for the corresponding subset of model parameters. These gradients are then aggregated across all nodes, and the model parameters are updated synchronously to reflect the collective information from the entire dataset.

Benefits of Distributed Data Parallelism:

1. Scalability: By leveraging multiple computing nodes, distributed data parallelism enables the training of large-scale models on massive datasets that exceed the memory capacity of a single machine. This scalability is essential for tackling complex machine learning tasks, such as image recognition, natural language processing, and reinforcement learning, where large amounts of data are ubiquitous.

2. Speed: Parallelising the training process across multiple nodes accelerates the convergence of deep learning models, leading to faster training times and improved productivity for data scientists and machine learning practitioners. This speedup is particularly crucial in time-sensitive applications and scenarios where rapid experimentation and model iteration are essential.

3. Fault Tolerance: Distributed data parallelism enhances the resilience of deep learning systems by mitigating the impact of hardware failures or network disruptions. By replicating data and model parameters across multiple nodes, the training process can continue uninterrupted even in the event of node failures, thereby ensuring robustness and reliability in large-scale distributed environments.

4. Resource Utilisation: By efficiently distributing the computational workload across multiple nodes, distributed data parallelism maximises the utilisation of available resources, minimising idle time and optimising cost-effectiveness in cloud-based or on-premises computing infrastructures.

Implementing Distributed Data Parallelism with PyTorch:

PyTorch, a popular deep learning framework, provides robust support for distributed training through its DistributedDataParallel module. Leveraging PyTorch’s built-in functionalities, developers can seamlessly distribute deep learning tasks across multiple GPUs or even across distributed clusters, harnessing the full potential of parallel computing for accelerated model training. One can also use PyTorch Lightning package that simplifies the distributed training easy for beginners. Refer this video for more information on pytorchlightening.

Checkout our latest blogs here.

If you are a data scientist aspirant and want to master data science, you can check our data science course in bangalore rajajinagar here.