/ 5 min read
Softmax to MNIST: Building a Tiny Autograd and Classifier
The goal here is to compile all the findings, including low level (math) so we can go deep down the rabbit hole and high level (intuition) so we can generalize easier about the computation.
This post walks through the math and implementation of axis-aware, numerically stable softmax and log-softmax, their vectorβJacobian products (backward pass), and the Negative Log-Likelihood loss. We finish with a simple MLP trained on MNIST using our own autograd engine.
Notation
- : logits (pre-softmax scores) from the final linear layer. In an MLP, .
- : softmax probabilities, , with components .
- : target vector (assumed one-hot here), and .
- : class axis (dimension along which probabilities sum to 1 and reductions occur).
- : per-axis maximum for stability, .
- : log-probabilities (log-softmax outputs), .
- : upstream gradient from the loss (same shape as the quantity it differentiates, e.g., or ).
- : weights, bias, and previous-layer activations feeding the final linear layer.
Softmax: definition, stability, and gradient
For logits and class axis , softmax is
Intuition note
Softmax turns arbitrary scores into a probability distribution: all entries become non-negative and sum to 1 along the class axis. Imagine classes βcompetingβ for a fixed 100% of belief: raising one score necessarily lowers others.
To avoid overflow we use a per-axis shift (log-sum-exp trick): let . Then
VectorβJacobian product. With upstream gradient and :
Here denotes elementwise multiplication. This captures Jacobian without materializing it.
Intuition note
Subtracting the per-axis max doesnβt change the probabilities (the factor cancels). It only prevents from blowing up.
Let denote the logits (pre-softmax scores) from the last linear layer. In an MLP, and the predicted class probabilities are . We will later instantiate this with our MNIST MLP, but the notation applies to any model producing pre-softmax scores.
Softmax partial derivatives (component-wise sensitivities):
Intuition note
Increasing one logit boosts its own probability , but because all probabilities must still sum to 1, other must decreaseβhence the negative cross terms.
Derivation
Using the softmax partials above and cross-entropy (one-hot ), the output-layer error simplifies neatly. For one example with :
Intuition note
If the model assigns too much probability to a wrong class, the gradient there is positive (push it down). If it assigns too little probability to the true class, the gradient there is negative (push it up). The signal is literally βwhat you predicted minus what it should be.β
Small numeric example: for logits ,
If the one-hot target is , then .
Log-softmax: stable forward, simple backward
Log-softmax normalizes in log space:
Stable form with :
Let and . The VectorβJacobian product is:
Intuition: we subtract the component of that lies along the all-ones direction on the simplex (hence ).
Negative Log-Likelihood (NLL) and Cross-Entropy
For one-hot targets and log-probabilities ,
Cross-Entropy with logits is just
which yields the classic gradient
All formulas above are axis-aware (we reduce along the chosen class axis with keepdims=True
), which makes broadcasting correct in both forward and backward passes.
Operations code can in my autograd repository:
Tests-first: numerical checks and invariants
During development I tried my best to follow TDD. I have tested quite a few cases, some of them:
- Finite-difference checks for softmax/log-softmax/NLL along both .
- Stability: probabilities sum to 1; log-softmax stays finite for extreme logits.
- Invariants: the softmax/log-softmax gradients sum to zero along the class axis.
You can find all of them here.
MNIST: a minimal MLP baseline
Off to a training. I decided to give it a try on a MNIST dataset. We trained a one-hidden-layer MLP (784β256β10, tanh) with cross-entropy (log-softmax + NLL). Tried many configurations including tweaking learning rate, changing activation function to ReLU, but MLP with tanh activation was getting the best results.
Kaggle submission from this MLP achieved 0.955 accuracy which is a good result for such a small MLP.
You can find learning demo in my autograd repository