Learning by doing is a joyful and effective experience, so let's build the AlexNet from scratch ⚡️
You can use Studio template I’ve created for Lightning.AI platform to replicate the code in this post with a single click to build and train your own AlexNet and tweak the internals to gain better understanding of convolutional neural networks. For your convenience, you will find an AlexNet adapted to MNIST dataset, so you can start right away!
AlexNet is an architecture of convolutional neural network, created by Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton at University of Toronto. It was published in NeurIPS proceedings in 2012 and became state-of-the-art model in image recognition, trained on ImageNet dataset
To bootstrap the project, we will use the ashleve/lightning-hydra-template, which is an easy-to-use template for starting PyTorch Lightning projects
Architecture overview
AlexNet consists of 5 convolutional layers, followed by 3 fully connected layers
ReLU is applied after each layer
Local response normalization (LRN) is applied to first and second convolutional layer. A good explanation of this technique you can find here (it's easier to grasp than using the formula, imho)
LRNs are followed by max-pooling layers, which is also present after 5th convolutional layer. In this operation, we pick the highest value from each patch. Here's a short explanation from where the following image comes from
First two fully-connected layers have a dropout method applied to their output, which is an normalization method preventing overfitting. Explanation and diagram is here at paperswithcode.com
Step by step implementation
Module
Let's start with an empty PyTorch Module. It's known as a base class for networks
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return None
The forward method is a place where we can define how model should perform computation in our network
Sequential
Instead of manually connecting layers one with another, we can use Sequential to create a model. It passes output values from each layer to input of next one
We bind model to self object, so we can define forward method now. It's a fairly basic, one since we just tell the Module to retrieve the model from self.model, put input data as a parameter and then return the output from the model
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential()
def forward(self, x):
x = self.model(x)
return x
First convolutional layer
As we saw before, the first conv layer accepts an input image. We use 3 input channels, which represent three colors - red, green and blue (RGB) values of image pixels
As described in the paper, we use a kernel of size 11x11 and stride with value 4, which means that in each step our kernel moves by 4 pixels
The output goes through rectifier activation function and then is response-normalized using LocalResponseNorm and pooled with MaxPool2d with kernel of size 3x3 and stride with value 2
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
)
def forward(self, x):
x = self.model(x)
return x
Second convolutional layer
It becomes less exciting at this point. Second layer is similar to the first one, the difference is in input values
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3):
super().__init__()
self.model = nn.Sequential(
# First conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Second conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
)
def forward(self, x):
x = self.model(x)
return x
Third convolutional layer
No LRN and no pooling here, just a regular convolutional layer with more channels
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3):
super().__init__()
self.model = nn.Sequential(
# First conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Second conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Third conv layer
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3)),
nn.ReLU(),
)
def forward(self, x):
x = self.model(x)
return x
Fourth convolutional layer
Same as the third one, but lowering amount of channels
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3):
super().__init__()
self.model = nn.Sequential(
# First conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Second conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Third conv layer
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3)),
nn.ReLU(),
# Fourth conv layer
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
)
def forward(self, x):
x = self.model(x)
return x
Fifth convolutional layer
The final convolutional layer with pooling. The output of this layer is then taken in by the first fully connected layer
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3):
super().__init__()
self.model = nn.Sequential(
# First conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Second conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Third conv layer
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3)),
nn.ReLU(),
# Fourth conv layer
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
# Fifth conv layer
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
)
def forward(self, x):
x = self.model(x)
return x
First fully-connected layer
Now that's something new. We need to calculate in_features for Linear, which applies a basic transformation to input data
Then we use a dropout, to reduce overfitting and ReLU as an activation function
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3, first_fc_in_features=9216):
super().__init__()
self.model = nn.Sequential(
# First conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Second conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Third conv layer
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3)),
nn.ReLU(),
# Fourth conv layer
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
# Fifth conv layer
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# First fully-connected layer with dropout
nn.Linear(in_features=first_fc_in_features, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
)
def forward(self, x):
x = self.model(x)
return x
Second fully-connected layer
The same thing, but the number of features is fixed now
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3, first_fc_in_features=9216):
super().__init__()
self.model = nn.Sequential(
# First conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Second conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Third conv layer
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3)),
nn.ReLU(),
# Fourth conv layer
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
# Fifth conv layer
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# First fully-connected layer with dropout
nn.Linear(in_features=first_fc_in_features, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
# Second fully-connected layer with dropout
nn.Linear(in_features=4096, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
)
def forward(self, x):
x = self.model(x)
return x
Put all the code together
Our implementation finishes at last, third fully-connected layer with 1000 classes with softmax applied on top of it
And we did this - the full code is here
from torch import nn, flatten
class AlexNet(nn.Module):
def __init__(self, channels=3, first_fc_in_features=9216):
super().__init__()
self.model = nn.Sequential(
# First conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Second conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# Third conv layer
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3)),
nn.ReLU(),
# Fourth conv layer
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
# Fifth conv layer
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# First fully-connected layer with dropout
nn.Linear(in_features=first_fc_in_features, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
# Second fully-connected layer with dropout
nn.Linear(in_features=4096, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
# Third fully-connected layer
nn.Linear(in_features=4096, out_features=1000),
nn.LogSoftmax(dim=1),
)
def forward(self, x):
x = self.model(x)
return x
Thanks for our journey!
While we could end here, let's consider a few more things that you might find interesting
224 vs 227
There was a discussion about size of an image. In an original paper, 224 x 224 x 3 is used, but it doesn't fit the math, so often 227 is used instead
AlexNet for MNIST
To make it easier to experiment with AlexNet, I prepared a version tailored to MNIST dataset. Here's the implementation and below you will find a commands you can run to train your modelAs you can see, a number of channels drop to 1, since MNIST images are in a grayscale. Images are also smaller than an original ImageNet images, so it's reflected in a network as well. One additional thing is Flatten after the last convolutional layer and smaller amount of output classes (since we have only 10 digits to classify)
import torch
from torch import nn
class AlexNetForMNIST(nn.Module):
"""Adapted version of AlexNet for MNIST dataset images resized to 64x64."""
def __init__(self, channels=1):
super().__init__()
first_fc_in_features = 256 * 2 * 2
self.model = nn.Sequential(
# 1st conv layer
nn.Conv2d(
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=2,
padding=0,
),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# 2nd conv layer
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), padding=2),
nn.ReLU(),
nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# 3rd conv layer
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3), padding=1),
nn.ReLU(),
# 4th conv layer
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3, 3), padding=1),
nn.ReLU(),
# 5th conv layer
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
nn.Flatten(),
# 1st fc layer with dropout
nn.Linear(in_features=first_fc_in_features, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
# 2nd fc layer with dropout
nn.Linear(in_features=4096, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
# 3rd fc layer
nn.Linear(in_features=4096, out_features=10),
nn.LogSoftmax(dim=1),
)
def forward(self, x):
x = self.model(x)
return x
Run with python src/train.py experiment=mnist_alexnet