Welcome back to Learn Reinforcement Learning with PyTorch! Up to now, we’ve explored key reinforcement learning concepts—states, actions, rewards—and implemented classic “tabular” algorithms such as Q-Learning and SARSA. These approaches store the value estimates for every possible state (or state-action pair) in a table. This method works well for simple environments, but quickly becomes infeasible as problems scale.
In this post, we’ll uncover the limitations of tabular reinforcement learning (RL), especially the “curse of dimensionality”, and explain why function approximation (especially deep neural networks) is critical for modern RL. You’ll learn the mathematics behind the state explosion, see hands-on demos, and begin thinking about designing neural architectures to replace (or generalize) the idea of the Q-table.
In tabular RL, we represent the action-value function (or state-value ) as a table, explicitly storing an entry for each state (or state-action) pair.
Suppose an environment’s state is described by discrete variables, with each variable ranging over possible values. Then, the total number of possible states is
For each action from a set , we track :
This grows exponentially with the number of state variables (“dimensions”).
If an agent observes a grid (think of a small Gridworld), with its location as 2 variables x and y (each with 5 possible values), then:
—no problem! But if you increase to an grid (large map, ), or the agent observes more features (positions of several objects), explodes.
If each variable has 100 possible values, and there are 4 variables:
Modern RL environments (e.g., Atari games, robotics with images) can have billions or more possible states!
Rather than represent as a lookup table, we learn a function:
where is a parameterized model (e.g., a neural network with weights ), mapping input state and action to an estimated value . This enables:
Let’s see what state and Q-table explosion look like in code, and why neural networks are needed.
Let’s illustrate how quickly table size explodes as the environment becomes more complex.
from typing import List, Tuple
import torch
def compute_table_size(state_sizes: List[int], n_actions: int) -> int:
"""
Compute total entries for a Q-table given state variable sizes and number of actions.
"""
from functools import reduce
from operator import mul
n_states = reduce(mul, state_sizes, 1)
return n_states * n_actions
# Example: 4 state variables, each with 100 values; 5 actions
state_sizes: List[int] = [100, 100, 100, 100]
n_actions: int = 5
table_entries: int = compute_table_size(state_sizes, n_actions)
print(f"Total Q-table entries: {table_entries:,}")
Output:
Total Q-table entries: 500,000,000
Allocating a table this size—with, say, 32-bit floats—would require ~2GB just for Q-values! Imagine 10 variables; it’s not possible.
Let’s visualize how the number of states grows as we increase the number of variables (“dimensions”).
import matplotlib.pyplot as plt
def plot_state_space_growth(k: int, max_vars: int) -> None:
"""
Plot number of possible states vs number of state variables for discrete variables.
"""
import numpy as np
num_vars = list(range(1, max_vars+1))
num_states = [k ** n for n in num_vars]
plt.figure(figsize=(7,4))
plt.semilogy(num_vars, num_states, marker='o')
plt.title(f"Exponential Growth of State Space (k={k})")
plt.xlabel("Number of state variables (dimensions)")
plt.ylabel("Number of possible states (log scale)")
plt.grid(True, which='both')
plt.show()
plot_state_space_growth(k=10, max_vars=10)
Try it! As variables go up, state space grows exponentially (don’t forget to use log-scale axes!).
Let’s try running tabular Q-learning in a simple “large” toy environment—and see what happens.
from typing import Tuple, Dict
import random
import numpy as np
# Our toy environment: 6 discrete state variables, each with 10 possible values
state_sizes: List[int] = [10] * 6 # 10^6 = 1,000,000 states
n_actions: int = 4
# Simulate the environment as a tuple of 6 ints
def random_state() -> Tuple[int, ...]:
return tuple(random.randint(0, 9) for _ in range(6))
# Initialize a Q-table
Q_table: Dict[Tuple[int,...], np.ndarray] = {}
# Try updating Q-table for 100,000 steps
for step in range(100_000):
state = random_state()
action = random.randint(0, n_actions-1)
next_state = random_state()
reward = random.random()
# Q-update (SARS, fixed alpha/gamma for this demo)
q_old = Q_table.get(state, np.zeros(n_actions))
alpha = 0.1
gamma = 0.99
# Q-learning update
q_target = reward + gamma * np.max(Q_table.get(next_state, np.zeros(n_actions)))
q_new = q_old.copy()
q_new[action] = (1 - alpha) * q_old[action] + alpha * q_target
Q_table[state] = q_new
print(f"Q-table size after 100,000 transitions: {len(Q_table):,} states visited")
print(f"Fraction of possible states visited: {len(Q_table)/1_000_000:.2%}")
What do you think will happen? We’ll find that after 100,000 episodes, only about 10% of the state space was ever even visited—most of the state-action values were never updated at all!
Suppose we want for arbitrary states (which could be vectors, images, etc). Here’s a minimal PyTorch MLP architecture that maps state (and action) to . For discrete actions, we often have our network output a vector of Q-values, one per action.
import torch
import torch.nn as nn
from typing import Any
class QNetwork(nn.Module):
def __init__(self, state_dim: int, n_actions: int, hidden_dim: int = 128) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, state: torch.Tensor) -> torch.Tensor:
"""
Input: state tensor of shape (batch_size, state_dim)
Output: Q-values, tensor of shape (batch_size, n_actions)
"""
return self.net(state)
# Example: state has 6 features, 4 possible actions
state_dim: int = 6
n_actions: int = 4
q_net = QNetwork(state_dim, n_actions)
# Batch of states
states: torch.Tensor = torch.rand(8, state_dim)
q_values: torch.Tensor = q_net(states)
print(q_values.shape) # (8, 4)
This neural net’s parameters don’t grow with the number of possible states—they are fixed! The learned function generalizes to new, even unseen, states.
Task:
Write code to define an environment where there are 8 discrete state variables,
each with 20 possible values. Compute how many total states there are, and how
large (in MB) a Q-table would be if you stored a 32-bit float for each
pair for 6 discrete actions.
Solution:
from typing import List
import numpy as np
def q_table_size(state_sizes: List[int], n_actions: int, dtype: np.dtype = np.float32) -> float:
from functools import reduce
from operator import mul
n_states = reduce(mul, state_sizes, 1)
total_entries = n_states * n_actions
bytes_total = total_entries * np.dtype(dtype).itemsize
mb_total = bytes_total / (1024**2)
return n_states, total_entries, mb_total
state_sizes: List[int] = [20] * 8 # 20^8
n_actions: int = 6
n_states, n_entries, size_mb = q_table_size(state_sizes, n_actions)
print(f"Number of states: {n_states:,}")
print(f"Q-table entries: {n_entries:,}")
print(f"Q-table size: {size_mb:,.2f} MB (float32)")
Task:
For a sequence of 1 to 12 state variables, each having 10 possible values, plot
(on a log scale) the number of possible states versus the number of variables.
Annotate the plot to show where the state count exceeds 1 million.
Solution:
import numpy as np
import matplotlib.pyplot as plt
def plot_state_growth(base: int = 10, max_vars: int = 12) -> None:
num_vars = np.arange(1, max_vars+1)
num_states = base ** num_vars
plt.figure(figsize=(7,4))
plt.semilogy(num_vars, num_states, marker='o')
plt.axhline(1_000_000, color='red', linestyle='--', label="1 million states")
plt.title(f"Exponential State Space Growth (base={base})")
plt.xlabel("Number of state variables")
plt.ylabel("Number of possible states (log scale)")
plt.legend()
plt.grid(True, which='both', ls='--')
plt.tight_layout()
plt.show()
plot_state_growth()
Task:
Simulate tabular Q-learning in an environment with 6 discrete state variables
(each with 10 values), updating the Q-table for 250,000 random transitions.
Afterward, report:
Solution:
from typing import Tuple, Dict, List
import random
import numpy as np
state_sizes: List[int] = [10] * 6 # 10^6 possible states
n_actions: int = 4
total_states: int = 10 ** 6
def random_state() -> Tuple[int, ...]:
return tuple(random.randint(0, 9) for _ in range(6))
Q_table: Dict[Tuple[int,...], np.ndarray] = {}
for _ in range(250_000):
state = random_state()
action = random.randint(0, n_actions-1)
next_state = random_state()
reward = random.random()
q_old = Q_table.get(state, np.zeros(n_actions))
alpha = 0.1
gamma = 0.99
q_target = reward + gamma * np.max(Q_table.get(next_state, np.zeros(n_actions)))
q_new = q_old.copy()
q_new[action] = (1 - alpha) * q_old[action] + alpha * q_target
Q_table[state] = q_new
visited = len(Q_table)
unused = total_states - visited
print(f"Unique states visited: {visited:,} / {total_states:,} ({visited/total_states:.2%})")
print(f"Fraction of table remaining unused: {unused/total_states:.2%}")
Task:
Specify (in PyTorch code) a neural network that takes an input state vector of
size 8 and outputs Q-values for 6 actions. Use at least two hidden layers. Print
the number of parameters.
Solution:
import torch
import torch.nn as nn
class BigQNetwork(nn.Module):
def __init__(self, state_dim: int = 8, n_actions: int = 6, hidden_dim: int = 128) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, state: torch.Tensor) -> torch.Tensor:
return self.net(state)
qnet = BigQNetwork()
n_params = sum(p.numel() for p in qnet.parameters())
print(f"Q-network has {n_params:,} parameters (trainable)")
# Example usage: for a batch of 10 states
state_batch = torch.rand(10, 8)
q_outputs = qnet(state_batch)
print(f"Output shape: {q_outputs.shape}") # (10, 6)
Tabular RL is a great tool for learning, but quickly reaches its limits in practical problems, falling victim to the curse of dimensionality. The exponential growth in the number of states makes storing (and learning) a separate value for each state-action pair impossible.
Function approximation—using neural networks—is how modern RL methods scale to large, complex environments. Neural nets can generalize from observed to unobserved states, store Q or value functions compactly, and enable deep RL to work with high-dimensional states like images and physics sensors.
Up next: We’ll dive into implementing Deep Q-Networks (DQN) in PyTorch, your first step towards high-performance, scalable RL agents!