After seeing bandits, let’s step up to environments where actions have consequences across time! In gridworlds and many RL classics, the key concept is the Q-value: estimating the future value of every action in every state.
In this post, you’ll:
The Q-value is the expected sum of discounted rewards from taking action in state , and then following some policy thereafter.
Update rule:
Update rule:
where is the actual next action taken by the agent according to its policy.
Let’s use a tiny Gridworld with 4 states and 2 actions (LEFT, RIGHT):
import numpy as np
n_states = 4
n_actions = 2
Q = np.zeros((n_states, n_actions)) # Q[state, action]
print("Initial Q-table:")
print(Q)
Let’s define a small Gridworld, reward at the end, with transitions:
import random
class SimpleGridEnv:
def __init__(self, n=4):
self.n = n
self.reset()
def reset(self) -> int:
self.s = 0
return self.s
def step(self, a: int) -> tuple[int, float, bool]:
# a: 0=LEFT, 1=RIGHT
if a == 1:
self.s += 1
else:
self.s -= 1
self.s = np.clip(self.s, 0, self.n-1)
reward = 1.0 if self.s == self.n-1 else 0.0
done = self.s == self.n-1
return self.s, reward, done
def epsilon_greedy(Q, s, eps=0.2):
if random.random() < eps:
return random.choice(range(n_actions))
return int(np.argmax(Q[s]))
def train_q(env, Q, episodes=250, alpha=0.1, gamma=0.9, eps=0.2):
returns = []
for ep in range(episodes):
s = env.reset()
done = False
total = 0.
while not done:
a = epsilon_greedy(Q, s, eps)
s2, r, done = env.step(a)
Q[s, a] += alpha * (r + gamma * np.max(Q[s2]) - Q[s, a]) # Q-learning update
s = s2
total += r
returns.append(total)
return np.array(returns)
def train_sarsa(env, Q, episodes=250, alpha=0.1, gamma=0.9, eps=0.2):
returns = []
for ep in range(episodes):
s = env.reset()
a = epsilon_greedy(Q, s, eps)
done = False
total = 0.
while not done:
s2, r, done = env.step(a)
a2 = epsilon_greedy(Q, s2, eps)
Q[s, a] += alpha * (r + gamma * Q[s2, a2] - Q[s, a]) # SARSA update
s = s2
a = a2
total += r
returns.append(total)
return np.array(returns)
env = SimpleGridEnv(n_states)
Q1 = np.zeros((n_states, n_actions))
Q2 = np.zeros((n_states, n_actions))
rets1 = train_q(env, Q1)
rets2 = train_sarsa(env, Q2)
import matplotlib.pyplot as plt
plt.plot(np.cumsum(rets1) / (np.arange(len(rets1))+1), label="Q-Learning")
plt.plot(np.cumsum(rets2) / (np.arange(len(rets2))+1), label="SARSA")
plt.ylabel("Mean Return"); plt.xlabel("Episode")
plt.legend(); plt.title("Q-Learning vs SARSA (Gridworld)")
plt.grid(); plt.show()
from matplotlib import pyplot as plt
plt.imshow(Q1, cmap='cool', interpolation='nearest')
plt.colorbar(label="Q-value")
plt.title("Q-table for Q-learning (States x Actions)")
plt.xlabel("Action (0=Left, 1=Right)")
plt.ylabel("State")
plt.show()
import time
def run_policy(env, Q, max_steps=10, delay=0.4) -> None:
s = env.reset()
traj = [s]
for _ in range(max_steps):
a = int(np.argmax(Q[s]))
s2, r, done = env.step(a)
traj.append(s2)
print(f"State: {s} -> Action: {a} -> State: {s2} | Reward: {r}")
s = s2
if done:
break
time.sleep(delay)
print("Trajectory:", traj)
print("Animating Q-learning policy:")
run_policy(env, Q1)
Taxi-v3 is a classic Gym environment—use Q-learning to learn its optimal policy!
import gymnasium as gym
import numpy as np
env = gym.make("Taxi-v3", render_mode="ansi")
n_states = env.observation_space.n
n_actions = env.action_space.n
Q = np.zeros((n_states, n_actions))
episodes = 1800
alpha = 0.1
gamma = 0.98
eps = 0.15
lengths = []
for ep in range(episodes):
s, _ = env.reset()
done = False
count = 0
while not done:
if np.random.rand() < eps:
a = np.random.randint(n_actions)
else:
a = np.argmax(Q[s])
s2, r, terminated, truncated, _ = env.step(a)
done = terminated or truncated
Q[s, a] += alpha * (r + gamma * np.max(Q[s2]) - Q[s, a])
s = s2
count += 1
lengths.append(count)
import matplotlib.pyplot as plt
plt.plot(np.convolve(lengths, np.ones(50)/50, mode='valid'))
plt.xlabel("Episode")
plt.ylabel("Episode Length")
plt.title("Taxi-v3: Q-Learning Episode Length (lower is better)")
plt.show()
print("Mean episode length (last 100 episodes):", np.mean(lengths[-100:]))
Q[s, a]
with zeros.Q
(states by actions) as an imshow/heatmap.Taxi-v3
environment and Q-learning. Train for 1–2k episodes.import numpy as np
import random
import matplotlib.pyplot as plt
# Exercise 1
n_states, n_actions = 4, 2
Q = np.zeros((n_states, n_actions))
print(Q)
# Exercise 2/3
# (See above code!)
# Use SimpleGridEnv, epsilon_greedy, train_q, train_sarsa
# Plot learning curves and Q-table heatmap
# Exercise 4
env = SimpleGridEnv(n_states)
run_policy(env, Q1)
You now know how to:
Next, you’ll see how random sampling and partial returns let us estimate value without knowing the environment—Monte Carlo and TD learning.
See you in Part 4.5!