USER:
i need your help to write the next blog post in a series i am producing called “Learn the Training Loop with PyTorch”.
first, let me give you the outline of the series:
Goal: Build foundational intuition by hand-coding a simple training loop from scratch.
1. Introduction: What is a Training Loop?
2. The Simplest Case: Linear Regression
3. Batch vs. Stochastic Gradient Descent
4. Visualizing the Loss Landscape
5. Numerical vs. Analytical Gradients
6. Recap and Key Takeaways
Goal: Extend the basic loop to more powerful models; introduce neural network concepts.
1. From Linear to Nonlinear: Why Neural Networks?
2. Forward and Backward Passes
3. Implementing a Simple Neural Net from Scratch
4. The Role of Activations
5. Mini-batching and Data Pipelines
6. Regularization and Overfitting
7. Recap: Comparing Our Simple Network with Linear Regression
Goal: Connect foundational understanding to the large-scale models used in industry.
1. Optimization Algorithms Beyond SGD
2. Learning Rate Scheduling
3. Weight Initialization
4. Deeper Networks and Backprop Challenges
5. Large-Scale Training: Data Parallelism and Hardware
6. Monitoring and Debugging the Training Loop
7. Modern Regularization and Generalization Techniques
8. The Training Loop in Practice: Case Studies
9. Conclusion: What’s Next After the Training Loop?
great. now, for each blog post, i am hoping to follow the following structure:
that’s it, now can you please write blog post 3.4?
the description is:
4. Deeper Networks and Backprop Challenges
please respond with the full blog post only, and no other information.
ASSISTANT:
As we build deeper neural networks to solve more complex tasks, we often run into unexpected issues during training. Two of the most notorious challenges are vanishing gradients and exploding gradients. These problems can make it hard or even impossible to train deep networks successfully. Fortunately, the modern deep learning toolkit provides several effective solutions, such as normalization techniques and skip connections.
This post will demystify these challenges and solutions. We’ll break down the intuition, provide the mathematical background, and demonstrate practical solutions using PyTorch.
Imagine sending a message down a line of people, each whispering it to the next. If the message is very faint, it might get quieter and quieter—eventually no one hears it (vanishing gradients). If people keep shouting louder and louder, the message becomes a roar—too loud to make sense (exploding gradients).
In deep neural networks, information must flow through many layers. If the “signal” (the gradients used to update the model) becomes too small or too large as it moves backward through layers, the network either learns too slowly or becomes unstable.
Recall that during training, we update network parameters using the gradients of the loss with respect to each parameter, computed via backpropagation. For a deep feedforward network with layers, the output can be described recursively as:
where:
The gradient of the loss with respect to earlier weights (e.g., ) is obtained through repeated application of the chain rule:
The problem arises in the product
If these Jacobians contain values (or, for scalar activations, derivatives) consistently less than 1, their product shrinks rapidly as increases – leading to vanishing gradients. If the derivatives are consistently larger than 1, the product grows fast—leading to exploding gradients.
For example, suppose each derivative is approximately , and your network is layers deep:
The gradient becomes almost zero!
The repeated multiplication by the derivatives of each layer and the weights means any issues get amplified exponentially with depth:
Consequence:
Modern networks are made possible by two main strategies:
Normalization layers help regulate the scale of activations at each layer, making gradients more stable and preventing them from exploding or vanishing.
Skip (or residual) connections allow gradients to shortcut around layers, so information and gradients can “flow” more directly from later to earlier layers. The core idea:
This simple addition helps fight vanishing gradients and enables the training of extremely deep networks (hundreds or thousands of layers).
Let’s walk through a simple experiment in PyTorch to illustrate these concepts.
If you do not have torch
or matplotlib
installed, use:
uv pip install torch matplotlib
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# Fix random seed for reproducibility
torch.manual_seed(0)
class DeepSigmoidNet(nn.Module):
def __init__(self, depth: int, hidden_size: int = 32):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size) for _ in range(depth)
])
self.activation = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = self.activation(layer(x))
return x
# Build a deep network
depth = 20
net = DeepSigmoidNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
# Forward pass
output = net(x)
loss = (output - target).pow(2).mean()
# Backward pass
loss.backward()
# Record gradient norms for each layer
grad_norms = []
for i, layer in enumerate(net.layers):
norm = layer.weight.grad.norm().item()
grad_norms.append(norm)
plt.plot(range(1, depth+1), grad_norms, marker='o')
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Vanishing Gradients in Deep Sigmoid Network')
plt.show()
What you’ll see: The gradient norms shrink as you move toward the first layer. Early layers barely receive any signal—classic vanishing gradients.
Let’s try again, adding nn.BatchNorm1d
after each layer:
class DeepBatchNormSigmoidNet(nn.Module):
def __init__(self, depth: int, hidden_size: int = 32):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.BatchNorm1d(hidden_size)
) for _ in range(depth)
])
self.activation = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = self.activation(layer(x))
return x
# Try again
net_bn = DeepBatchNormSigmoidNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_bn(x)
loss = (output - target).pow(2).mean()
loss.backward()
# New gradient norms
grad_norms_bn = [layer[0].weight.grad.norm().item() for layer in net_bn.layers]
plt.plot(range(1, depth+1), grad_norms_bn, marker='s', label='BatchNorm', color='orange')
plt.plot(range(1, depth+1), grad_norms, marker='o', linestyle='--', label='Vanilla', alpha=0.6)
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Gradient Norms: BatchNorm vs No BatchNorm')
plt.legend()
plt.show()
Observation: Gradient norms with BatchNorm stay much more stable, avoiding the collapse seen previously!
Let’s see how a simple residual connection keeps the gradient healthy:
class DeepResidualNet(nn.Module):
def __init__(self, depth: int, hidden_size: int = 32):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size) for _ in range(depth)
])
self.activation = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
z = layer(x)
x = self.activation(z) + x # Residual connection!
return x
# Build and test
net_res = DeepResidualNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_res(x)
loss = (output - target).pow(2).mean()
loss.backward()
# Gradients with residuals
grad_norms_res = [layer.weight.grad.norm().item() for layer in net_res.layers]
plt.plot(range(1, depth+1), grad_norms_res, marker='^', color='green', label='Residual')
plt.plot(range(1, depth+1), grad_norms, marker='o', linestyle='--', label='Vanilla', alpha=0.6)
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Gradient Norms: Residual vs No Residual')
plt.legend()
plt.show()
Result: Notice how the gradient norm remains strong, even for early layers. This means the network can actually learn from the start!
Task:
Modify the first demo to use a network where all weights are initialized to a large value (e.g., normal distribution with std=2.0). Plot the gradient norms for each layer. What happens?
Solution:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
depth = 20
class ExplodingNet(nn.Module):
def __init__(self, depth: int, hidden_size: int = 32):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size) for _ in range(depth)
])
self.activation = nn.ReLU()
for layer in self.layers:
nn.init.normal_(layer.weight, std=2.0)
nn.init.zeros_(layer.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = self.activation(layer(x))
return x
net_expl = ExplodingNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_expl(x)
loss = (output - target).pow(2).mean()
loss.backward()
grad_norms_expl = [layer.weight.grad.norm().item() for layer in net_expl.layers]
plt.plot(range(1, depth+1), grad_norms_expl, marker='x', color='red')
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Exploding Gradients with Large Weight Initialization')
plt.show()
Task:
Modify the DeepSigmoidNet
class to use nn.LayerNorm
instead of BatchNorm and plot gradient norms. How does it compare?
Solution:
class DeepLayerNormSigmoidNet(nn.Module):
def __init__(self, depth: int, hidden_size: int = 32):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size)
) for _ in range(depth)
])
self.activation = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = self.activation(layer(x))
return x
net_ln = DeepLayerNormSigmoidNet(depth=depth)
x = torch.randn(16, 32)
target = torch.randn(16, 32)
output = net_ln(x)
loss = (output - target).pow(2).mean()
loss.backward()
grad_norms_ln = [layer[0].weight.grad.norm().item() for layer in net_ln.layers]
plt.plot(range(1, depth+1), grad_norms_ln, marker='*', color='purple')
plt.xlabel('Layer')
plt.ylabel('Weight Gradient Norm')
plt.title('Gradient Norms with LayerNorm')
plt.show()
Task:
Write a simple ResidualBlock
in PyTorch that takes in any module as a “core layer” and adds a residual connection.
Solution:
class ResidualBlock(nn.Module):
def __init__(self, core_layer: nn.Module):
super().__init__()
self.core_layer = core_layer
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.core_layer(x) + x # Residual connection
# Example usage
block = ResidualBlock(nn.Linear(32, 32))
x = torch.randn(8, 32)
output = block(x)
print("Output shape:", output.shape)
Training deep neural networks is hard—not because the models themselves are fundamentally mysterious, but because of how gradients behave as they flow backward through many layers. Vanishing and exploding gradients can cripple learning, but understanding why they happen guides us to solutions. Normalization layers and skip connections have revolutionized deep learning, letting us train networks that would have been impossible a decade ago.
Next up: We’ll look at how large-scale data and hardware make these powerful networks practical in the real world!
USER: