Hand-coding neural networks gives you intuition, but PyTorch’s torch.nn
module
is the professional toolkit—it provides higher-level abstractions, readable
code, and error-free scaling to deep architectures. In practice, nearly every RL
and ML practitioner uses nn.Module
for defining models.
In this post, you’ll:
nn.Module
from scratch, with all
definitions included.Let’s see why—and how—PyTorch’s object-oriented approach saves time and headaches.
Recall the two-layer (one hidden layer) network from before. In torch.nn
you
define each layer as a linear transformation, with weights and biases stored
for you:
When you use nn.Module
in PyTorch:
self.fc1 = nn.Linear(...)
).def forward(self, x): ...
),
chaining the operations in order.torch.save
/torch.load
).You’ll see: switching from tensor code to nn.Module
makes models more robust,
reusable, and production-ready.
nn.Module
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define a two-layer neural network fully inside a class
class SimpleNet(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None:
super().__init__()
self.fc1: nn.Linear = nn.Linear(input_dim, hidden_dim)
self.fc2: nn.Linear = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h: torch.Tensor = F.relu(self.fc1(x))
logits: torch.Tensor = self.fc2(h)
return logits
# Example: instantiate and print the model
model: SimpleNet = SimpleNet(2, 8, 2)
print(model)
Let’s train on some synthetic data and see how nn.Module
streamlines the
process.
import matplotlib.pyplot as plt
# Generate synthetic linearly separable data
torch.manual_seed(3)
N: int = 100
X: torch.Tensor = torch.randn(N, 2)
y: torch.Tensor = (X[:, 0] + X[:, 1] > 0).long()
model: SimpleNet = SimpleNet(2, 8, 2)
optimizer: torch.optim.Optimizer = torch.optim.Adam(model.parameters(), lr=0.07)
loss_fn: nn.Module = nn.CrossEntropyLoss()
losses: list[float] = []
for epoch in range(80):
logits: torch.Tensor = model(X)
loss: torch.Tensor = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch % 20 == 0 or epoch == 79:
print(f"Epoch {epoch}: Loss={loss.item():.3f}")
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("NN Training Loss (torch.nn)")
plt.grid(True)
plt.show()
It’s clear: training, updates, and device-handling are now concise and readable—no need to hand-manage gradients!
# Define a deeper feedforward network with two hidden layers
class DeepNet(nn.Module):
def __init__(self, input_dim: int, hidden1: int, hidden2: int, output_dim: int) -> None:
super().__init__()
self.fc1: nn.Linear = nn.Linear(input_dim, hidden1)
self.fc2: nn.Linear = nn.Linear(hidden1, hidden2)
self.fc3: nn.Linear = nn.Linear(hidden2, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h1: torch.Tensor = F.relu(self.fc1(x))
h2: torch.Tensor = F.relu(self.fc2(h1))
logits: torch.Tensor = self.fc3(h2)
return logits
model_deep: DeepNet = DeepNet(2, 16, 8, 2)
optimizer_deep: torch.optim.Optimizer = torch.optim.Adam(model_deep.parameters(), lr=0.05)
losses_deep: list[float] = []
for epoch in range(100):
logits: torch.Tensor = model_deep(X)
loss: torch.Tensor = loss_fn(logits, y)
optimizer_deep.zero_grad()
loss.backward()
optimizer_deep.step()
losses_deep.append(loss.item())
if epoch % 25 == 0 or epoch == 99:
print(f"[DeepNet] Epoch {epoch}: Loss={loss.item():.3f}")
plt.plot(losses_deep)
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.title("Deep NN Training Loss")
plt.grid(True)
plt.show()
# Save model weights to disk
torch.save(model_deep.state_dict(), "deepnet_weights.pth")
print("Weights saved to deepnet_weights.pth")
# Load weights into a new instance (architecture must match)
model_loaded: DeepNet = DeepNet(2, 16, 8, 2)
model_loaded.load_state_dict(torch.load("deepnet_weights.pth"))
print("Weights loaded. Sample output:", model_loaded(X[:5]))
nn.Module
MyNet
that subclasses torch.nn.Module
.forward()
method that passes an input tensor through your
network.import torch
import torch.nn as nn
import torch.nn.functional as F
class MyNet(nn.Module):
def __init__(self, input_dim: int = 2, hidden_dim: int = 6, output_dim: int = 2) -> None:
super().__init__()
self.fc1: nn.Linear = nn.Linear(input_dim, hidden_dim)
self.fc2: nn.Linear = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h: torch.Tensor = F.relu(self.fc1(x))
out: torch.Tensor = self.fc2(h)
return out
# Create dummy input and check shape
model: MyNet = MyNet()
x_sample: torch.Tensor = torch.randn(4, 2)
logits: torch.Tensor = model(x_sample)
print("Logits shape:", logits.shape) # Should be (4, 2)
MyNet
as your base model.# Using MyNet
optimizer: torch.optim.Optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
loss_fn: nn.Module = nn.CrossEntropyLoss()
x_batch: torch.Tensor = torch.randn(8, 2)
y_batch: torch.Tensor = torch.randint(0, 2, (8,))
logits: torch.Tensor = model(x_batch)
loss: torch.Tensor = loss_fn(logits, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Loss:", loss.item())
Try to code an equivalent manual approach and compare for yourself!
MyNet
class with 5 units, activations
between layers.class MyDeepNet(nn.Module):
def __init__(self, input_dim: int = 2, h1: int = 6, h2: int = 5, output_dim: int = 2) -> None:
super().__init__()
self.fc1: nn.Linear = nn.Linear(input_dim, h1)
self.fc2: nn.Linear = nn.Linear(h1, h2)
self.fc3: nn.Linear = nn.Linear(h2, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h1: torch.Tensor = F.relu(self.fc1(x))
h2: torch.Tensor = F.relu(self.fc2(h1))
return self.fc3(h2)
# Data
torch.manual_seed(0)
N: int = 100
X: torch.Tensor = torch.randn(N, 2)
y: torch.Tensor = (X[:,0] + X[:,1] > 0).long()
net: MyDeepNet = MyDeepNet()
optim: torch.optim.Optimizer = torch.optim.Adam(net.parameters(), lr=0.06)
losses: list[float] = []
for epoch in range(60):
logits: torch.Tensor = net(X)
loss: torch.Tensor = nn.functional.cross_entropy(logits, y)
optim.zero_grad()
loss.backward()
optim.step()
losses.append(loss.item())
import matplotlib.pyplot as plt
plt.plot(losses)
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.title("DeepNet Training Loss")
plt.grid(True); plt.show()
MyDeepNet
, save model weights to disk.MyDeepNet
, load the weights, and verify that predictions
are identical.# Save model
torch.save(net.state_dict(), "mydeepnet_weights.pth")
# Load into a new instance
net2: MyDeepNet = MyDeepNet()
net2.load_state_dict(torch.load("mydeepnet_weights.pth"))
# Check equality
out1: torch.Tensor = net(X[:5])
out2: torch.Tensor = net2(X[:5])
print("Predictions equal after reload:", torch.allclose(out1, out2))
In this part, you’ve experienced the transformational power of torch.nn and nn.Module. With just a few lines, you now:
Up next: You’ll explore and visualize nonlinear neural network building blocks—activation functions—and see how these unlock expressivity and speed up learning.
You’re now building neural nets “the real way.” Take pride in your object-oriented power—see you in Part 3.4!