Welcome! Thus far, you’ve used logistic regression for binary classification. But in real-world RL and ML, problems usually involve choosing between several possible actions or classes—not just two. This blog post will introduce you to the softmax function and multiclass (a.k.a. multinomial) classification.
You will:
Given an input vector (the “logits” for classes), the softmax function outputs a probability vector :
where is the logit for class .
One-hot true targets: for each sample.
The cross-entropy loss for a sample with logits and true class :
This penalizes the model for assigning low probability to the true class.
Multiclass classification means the model estimates a probability for each class, not just for “yes” or “no”. Neural networks (and logistic regression extensions) do this by outputting a vector of logits.
In practice:
nn.CrossEntropyLoss
, which both
applies softmax and computes the correct cross-entropy loss (using logits
directly for numerical stability).In code:
import torch
import matplotlib.pyplot as plt
torch.manual_seed(42)
N = 300
cov = torch.tensor([[1.2, 0.8], [0.8, 1.2]])
L = torch.linalg.cholesky(cov)
means = [torch.tensor([-2., 0.]), torch.tensor([2., 2.]), torch.tensor([0., -2.])]
X_list = []
y_list = []
for i, mu in enumerate(means):
Xi = torch.randn(N//3, 2) @ L.T + mu
X_list.append(Xi)
y_list.append(torch.full((N//3,), i))
X = torch.cat(X_list)
y = torch.cat(y_list).long()
colors = ['b', 'r', 'g']
for i in range(3):
plt.scatter(X_list[i][:,0], X_list[i][:,1], color=colors[i], alpha=0.5, label=f"Class {i}")
plt.legend(); plt.xlabel("x1"); plt.ylabel("x2")
plt.title("Synthetic 3-Class Data")
plt.show()
import torch.nn.functional as F
def softmax(logits: torch.Tensor) -> torch.Tensor:
# For numerical stability, subtract max
logits = logits - logits.max(dim=1, keepdim=True).values
exp_logits = torch.exp(logits)
return exp_logits / exp_logits.sum(dim=1, keepdim=True)
def cross_entropy_manual(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# logits: (N, K), targets: (N,)
N = logits.shape[0]
log_probs = F.log_softmax(logits, dim=1)
return -log_probs[torch.arange(N), targets].mean()
# Toy example
logits = torch.tensor([[2.0, 0.5, -1.0],[0.0, 3.0, 0.5]])
targets = torch.tensor([0, 1])
probs = softmax(logits)
manual_loss = cross_entropy_manual(logits, targets)
print("Probabilities:\n", probs)
print("Manual cross-entropy loss:", manual_loss.item())
nn.CrossEntropyLoss
Let’s fit a linear classifier to the data above.
# Model: simple linear (no bias for simplicity)
W = torch.zeros(2, 3, requires_grad=True) # (features, classes)
b = torch.zeros(3, requires_grad=True)
lr = 0.05
loss_fn = torch.nn.CrossEntropyLoss()
losses = []
for epoch in range(400):
logits = X @ W + b # (N, 3)
loss = loss_fn(logits, y)
loss.backward()
with torch.no_grad():
W -= lr * W.grad
b -= lr * b.grad
W.grad.zero_(); b.grad.zero_()
losses.append(loss.item())
if epoch % 100 == 0 or epoch == 399:
print(f"Epoch {epoch}: Cross-entropy loss={loss.item():.3f}")
plt.plot(losses)
plt.title("Multiclass Classifier Training Loss")
plt.xlabel("Epoch"); plt.ylabel("Cross-Entropy Loss"); plt.grid(True)
plt.show()
import numpy as np
with torch.no_grad():
x1g, x2g = torch.meshgrid(torch.linspace(-6,6,200), torch.linspace(-6,6,200), indexing='ij')
Xg = torch.stack([x1g.reshape(-1), x2g.reshape(-1)], dim=1) # (n_grid, 2)
logits_grid = Xg @ W + b
y_pred_grid = logits_grid.argmax(dim=1).reshape(200,200)
plt.contourf(x1g, x2g, y_pred_grid.numpy(), levels=[-0.5,0.5,1.5,2.5], colors=['b','r','g'], alpha=0.15)
for i in range(3):
plt.scatter(X_list[i][:,0], X_list[i][:,1], color=colors[i], alpha=0.6, label=f"Class {i}")
plt.title("Learned Class Boundaries (Linear)")
plt.xlabel("x1"); plt.ylabel("x2"); plt.legend(); plt.show()
F.cross_entropy
.nn.CrossEntropyLoss
W
and b
).contourf
or imshow
to shade the 2D plane by predicted class.import torch
import matplotlib.pyplot as plt
import numpy as np
# EXERCISE 1
torch.manual_seed(0)
N = 150
means = [torch.tensor([-2.0, 0.]), torch.tensor([2.0, 2.5]), torch.tensor([0., -2.])]
cov = torch.tensor([[1.1, 0.6], [0.6, 1.0]])
L = torch.linalg.cholesky(cov)
X_list = []
y_list = []
for i, mu in enumerate(means):
Xi = torch.randn(N//3, 2) @ L.T + mu
X_list.append(Xi)
y_list.append(torch.full((N//3,), i))
X = torch.cat(X_list)
y = torch.cat(y_list).long()
for i, c in enumerate(['b', 'g', 'r']):
plt.scatter(X_list[i][:,0], X_list[i][:,1], color=c, alpha=0.5, label=f"Class {i}")
plt.legend(); plt.title("Synthetic Data"); plt.show()
# EXERCISE 2
def softmax(logits):
logits = logits - logits.max(dim=1, keepdim=True).values
exp = torch.exp(logits)
return exp / exp.sum(dim=1, keepdim=True)
def cross_entropy(logits, targets):
N = logits.shape[0]
log_probs = torch.nn.functional.log_softmax(logits, dim=1)
return -log_probs[torch.arange(N), targets].mean()
# Test
logits = torch.tensor([[2.0, -1.0, 0.5], [0.2, 1.0, -2.0]])
targets = torch.tensor([0, 1])
probs = softmax(logits)
print("Softmax probabilities:\n", probs)
print("Manual cross-entropy:", cross_entropy(logits, targets).item())
print("PyTorch cross-entropy:", torch.nn.functional.cross_entropy(logits, targets).item())
# EXERCISE 3
W = torch.zeros(2, 3, requires_grad=True)
b = torch.zeros(3, requires_grad=True)
lr = 0.05
losses = []
for epoch in range(400):
logits = X @ W + b
loss = torch.nn.functional.cross_entropy(logits, y)
loss.backward()
with torch.no_grad():
W -= lr * W.grad
b -= lr * b.grad
W.grad.zero_(); b.grad.zero_()
losses.append(loss.item())
plt.plot(losses); plt.title("Classifier Training Loss"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.grid(); plt.show()
# EXERCISE 4
with torch.no_grad():
grid_x, grid_y = torch.meshgrid(torch.linspace(-5, 5, 200), torch.linspace(-5, 5, 200), indexing='ij')
Xg = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
logits_grid = Xg @ W + b
pred_grid = logits_grid.argmax(dim=1).reshape(200,200)
plt.contourf(grid_x, grid_y, pred_grid.numpy(), levels=[-0.5,0.5,1.5,2.5], colors=['b','r','g'], alpha=0.15)
for i, c in enumerate(['b', 'g', 'r']):
plt.scatter(X_list[i][:,0], X_list[i][:,1], color=c, alpha=0.6, label=f"Class {i}")
plt.title("Decision Boundaries"); plt.legend(); plt.show()
You’ve now experienced multiclass classification end-to-end:
Next: In the next post, you’ll use neural networks to model even more complex (curved!) boundaries, and classify data that linear models can’t handle. This is the last road before deep RL.
Practice tweaking your classes, noise, and model—softmax is the backbone of every multiclass RL agent! See you in Part 2.8!