Learn the Training Loop with PyTorch, Part 3.4: Deeper Networks and Backprop Challenges

2025-06-18 · Artintellica

USER:

i need your help to write the next blog post in a series i am producing called “Learn the Training Loop with PyTorch”.

first, let me give you the outline of the series:


Module 1: The Elementary Training Loop

Goal: Build foundational intuition by hand-coding a simple training loop from scratch.

1. Introduction: What is a Training Loop?

  • Definition and importance
  • Loss, optimization, and learning in ML

2. The Simplest Case: Linear Regression

  • What is linear regression?
  • The mean squared error loss
  • Deriving the update rules manually

3. Batch vs. Stochastic Gradient Descent

  • Differences and practical implications
  • Hands-on: See both in Python/numpy

4. Visualizing the Loss Landscape

  • Plotting loss as a function of parameters
  • Demo: Watching convergence by hand

5. Numerical vs. Analytical Gradients

  • How are gradients computed?
  • Manual calculation and confirmation with numerical gradient

6. Recap and Key Takeaways


Module 2: The Training Loop for Neural Networks

Goal: Extend the basic loop to more powerful models; introduce neural network concepts.

1. From Linear to Nonlinear: Why Neural Networks?

  • Multi-layer perceptrons (MLPs)
  • Activation functions (ReLU, sigmoid, tanh)

2. Forward and Backward Passes

  • The chain rule of differentiation
  • Backpropagation intuition (no need for full rigor yet)

3. Implementing a Simple Neural Net from Scratch

  • Single hidden layer network in numpy/Python
  • Coding forward, loss, backward, and parameter updates

4. The Role of Activations

  • Intuitive explanation and code examples
  • Effect on learning and expressiveness

5. Mini-batching and Data Pipelines

  • How batching speeds up and stabilizes learning

6. Regularization and Overfitting

  • Techniques: L2, dropout (conceptual intro)

7. Recap: Comparing Our Simple Network with Linear Regression


Module 3: Advanced Training Loops and Modern Tricks

Goal: Connect foundational understanding to the large-scale models used in industry.

1. Optimization Algorithms Beyond SGD

  • Adam, RMSProp, etc.: How they work, why they’re useful

2. Learning Rate Scheduling

  • How/why learning rates are changed during training
  • Demo: Effect of learning rate on convergence and results

3. Weight Initialization

  • Why it matters, common strategies

4. Deeper Networks and Backprop Challenges

  • Vanishing/exploding gradients
  • Solutions: normalization, skip connections

5. Large-Scale Training: Data Parallelism and Hardware

  • How the same training loop runs efficiently on GPUs and across many machines

6. Monitoring and Debugging the Training Loop

  • Visualizing loss/accuracy, diagnosing common problems

7. Modern Regularization and Generalization Techniques

  • Early stopping, batch/layer normalization, data augmentation (conceptual intro)

8. The Training Loop in Practice: Case Studies

  • High-level overview of how OpenAI/Google/Meta might train LLMs: what’s the same, what’s different.
  • What “scaling up” really means: dataset, parameters, hardware, tricks.

9. Conclusion: What’s Next After the Training Loop?

  • Discussion: fine-tuning, transfer learning, and unsupervised/self-supervised methods.

General Tips for Each Module:

  • Start simple, build incrementally: Keep code simple at first, then show how to layer complexity.
  • Accompany every lesson with practical code examples.
  • Focus on intuition: Use diagrams, plots, and hands-on demo to make concepts tangible.
  • Summarize and review regularly: Add checklists or quizzes at the end of each module.

great. now, for each blog post, i am hoping to follow the following structure:

  • the course is called “Learn the Training Loop with PyTorch”
  • always provide a brief introduction to the topic at hand
  • then always provide a very simple “ELI5” explanation of the topic, that theortically anyone could understand
  • then provide the mathematical foundations. for math, always use dollar delimited LaTeX, and always provide a brief explanation of the math. for instance, y=Xw+b\mathbf{y} = \mathbf{X}\mathbf{w} + \mathbf{b} is the equation for linear regression, where y\mathbf{y} is the output vector, X\mathbf{X} is the input matrix, w\mathbf{w} is the weight vector, and b\mathbf{b} is the bias vector.
  • for block-level math, always use two dollar signs, like this: y=Xw+b\mathbf{y} = \mathbf{X}\mathbf{w} + \mathbf{b}
  • after the mathematical foundations, such as definitions, then always provide an in-depth explanation, covering the topic in detail, and expanding on the explanation of the math if necessary.
  • for other formatting, use markdown like headings, bullet points, and code blocks to make the content readable. do not use HTML tags.
  • after the math is complete, then provide one or more demonstrations of the topic using PyTorch. if possible, always use PyTorch, not numpy, just to keep it simple and consistent. you can also use matplotlib to visualize the results.
  • for all python code, always use type hints, and, if necessary, describe installation instructions using “uv” (not pip).
  • then, create a series of exercises, somewhere from 1 - 4 total exercises that cover the material. please provide the soruce code for each exercise to make sure the student can easily compare their solution to the correct answer. do not hide the exercises and do not put HTML in the exercises.
  • finally, provide a summary of the topic, and a list of key takeaways.

that’s it, now can you please write blog post 3.4?

the description is:

4. Deeper Networks and Backprop Challenges

  • Vanishing/exploding gradients
  • Solutions: normalization, skip connections

please respond with the full blog post only, and no other information.

ASSISTANT:

Learn the Training Loop with PyTorch

Module 3.4: Deeper Networks and Backprop Challenges


Introduction

As we build deeper neural networks to solve more complex tasks, we often run into unexpected issues during training. Two of the most notorious challenges are vanishing gradients and exploding gradients. These problems can make it hard or even impossible to train deep networks successfully. Fortunately, the modern deep learning toolkit provides several effective solutions, such as normalization techniques and skip connections.

This post will demystify these challenges and solutions. We’ll break down the intuition, provide the mathematical background, and demonstrate practical solutions using PyTorch.


ELI5: Why Does Deep Training Sometimes Fail?

Imagine sending a message down a line of people, each whispering it to the next. If the message is very faint, it might get quieter and quieter—eventually no one hears it (vanishing gradients). If people keep shouting louder and louder, the message becomes a roar—too loud to make sense (exploding gradients).

In deep neural networks, information must flow through many layers. If the “signal” (the gradients used to update the model) becomes too small or too large as it moves backward through layers, the network either learns too slowly or becomes unstable.


Mathematical Foundations

Backpropagation and Gradients

Recall that during training, we update network parameters using the gradients of the loss with respect to each parameter, computed via backpropagation. For a deep feedforward network with LL layers, the output can be described recursively as:

x(l)=f(l)(W(l)x(l1)+b(l))\mathbf{x}^{(l)} = f^{(l)}(\mathbf{W}^{(l)} \mathbf{x}^{(l-1)} + \mathbf{b}^{(l)})

where:

  • x(l)\mathbf{x}^{(l)} is the output of layer ll
  • f(l)f^{(l)} is the activation function (e.g., ReLU, sigmoid)
  • W(l)\mathbf{W}^{(l)}, b(l)\mathbf{b}^{(l)} are the weights and biases of layer ll

The gradient of the loss LL with respect to earlier weights (e.g., W(1)\mathbf{W}^{(1)}) is obtained through repeated application of the chain rule:

LW(1)=Lx(L)l=2Lx(l)x(l1)x(1)W(1)\frac{\partial L}{\partial \mathbf{W}^{(1)}} = \frac{\partial L}{\partial \mathbf{x}^{(L)}} \cdot \prod_{l=2}^{L} \frac{\partial \mathbf{x}^{(l)}}{\partial \mathbf{x}^{(l-1)}} \cdot \frac{\partial \mathbf{x}^{(1)}}{\partial \mathbf{W}^{(1)}}

Vanishing & Exploding Gradients

The problem arises in the product

l=2Lx(l)x(l1)\prod_{l=2}^{L} \frac{\partial \mathbf{x}^{(l)}}{\partial \mathbf{x}^{(l-1)}}

If these Jacobians contain values (or, for scalar activations, derivatives) consistently less than 1, their product shrinks rapidly as LL increases – leading to vanishing gradients. If the derivatives are consistently larger than 1, the product grows fast—leading to exploding gradients.

For example, suppose each derivative is approximately 0.50.5, and your network is 1010 layers deep:

0.510=0.000976560.5^{10} = 0.00097656

The gradient becomes almost zero!


In-depth Explanation

Why Does This Happen?

The repeated multiplication by the derivatives of each layer and the weights means any issues get amplified exponentially with depth:

  • Sigmoid/tanh activations: Their derivatives are at most 0.25/1.00.25/1.0, so repeated multiplication quickly shrinks the gradient.
  • Poor weight initialization: Weights that are too large or too small can make products of gradients shrink to zero or explode to infinity.

Consequence:

  • Vanishing gradients: Early layers learn extremely slowly (or not at all).
  • Exploding gradients: Training diverges—loss becomes NaN or the model “blows up.”

How Do We Fix It?

Modern networks are made possible by two main strategies:

1. Normalization (e.g., BatchNorm)

Normalization layers help regulate the scale of activations at each layer, making gradients more stable and preventing them from exploding or vanishing.

  • Batch Normalization normalizes activations within a mini-batch to have zero mean and unit variance: x^(l)=x(l)μ(l)(σ(l))2+ϵ\hat{x}^{(l)} = \frac{x^{(l)} - \mu^{(l)}}{\sqrt{(\sigma^{(l)})^2 + \epsilon}} where μ(l)\mu^{(l)} and σ(l)\sigma^{(l)} are mean and std of x(l)x^{(l)} in the batch.

2. Skip Connections (e.g., Residual Networks)

Skip (or residual) connections allow gradients to shortcut around layers, so information and gradients can “flow” more directly from later to earlier layers. The core idea:

x(l+1)=f(W(l)x(l)+b(l))+x(l)\mathbf{x}^{(l+1)} = f(\mathbf{W}^{(l)} \mathbf{x}^{(l)} + \mathbf{b}^{(l)}) + \mathbf{x}^{(l)}

This simple addition helps fight vanishing gradients and enables the training of extremely deep networks (hundreds or thousands of layers).


Demonstration: Seeing Gradient Problems and Solutions in PyTorch

Let’s walk through a simple experiment in PyTorch to illustrate these concepts.

Installation

If you do not have torch or matplotlib installed, use:

uv pip install torch matplotlib

1. Vanishing Gradient Demo: Deep Network with Sigmoid

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Fix random seed for reproducibility
torch.manual_seed(0)

class DeepSigmoidNet(nn.Module):
    def __init__(self, depth: int, hidden_size: int = 32):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size) for _ in range(depth)
        ])
        self.activation = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = self.activation(layer(x))
        return x

# Build a deep network
depth = 20
net = DeepSigmoidNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)

# Forward pass
output = net(x)
loss = (output - target).pow(2).mean()

# Backward pass
loss.backward()

# Record gradient norms for each layer
grad_norms = []
for i, layer in enumerate(net.layers):
    norm = layer.weight.grad.norm().item()
    grad_norms.append(norm)

plt.plot(range(1, depth+1), grad_norms, marker='o')
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Vanishing Gradients in Deep Sigmoid Network')
plt.show()

What you’ll see: The gradient norms shrink as you move toward the first layer. Early layers barely receive any signal—classic vanishing gradients.


2. Fixing with BatchNorm

Let’s try again, adding nn.BatchNorm1d after each layer:

class DeepBatchNormSigmoidNet(nn.Module):
    def __init__(self, depth: int, hidden_size: int = 32):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.BatchNorm1d(hidden_size)
            ) for _ in range(depth)
        ])
        self.activation = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = self.activation(layer(x))
        return x

# Try again
net_bn = DeepBatchNormSigmoidNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_bn(x)
loss = (output - target).pow(2).mean()
loss.backward()

# New gradient norms
grad_norms_bn = [layer[0].weight.grad.norm().item() for layer in net_bn.layers]

plt.plot(range(1, depth+1), grad_norms_bn, marker='s', label='BatchNorm', color='orange')
plt.plot(range(1, depth+1), grad_norms, marker='o', linestyle='--', label='Vanilla', alpha=0.6)
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Gradient Norms: BatchNorm vs No BatchNorm')
plt.legend()
plt.show()

Observation: Gradient norms with BatchNorm stay much more stable, avoiding the collapse seen previously!


3. Skip Connections: Residual Network in PyTorch

Let’s see how a simple residual connection keeps the gradient healthy:

class DeepResidualNet(nn.Module):
    def __init__(self, depth: int, hidden_size: int = 32):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size) for _ in range(depth)
        ])
        self.activation = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            z = layer(x)
            x = self.activation(z) + x  # Residual connection!
        return x

# Build and test
net_res = DeepResidualNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_res(x)
loss = (output - target).pow(2).mean()
loss.backward()

# Gradients with residuals
grad_norms_res = [layer.weight.grad.norm().item() for layer in net_res.layers]

plt.plot(range(1, depth+1), grad_norms_res, marker='^', color='green', label='Residual')
plt.plot(range(1, depth+1), grad_norms, marker='o', linestyle='--', label='Vanilla', alpha=0.6)
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Gradient Norms: Residual vs No Residual')
plt.legend()
plt.show()

Result: Notice how the gradient norm remains strong, even for early layers. This means the network can actually learn from the start!


Exercises

1. Exercise: Visualize Exploding Gradients

Task:
Modify the first demo to use a network where all weights are initialized to a large value (e.g., normal distribution with std=2.0). Plot the gradient norms for each layer. What happens?

Solution:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

depth = 20

class ExplodingNet(nn.Module):
    def __init__(self, depth: int, hidden_size: int = 32):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size) for _ in range(depth)
        ])
        self.activation = nn.ReLU()
        for layer in self.layers:
            nn.init.normal_(layer.weight, std=2.0)
            nn.init.zeros_(layer.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = self.activation(layer(x))
        return x

net_expl = ExplodingNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_expl(x)
loss = (output - target).pow(2).mean()
loss.backward()
grad_norms_expl = [layer.weight.grad.norm().item() for layer in net_expl.layers]

plt.plot(range(1, depth+1), grad_norms_expl, marker='x', color='red')
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Exploding Gradients with Large Weight Initialization')
plt.show()

2. Exercise: Add Layer Normalization

Task:
Modify the DeepSigmoidNet class to use nn.LayerNorm instead of BatchNorm and plot gradient norms. How does it compare?

Solution:

class DeepLayerNormSigmoidNet(nn.Module):
    def __init__(self, depth: int, hidden_size: int = 32):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size)
            ) for _ in range(depth)
        ])
        self.activation = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = self.activation(layer(x))
        return x

net_ln = DeepLayerNormSigmoidNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_ln(x)
loss = (output - target).pow(2).mean()
loss.backward()
grad_norms_ln = [layer[0].weight.grad.norm().item() for layer in net_ln.layers]

plt.plot(range(1, depth+1), grad_norms_ln, marker='*', color='purple')
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Gradient Norms with LayerNorm')
plt.show()

3. Exercise: Implement Your Own Residual Block

Task:
Write a simple ResidualBlock in PyTorch that takes in any module as a “core layer” and adds a residual connection.

Solution:

class ResidualBlock(nn.Module):
    def __init__(self, core_layer: nn.Module):
        super().__init__()
        self.core_layer = core_layer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.core_layer(x) + x  # Residual connection

# Example usage
block = ResidualBlock(nn.Linear(32, 32))
x = torch.randn(8, 32)
output = block(x)
print("Output shape:", output.shape)

Summary

Training deep neural networks is hard—not because the models themselves are fundamentally mysterious, but because of how gradients behave as they flow backward through many layers. Vanishing and exploding gradients can cripple learning, but understanding why they happen guides us to solutions. Normalization layers and skip connections have revolutionized deep learning, letting us train networks that would have been impossible a decade ago.


Key Takeaways

  • Vanishing gradients occur when repeated multiplication by small derivatives causes gradient signals to shrink, stalling learning.
  • Exploding gradients result from derivatives (or weights) larger than 11, causing gradients to blow up.
  • Batch Normalization (and Layer Normalization) standardize activations at each layer, stabilizing gradients.
  • Residual (skip) connections allow gradients to flow more easily through deep networks, making training possible for extremely deep models.
  • These two innovations—normalization and skip connections—are central to modern deep learning and have enabled much of recent progress.

Next up: We’ll look at how large-scale data and hardware make these powerful networks practical in the real world!

USER:



Next Blog Posts

Earlier Blog Posts


Back to Blog

Copyright © 2025 Identellica LLC
Home · Blog · Source Code