“In a Neural ODE, the network defines the velocity. Learning is learning the flow.”
An ordinary differential equation (ODE) specifies the time evolution of a system via:
where is the state at time , and is a vector field.
Let’s solve the “spiral” ODE:
for , , which produces spirals to the origin.
# calc-12-ode/spiral_ode_numpy.py
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
alpha = -0.3
beta = 1.1
def f(t, h):
x, y = h
dx = alpha * x - beta * y
dy = beta * x + alpha * y
return [dx, dy]
t_span = [0, 8]
h0 = [2, 0.5]
sol = solve_ivp(f, t_span, h0, t_eval=np.linspace(*t_span, 400))
plt.plot(sol.y[0], sol.y[1], label="trajectory")
plt.title("Spiral ODE trajectory")
plt.xlabel("x"); plt.ylabel("y")
plt.grid()
plt.axis("equal")
plt.show()
torchdiffeq
)We’ll train a Neural ODE to fit spiral trajectories. You need torchdiffeq
:
pip install torchdiffeq
# calc-12-ode/neural_ode_spiral.py
import torch
import matplotlib.pyplot as plt
from torchdiffeq import odeint
torch.manual_seed(42)
device = torch.device("cpu")
# --- Spiral data generation
alpha, beta = -0.4, 1.2
def true_field(t, h):
x, y = h[..., 0], h[..., 1]
dx = alpha * x - beta * y
dy = beta * x + alpha * y
return torch.stack([dx, dy], -1)
t = torch.linspace(0, 7, 160)
h0 = torch.tensor([2.0, 0.7])
with torch.no_grad():
true_traj = odeint(true_field, h0, t)
# --- Neural ODE model
class ODEFunc(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(2, 32),
torch.nn.Tanh(),
torch.nn.Linear(32, 2),
)
def forward(self, t, h):
return self.net(h)
odefunc = ODEFunc()
optimizer = torch.optim.Adam(odefunc.parameters(), lr=0.01)
# --- Training loop
for epoch in range(250):
pred_traj = odeint(odefunc, h0, t)
loss = ((pred_traj - true_traj) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 40 == 0:
print(f"Epoch {epoch:3d} | Loss: {loss.item():.5f}")
# --- Plot
plt.plot(true_traj[:, 0], true_traj[:, 1], label="True", lw=3)
plt.plot(pred_traj.detach()[:, 0], pred_traj.detach()[:, 1], "--", label="Neural ODE", lw=2)
plt.legend()
plt.xlabel("x"); plt.ylabel("y")
plt.title("Neural ODE fits Spiral Trajectory")
plt.axis("equal")
plt.tight_layout()
plt.show()
torch.cat([h, t*torch.ones_like(h[..., :1])], dim=-1)
.Put solutions in calc-12-ode/
and tag v0.1
.
Next: Calculus 13 — Partial Differential Equations, SDEs, and Diffusion Models in Machine Learning.