The inspiration for this post came from Dave’s Garage YouTube Channel. It featured dbrll/ATTN-11, which used machine learning on a PDP-11 to reverse a sequence of numbers. I’m not (quite) old enough to remember the PDP-11, but the idea of a minimal machine learning example was intriguing, particularly as an aid to understanding.
This post takes you through the mtelvers/attn-ox OxCaml program that learns to add single hex digits. The excercise was here was for me to understand the building blocks of machine learning. Claude wrote the code. The system is small enough that everything is concrete: every line of code maps to one of the boxes in the diagrams below.
1. The problem
Write a function that takes two single hex digits x and y and returns the sum x + y, e.g. f + f = 1e.
Rather than a single line of code, the machine-learning approach starts with random floats, which are updated by gradient descent until the function gets all 256 inputs right.
The idea is to see the smallest possible version of “a model learns from examples” running end-to-end. The architecture of transformer, attention, layer norm, and residual connections are the same as used in much larger models.
2. Overview
The model is not the kind of function you would usually write for addition (something like int -> int -> int). It is a function on token sequences:
input : an array of 8 integers (token IDs from a 32-token vocabulary)
output : an array of 8 probability distributions, each over the 32 tokens
We encode the question x + y = ? as a sequence of 8 token IDs:
positions: 0 1 2 3 4 5 6 7
ids: BOS x + y = c1 c2 PAD
The 16 hex digits get token IDs 0..15 (so the digit 0 is token 0, digit f is token 15). Then four reserved tokens follow with the next four IDs: + = 16, = = 17, BOS (begin of sequence) = 18, PAD = 19. The vocabulary contains 20 used tokens, but we round up to 32 to the next power of 2. The unused slots 20..31 never appear in inputs or as predicted answers.
For input 8 + a, we already know the answer is 12 (= 8 + 10 = 18, and in hex 18 is written 12). The training sequence is:
ids = [18; 8; 16; 10; 17; 1; 2; 19] (BOS, 8, +, a, =, 1, 2, PAD)
We hand the whole sequence (including the answer) to the model during training. The model produces 8 probability distributions, one at each position, each predicting “what comes next?”. A probability distribution here is a row of 32 non-negative floats that sum to 1. There is one float per vocabulary token, expressing the model’s confidence that this token is next.
Here is the actual 8 × 32 output of a trained model on our 8 + a example. Each row is a probability distribution; each row sums to 1.
0 1 2 3 4 5 6 7 8 9 a b c d e f + = BOS PAD 20 .. 31
pos 0 (BOS) 0.99 · · · · · · · · · · · · · · · · · · · · .. ·
pos 1 (8 ) 0.99 · · · · · 0.01 · · · · · · · · · · · · · · .. ·
pos 2 (+ ) 0.86 0.02 · · · · · 0.08 0.03 · · · · · · · · · · · · .. ·
pos 3 (a ) 1.00 · · · · · · · · · · · · · · · · · · · · .. ·
pos 4 (= ) · 1.00 · · · · · · · · · · · · · · · · · · · .. ·
pos 5 (c1 ) · · 0.99 · · · · · · · · · · · · · · · · · · .. ·
pos 6 (c2 ) 1.00 · · · · · · · · · · · · · · · · · · · · .. ·
pos 7 (PAD) 1.00 · · · · · · · · · · · · · · · · · · · · .. ·
At position 4, the model assigns a probability of 1.00 to token 1 (the high digit of 12), and at position 5, it assigns 0.99 to token 2 (the low digit). Predictions at other positions can be arbitrary, as these are not scored during training.
This concept of predicting the next token but only scoring the ones we care about is how GPT-style language models are trained.
3. Pipeline
At the top level, the model is a pipeline of three stages: embedding, a transformer block, and an output projection.
ids (length 8)
│
▼
┌────────────────┐
│ Embedding │ integers → vectors (§5)
└────────┬───────┘
│ X (8 × 32)
▼
┌────────────────┐
│ Transformer │ vectors → vectors
│ block │ (we'll open this up in §7)
└────────┬───────┘
│ Y (8 × 32)
▼
┌────────────────┐
│ Output │ vectors → predictions (§6)
│ projection │
└────────┬───────┘
│ logits (8 × 32)
▼
softmax + cross-entropy at positions 4 and 5 (§11)
Two numbers govern almost every shape inside.
seq_len = 8is the length of every input sequence.d_model = 32is the width of the model’s internal representation.
Every token, once embedded, is described by a vector of 32 floats. Every intermediate matrix inside the model is seq_len × d_model = 8 × 32. Larger d_model means more capacity but more compute and memory; 32 is enough for this task.
Hyperparameters are settings we pick before training and parameters are the floats that gradient descent adjusts. The full set of hyperparameters lives in Train.default_config.
| value | |
|---|---|
vocab_size |
32 |
seq_len |
8 |
d_model |
32 |
n_heads |
2 |
head_dim |
16 |
d_ff |
128 |
We will see the use of each of these in the coming sections.
The three boxes map to OCaml modules:
| Box | File |
|---|---|
| Embedding | lib/embedding.ml |
| Transformer block | lib/transformer_block.ml, which composes layernorm.ml, attention.ml, ffn.ml |
| Output projection | lib/model.ml |
4. Tensors
Every intermediate value is a 2D matrix of floats, a Tensor.t:
type t = { rows : int; cols : int; data : float array }
data holds the cells row by row, so element (i, j) lives at data.(i * cols + j). Array.make_matrix or Bigarray.Array2 would also work; we use a flat float array. The fields are exposed so the kernels in ops.ml can index into data directly.
lib/ops.ml provides the building-block kernels:
| Function | Computes | Used by |
|---|---|---|
matmul |
C = A @ B |
every linear layer |
matmul_at_b_t |
C = A @ B^T |
attention scores, output projection |
accum_xt_y |
dW += X^T @ dY |
every linear layer’s backward |
accum_y_wt |
dX += dY @ W^T |
every linear layer’s backward |
add_inplace |
dst += src |
residual connections |
softmax_rows |
row-wise softmax | attention |
apply_causal_mask |
zero out upper triangle | attention |
Section 19 (SIMD) shows how these are implemented.
5. Embedding: turn integers into vectors
The first box turns each integer token into a 32-element float vector (the d_model = 32 from §3).
The embedding has two learned matrices:
E (32 × 32) one row per token in the vocabulary, 32 floats per row
P (8 × 32) one row per sequence position
Both are initialised to small random numbers (Gaussian, scale 0.02).
For input ids = [18; 8; 16; 10; 17; 1; 2; 19] (our 8 + a example), the embedded matrix X (the same X that leaves the embedding box in §3’s diagram) is built row by row:
X[i] = E[ids[i]] + P[i] for each i
The model needs both E (which token is at position i) and P (which position this is). Without P, the model would treat [8, +, a] and [a, +, 8] identically. There is no other source of order information in the architecture.
E and P are learned. They start as random noise. After training, rows of E corresponding to digit tokens end up encoding the token’s numeric value in some 32-dimensional way the rest of the model can use.
(Implementation: Embedding.forward and Embedding.backward in lib/embedding.ml.)
6. Output projection
At the other end of the model, after the transformer block, we have a matrix Y' of shape (8, 32). That’s one 32-element vector per sequence position. The
output projection turns each 32-element vector into 32 logits, one per vocabulary token.
logits[i] = Y'[i] · E^T (using E from §5, transposed)
A logit is just an unnormalised score: any real number, positive or negative, big or small. By itself, it doesn’t mean a probability. To turn a row of 32 logits into a probability distribution (32 non-negative values summing to 1), we apply softmax. We will look at softmax in more detail in §11, where the model’s error is computed; it’s also what produced the 8 × 32 matrix you saw right at the start in §2.
We use the embedding matrix E as the output projection, which is called weight tying. It saves parameters and ties the input and output representations of each token.
The backward pass for the embedding has to add contributions from both its uses (lookup and unembedding):
(* Model.backward, model.ml *)
Ops.matmul ~a:d_logits ~b:e ~c:d_y_norm; (* dY = d_logits @ E *)
Ops.accum_xt_y ~x:d_logits ~dy:t.cache_y_norm ~dw:de;
(* dE += d_logits^T @ Y *)
... later ...
Embedding.backward t.embed ~d_out:d_x; (* dE += scatter from inputs *)
Now that we have covered the input and output, we will now look at the transformer.
7. Transformer
Every operation in the transformer operates on an 8 × 32 matrix and produces a new 8 × 32 matrix. X is generated by the embedding stage and Y' is used by the output projection.
X (8 × 32)
│
┌───────┤
│ ▼
│ LayerNorm
│ │
│ ▼
│ Attention
│ │
│ ▼
└─────► ⊕ X' = X + Attention(LayerNorm(X))
│
┌───────┤
│ ▼
│ LayerNorm
│ │
│ ▼
│ Feed-forward
│ │
│ ▼
└─────► ⊕ Y = X' + Feed-forward(LayerNorm(X'))
│
▼
Y (8 × 32)
│
--- end of transformer block ---
│
▼
LayerNorm one more stabilising step
│ before the output projection
▼
Y' (8 × 32) this is what §6 receives
X is processed through the LayerNorm and Attention steps, and then X is added in creating a residual connection. The input to the second stage is X' = X + Attention(LayerNorm(X)). This pattern is repeated with X' being processed via LayerNorm and Feed-forward before being accumulated again with Y = X’ + Feed-forward(LayerNorm(X’)).
The residual connection matters, as it provides an identity highway to carrying X straight through, with the operation’s output added in. The model can ignore an operation entirely by training it to output near-zero. Gradients flow back along the highway in addition to through each operation, which makes optimisation much easier.
The two +s in the diagram are literally these calls in the code:
(* Transformer_block.forward, transformer_block.ml *)
Ops.add_inplace ~dst:t.cache_x_mid ~src:t.cache_attn_out; (* X' = X + Attn(LN1(X)) *)
Ops.add_inplace ~dst:out ~src:t.cache_ffn_out; (* Y = X' + FFN(LN2(X')) *)
The next three sections explain LayerNorm, Attention, and Feed-forward one at a time.
8. LayerNorm
LayerNorm operates independently on each row of its input matrix. Let r be one such row, a 32-element vector (i.e. X[i] for some i). The output row r' is:
mean = sum(r) / 32
var = sum((r - mean)^2) / 32
r'[j] = γ[j] * (r[j] - mean) / sqrt(var + 1e-5) + β[j]
γ and β are learned 32-element vectors, one of each and are shared across every row of the input. Initialised to 1 and 0, respectively. So a LayerNorm has 64 trainable floats total, regardless of how many rows it sees. The rest of the operation has no learned parameters; it just rescales each row to have a mean of 0 and a variance of 1, then applies the same learned per-feature scale and shift to every row.
LayerNorm is a stabiliser that keeps the floats inside the model from drifting to large or too small during training.
LayerNorm is used at three places in the model, all visible in the §7 diagram: before Attention, before Feed-forward, and once more between the transformer block’s output Y and the output projection.
(Implementation: Layernorm.forward and Layernorm.backward, with backward derived from the standard chain rule.)
9. Attention
Attention is what makes a transformer a transformer!
9.1 The mechanism
You start with X, the 8 × 32 matrix from the output of LayerNorm. You compute three linear projections of it:
Q = X @ Wq (8 × 32)
K = X @ Wk (8 × 32)
V = X @ Wv (8 × 32)
Wq, Wk, Wv are 32 × 32 learned matrices. Q, K, V are three different views of the same input. The names come from a database analogy:
Q[i]is the query at positioni. Think of it as “what is positionilooking for?”K[j]is the key at positionj. Think of it as “what does positionjadvertise?”V[j]is the value at positionj. Think of it as “what would positionjcontribute if attended to?”
We compute a scalar score for every (query position, key position) pair:
scores[i, j] = Q[i] · K[j] / sqrt(d_k) (8 × 8)
d_k is the length of the vectors. From the shapes above, Q[i] and K[j] both have 32 elements, so d_k = 32.
Each scores[i, j] is a single number, the dot product of row i of Q with row j of K, then divided by sqrt(d_k). Collecting them gives an 8 × 8 matrix scores, with one cell per pair of positions. Higher score means “row i cares more about row j”.
For any pair where j > i, set the score to -∞, creating a causal mask. This prevents position i from looking at anything that comes after it, which is required for next-token prediction (otherwise the model could just look ahead at the answer).
Apply softmax to each row of the scores matrix to convert it into a probability distribution:
probs[i, j] = exp(scores[i, j]) / Σ_k exp(scores[i, k])
After softmax, each row of probs sums to 1. Masked entries (where the score was -∞) become exactly 0. The rest are non-negative and form a probability distribution over the earlier-or-equal positions.
Finally, use those weights to take a weighted average of the values:
out[i] = Σ_j probs[i, j] * V[j] (8 × 32)
This Q/K/V/scores/softmax process results in output matrix where row i is the weighted average of the value rows from earlier positions, where those weights come from the query-key dot products.
This allows position i to pull information from any earlier position by assigning it a high probability.
9.2 Multiple heads
A head is one independent run of the whole §9.1 mechanism (Q/K/V projections, scores, softmax, weighted sum of V). Multi-head attention splits the matrix into multiple slices and operates on them in parallel. This allows the model to perform multiple types of routing simultaneously. §17 shows only one of our two heads actually doing work, while the other remains diffuse. We could probably train this model with n_heads = 1 and it would still converge, but 2 is the typical transformer pattern, and because §17 is more instructive with two heads to compare side by side.
In our model, the hyperparameter defines n_heads = 2, so we split the 32 columns of Q, K, V down the middle:
- Head 0 uses columns 0..15 of Q, K, V (
head_dim = 16). - Head 1 uses columns 16..31.
Each head dot-products and softmaxes only within its own 16-element slice, then takes its own weighted sum of V. So d_k from §9.1 is now 16 (the per-head slice width), which is why the formula uses sqrt(16) in practice.
Each head produces its own 8 × 16 output. The two are concatenated back to 8 × 32 (head 0’s columns on the left, head 1’s on the right):
concat(out_head_0, out_head_1) (8 × 32)
But that’s not the final answer. Without one more step, head 0’s output would always live in columns 0..15 and head 1’s in columns 16..31; the two would never combine. So we apply a learned 32 × 32 matrix Wo that mixes the columns:
Y = concat(out_head_0, out_head_1) @ Wo (8 × 32)
Wo lets the model learn how to combine the heads’ outputs. It’s the fourth and last learned matrix in attention (alongside Wq, Wk, Wv).
The total parameter count for the attention block is 4 × (32 × 32) = 4096 floats (Wq, Wk, Wv, Wo).
(Implementation: Attention.forward and Attention.backward in lib/attention.ml. Per-head slicing is done by column-offset arithmetic into the same flat tensors, with no explicit reshape.)
10. Feed-forward (FFN)
The simpler of the two operations inside the transformer block is called the Feed-Forward Network or FFN where network is just a generic word for a stack of learnable layers (as in “neural network”).
For each row of the input matrix independently:
hidden = X @ W1 (32 → 128)
hidden = GELU(hidden)
out = hidden @ W2 (128 → 32)
W1is a learned 32 x 128 matrix, andW2is a learned 128 x 32 matrix in the same way asWq, Wk, Wv, Wofrom §9 are.X @ W1widens each 32-element row to 128 floats;hidden @ W2narrows it back to 32.GELUis an element-wise non-linear function. Applied to a vector, it transforms each element independently. It passes large positive values through unchanged, squashes large negative values toward zero, and smoothly interpolates between. The “G” is for Gaussian as it is based on the Gaussian distribution.
The non-linearity is the whole point of the Feed-forward block. Without something non-linear between W1 and W2, the full computation (X @ W1) @ W2 would collapse to X @ (W1 @ W2) as just one bigger matrix multiplication. GELU between them is what stops that collapse and gives this block expressive power that attention’s linear projections alone don’t have.
The 128 dimension comes from the hyperparameter table and is d_ff = 4 × d_model. This is the conventional ratio for transformer FFN blocks, and it could shrink or expand trading capacity relative to compute.
The Feed-forward parameter count is 32 × 128 + 128 × 32 = 8192 floats, which is bigger than the whole attention block and is typically the case in real transformers where the FFN parameters dominate.
(Implementation: Ffn.forward and Ffn.backward in lib/ffn.ml.)
11. Loss
§4–§10 described the forward pass, showing how the model turns input ids into 8 × 16 logits that represent the model’s prediction. But during training, we need a way to score that prediction, which is called the loss. The gradient descent step (§13) will adjust the model’s parameters to make the loss go down.
The standard choice of calculating the loss in a probability distribution over a discrete set is called softmax cross-entropy. For each scored position (positions 4
and 5, the ones flagged mask[i] = true in our setup, where c1 and c2 are the targets):
p[i] = softmax(logits[i]) (a probability over 32 tokens)
loss_i = -log(p[i][target[i]]) (negative log-prob of the right token)
target[i] is the correct token at position i, the one we want the model to predict. For the 8 + a = 12 example, target[4] = 1 (the high digit) and target[5] = 2 (the low digit). p[i][target[i]] is the probability the model assigned to that correct token.
Concretely, look back at the 8 × 32 matrix in §2.
p[4]was[·, 1.00, ·, ·, ...]. The model put 1.00 on token 1 at position 4. Sincetarget[4] = 1, that’s the correct token, sop[4][target[4]] = 0.999andloss_4 = -log(0.999) ≈ 0.001.p[5]was[·, ·, 0.99, ·, ...], 0.993 on token 2. Sincetarget[5] = 2, that’s correct too, sop[5][target[5]] = 0.993andloss_5 = -log(0.993) ≈ 0.007.
Total loss for this example is the mean of loss_4 and loss_5: (0.001 + 0.007) / 2 ≈ 0.004. (The total loss across the dataset is the mean of loss_i over all scored positions across all examples in the batch.)
A few reference points for calibrating the loss number:
- If the model assigns 100% probability to the correct token,
loss = 0. - If 50%,
loss ≈ 0.69. - If 1%,
loss ≈ 4.6. - A uniform 1/32 prediction gives
loss = log(32) ≈ 3.47.
The gradient of the loss with respect to the logits has a closed form that doesn’t need the log numerically:
d_logits[i] = (p[i] - one_hot(target[i])) / num_scored (if mask[i])
= 0 (otherwise)
one_hot(j) is a 32-element vector with a 1 at index j and 0 everywhere else. For our example, target[5] = 2 (the digit “2” is the right answer at position 5), so one_hot(target[5]) = one_hot(2) = [0, 0, 1, 0, 0, ..., 0]. The marker 1 lands at index 2 because that is where target points.
Subtracting one_hot(target[i]) from p[i] decreases the probability at the correct token by 1 and leaves the others unchanged. The gradient of d_logits[i] is therefore positive everywhere except at the right token (where it’s negative), which pushes the model to increase the probability at the right token and decrease it at every other token, exactly what we want.
This is the starting point of the backward pass.
(Implementation: Loss.forward_and_grad in lib/loss.ml.)
12. Backward pass
We have the loss (§11) as a single floating-point number summarising how wrong the model is. Gradient descent (§13) will update each of the 13,760 parameters in the direction that decreases the loss. To do that, for every single parameter it needs, a number saying “if you nudge me up by ε, the loss changes by gradient × ε”. That number is the gradient of the loss with respect to that parameter. We need 13,760 of them, one per parameter. Computing all of them is the backward pass (or backprop).
The loss depends on the logits; the logits depend on Y'; Y' depends on Y; Y depends on the block’s parameters and on X; X depends on the embedding tables and on the input ids. Each depends-on is a function we already wrote in the forward pass. The chain rule says: differentiate one step at a time from the loss backward function, multiplying derivatives as we go. Apply it to every operation in the forward pass, in reverse order, and we end up with a gradient for every parameter.
Every block in this model has a backward function that does the chain rule, with shapes and indices written out. Verbose, but it makes the math visible. And test/grad_check.ml verifies every step numerically against finite-difference perturbations of the inputs (so a sign error or off-by-one anywhere fails a test).
Each block’s backward function takes one argument and produces one result. The input is d_out, the gradient of the loss with respect to the block’s output. It was computed by whatever ran after this block (the next block’s backward function, or the loss directly for the last block). The output is d_in, the gradient of the loss with respect to the block’s input. It is passed to whatever ran before this block. It also accumulates gradients into the block’s parameter tensors (e.g. dWq for the attention’s Wq).
Each forward pass caches whatever it will need for its backward pass. LayerNorm caches the per-row mean and 1/std. Attention caches Q, K, V, and the post-softmax probabilities. FFN caches the pre-GELU activations. The backward function then runs the chain rule using those cached values, never re-running forward.
The full reverse trip:
d_logits (from softmax-CE in lib/loss.ml)
│
│ Output projection: d_y_norm = d_logits @ E
│ dE += d_logits^T @ Y'
▼
d_y_norm
│ Final LN backward
▼
d_y (gradient at block output)
│ Block backward (chain through both residuals)
▼
d_x (gradient at block input)
│ Embedding backward
▼
─── done ───
every weight matrix has a gradient
A residual connection produces two gradient paths back to its input. One goes through the sublayer, the other through the identity. Both contribute, and they sum at the input. The block backward in lib/transformer_block.ml makes this explicit.
Ffn.backward t.ffn ~d_out ~d_in:d_ln2_out;
Layernorm.backward t.ln2 ~d_out:d_ln2_out ~d_in:d_x_mid;
Ops.add_inplace ~dst:d_x_mid ~src:d_out; (* residual: dX' += d_out *)
Attention.backward t.attn ~d_out:d_x_mid ~d_in:d_ln1_out;
Layernorm.backward t.ln1 ~d_out:d_ln1_out ~d_in;
Ops.add_inplace ~dst:d_in ~src:d_x_mid; (* residual: dX += d_x_mid *)
13. Optimiser
“Every parameter” means every individual float, all 13,760 of them. Each cell of every learned matrix (E, P, Wq, Wk, Wv, Wo, W1, W2, plus the LayerNorm γs and βs) gets its own gradient and its own update. The formulas below all apply element-wise, scalar by scalar.
The simplest update rule is
param := param - lr * grad
lr is a small positive number (around 0.001 in our setup) called the learning rate that controls how big each step is. Too small and training crawls; too big and the loss bounces around or diverges instead of settling. The minus sign moves the parameter against the gradient, since the gradient points up the loss surface and we want to go down.
That formula is plain gradient descent. In practice, transformers train better with AdamW, which keeps a per-parameter running average of the gradient and its square:
m := 0.9 * m + 0.1 * grad
v := 0.999 * v + 0.001 * grad * grad
m_hat := m / (1 - 0.9^step)
v_hat := v / (1 - 0.999^step)
param := param - lr * (m_hat / (sqrt(v_hat) + 1e-8) + 0.01 * param)
The m_hat / sqrt(v_hat) term normalises by gradient magnitude, so different parameters can move at appropriate rates without manual tuning. The final + 0.01 * param is weight decay, a gentle pull toward zero that prevents weights from drifting unbounded.
(Implementation: Adam.step in lib/adam.ml.)
14. Training loop
In the training loop, a step is one parameter update, one application of the AdamW rule from §13. Over the entire run, we take 5000 steps. A batch is the group of training examples processed together within a single step. Their gradients are averaged before AdamW runs. We use batch_size = 16, so each step processes 16 examples. (A single example would give a noisy gradient estimate; averaging 16 smooths it out.)
lib/train.ml. For each step:
- Zero all parameter gradients.
- For each of the 16 examples in the batch:
- Run forward → logits.
- Compute loss and
d_logits. - Run backward → gradients added into every parameter’s
gradtensor.
- Divide every gradient tensor by
batch_size(= 16), so we have the mean per-example gradient instead of the sum. - Apply the AdamW update rule (§13) to every parameter, using those averaged gradients.
The lr (from §13) doesn’t stay at 0.001 the whole way through. For the first 50 steps, it ramps linearly from 0 up to 0.001, a learning-rate warmup. AdamW’s running averages m, v need a few steps to stabilise; taking large early steps before they have meaningful values can throw the model into a bad region. Warmup avoids that.
Every 250 steps, the trainer evaluates all 256 examples and prints the loss and accuracy. (5000 steps × 16 examples per step = 80,000 forward and backward passes during the whole run, which takes about 50 seconds.)
15. Training vs inference
Training and inference run almost the same code: both call Model.forward, which produces logits from ids. The difference is what else we keep around.
During training we allocate three things in addition to the parameters:
| What | Why | Size |
|---|---|---|
AdamW state m, v per parameter |
running averages for the optimizer | 2 × 13,760 floats |
| Gradient buffer per parameter | accumulator for backward | 1 × 13,760 floats |
| Forward activation caches per block | inputs needed by backward | a few KB |
So the per-step memory footprint is roughly 4 × the parameter count, plus caches. For 13,760 parameters, that’s tiny, but at the scale of a real LLM (billions of parameters), this 4× overhead is the dominant constraint on what hardware you need to train on. It’s why you can deploy a 7B model on a laptop, but you can’t train one there.
For inference, we need only the trained parameter values:
E 32 × 32 = 1,024 floats (used twice: input embedding + output projection)
P 8 × 32 = 256 floats
Wq, Wk, Wv, Wo 32 × 32 = 1,024 each, 4,096 total
W1 32 × 128 = 4,096 floats
W2 128 × 32 = 4,096 floats
γ, β × 3 32 × 2 × 3 = 192 floats
─────────
total 13,760 floats ≈ 110 KB on disk at 8 bytes/float
To use the trained model:
1. encode the question: ids = [BOS; 8; +; a; =; PAD; PAD; PAD]
2. logits = Model.forward model ~ids (same code as training)
3. answer_c1 = argmax of logits row 4
4. answer_c2 = argmax of logits row 5
The gradients aren’t computed, the optimiser isn’t called, and the caches aren’t needed. The model itself is the same code; we just stop calling Model.backward.
In this project we don’t actually save and reload the trained parameters to disk; we train and infer in the same process. Doing so would be a few lines: Marshal the Model.t after training, restore it later.
16. A real training run
Output from dune exec bin/main.exe:
attn-ox: d_model=32 heads=2 d_ff=128 seq=8 vocab=32 batch=16 lr=0.001 params=13760
step 250 loss=0.9444 digit_acc=0.678 ex_acc=0.355
step 500 loss=0.6424 digit_acc=0.803 ex_acc=0.605
step 750 loss=0.4394 digit_acc=0.850 ex_acc=0.699
step 1000 loss=0.4011 digit_acc=0.883 ex_acc=0.766
step 1250 loss=0.2981 digit_acc=0.979 ex_acc=0.957
step 1500 loss=0.2204 digit_acc=0.965 ex_acc=0.930
step 1750 loss=0.1533 digit_acc=1.000 ex_acc=1.000
step 2000 loss=0.0981 digit_acc=0.990 ex_acc=0.980
step 2250 loss=0.0675 digit_acc=1.000 ex_acc=1.000
step 2500 loss=0.0504 digit_acc=1.000 ex_acc=1.000
step 2750 loss=0.0630 digit_acc=1.000 ex_acc=1.000
step 3000 loss=0.0244 digit_acc=1.000 ex_acc=1.000
step 3250 loss=0.0195 digit_acc=1.000 ex_acc=1.000
step 3500 loss=1.0294 digit_acc=0.867 ex_acc=0.750
step 3750 loss=0.0219 digit_acc=1.000 ex_acc=1.000
step 4000 loss=0.0125 digit_acc=1.000 ex_acc=1.000
step 4500 loss=0.0083 digit_acc=1.000 ex_acc=1.000
step 5000 loss=0.0056 digit_acc=1.000 ex_acc=1.000
final: digit_acc=1.000 ex_acc=1.000
digit_acc is the fraction of answer digits the model gets right (512 total, two per example, 256 examples). ex_acc is the fraction of examples where both digits are right. The model first hits 100% example accuracy at step 1750 and converges to it stably by ~step 3000. (One brief regression at step 3500 from gradient noise, but it recovers.)
Initial loss is around 1.5, well below log(32) ≈ 3.47 (random over a 32-token vocabulary) because the warmup completes in 50 steps and the model has already started fitting before the first 250-step report. It drops to 0.006 by the end.
A few sample predictions on random examples after training:
sample predictions:
9 + 2 = 0b (truth 0b) OK
3 + b = 0e (truth 0e) OK ← b is the hex digit for 11
3 + 9 = 0c (truth 0c) OK
1 + 7 = 08 (truth 08) OK
9 + 9 = 12 (truth 12) OK ← carry (9+9=18, hex 12)
d + 9 = 16 (truth 16) OK ← carry
b + b = 16 (truth 16) OK ← carry
8 + f = 17 (truth 17) OK ← carry
8 + a = 12 (truth 12) OK ← the example we'll inspect in §17
17. What the model learned
Here is the model’s behaviour on 8 + a after training, captured by Inspect.dump_example. Recall that a is the hex digit for 10, so 8 + a = 18 (decimal), which is 12 in hex.
────────── inspecting 8 + a ──────────
input ids : BOS 8 + a = 1 2 PAD
positions : BOS 8 + a = c1 c2 PAD
top-3 predictions at the answer positions:
pos 4 (=, target=1): 1=0.999 2=0.001 9=0.000
pos 5 (c1, target=2): 2=0.993 3=0.005 1=0.002
The model is confident: 99.9% probability on the correct first answer digit, 99.3% on the second. (These are the top three entries of rows 4 and 5 of the full 8 × 32 probability distribution shown in §2; the “top-3 predictions” summary just hides the small values.)
Now look at the attention probabilities. Each row is a probability distribution: row i shows where position i reaches when computing its output. The causal mask blanks the upper triangle.
head 0 (rows attend to cols, causal mask hides upper triangle)
BOS 8 + a = c1 c2 PAD
BOS 1.000 · · · · · · ·
8 0.706 0.294 · · · · · ·
+ 0.274 0.458 0.268 · · · · ·
a 0.368 0.156 0.358 0.119 · · · ·
= 0.068 0.259 0.054 0.524 0.095 · · ·
c1 0.048 0.192 0.035 0.278 0.050 0.398 · ·
c2 0.123 0.195 0.097 0.156 0.105 0.166 0.158 ·
PAD 0.112 0.097 0.122 0.106 0.130 0.161 0.148 0.123
head 1 (rows attend to cols, causal mask hides upper triangle)
BOS 8 + a = c1 c2 PAD
BOS 1.000 · · · · · · ·
8 0.801 0.199 · · · · · ·
+ 0.316 0.369 0.315 · · · · ·
a 0.346 0.050 0.535 0.069 · · · ·
= 0.095 0.347 0.076 0.378 0.105 · · ·
c1 0.056 0.497 0.033 0.265 0.038 0.111 · ·
c2 0.138 0.111 0.175 0.162 0.171 0.077 0.165 ·
PAD 0.125 0.039 0.166 0.072 0.194 0.222 0.080 0.102
Two rows of head 1 are doing the work. They are rows 4 and 5, the rows the loss actually scores.
Head 1, row = (position 4). Reading along that row in the matrix: 0.347 on column 8 (position 1, where x lives) and 0.378 on column a (position 3, where y lives). When the model is sitting at the = token, head 1 reaches back and pulls in both operands, roughly equally, because computing the first answer digit requires knowing whether x + y ≥ 16 (which depends on both).
That weighting is the §9.1 mechanism running for this head: it is softmax(Q[4] · K[j] / sqrt(16)) evaluated at all j. Through training, the query Q[4] has learned to match the keys at both operand positions, not at any other positions. After softmax, mass concentrates on positions 1 and 3.
Head 1’s output at row = is Σ_j probs[4][j] * V[j], a weighted sum of value vectors. Most of the weight goes on the value vectors at the two operand positions V[1] and V[3]. These two vectors are not the literal digits 8 and a; they are learned 16-element vectors that the model has put there during training to encode whatever it needs about each operand. Head 1’s combined output is roughly a 50/50 mix of V[1] and V[3]. This vector gets added into the central 32-element state at position 4 via the attention residual + (§7’s diagram). That central running state, the vertical column in the §7 diagram (the value of X, then X', then Y at each position), is typically called the residual stream. Each sub-block adds its output to it rather than replacing it. From here, the residual stream at position 4 is processed by the Feed-forward block (which adds another contribution), then the final LayerNorm, then the output projection. Only at the end does it become the prediction 1 (the high digit of 12). Head 1’s job is routing: matching Q against K
to choose where to read V from. Producing the answer digit is the rest of the pipeline’s job.
The pattern is position-based so for any input this row of head 1 attends to positions 1 and 3 (wherever the two operands sit). The columns happen to be labelled 8 and a only because in this example the operands have those values.
Head 1, row c1 (position 5). 0.497 on column 8 (position 1) plus 0.265 on column a (position 3): again, both operands. Plus 0.111 on column c1 (position 5, the row’s own position). Position 5 already has information about both operands mixed in from the previous step, so the residual stream at c1 ends up holding both x and y. That’s exactly what the Feed-forward block needs to compute the second answer digit.
Head 0 has a similar pattern but slightly more diffuse: 0.259 on column 8 and 0.524 on column a at row =. With only two heads, both heads do meaningful work here; the model evidently needs both sources of information at every answer position, since computing the carry from x + y always requires both inputs.
18. Did the model learn addition?
§16 reports 100% accuracy. But it’s measured on the same 256 examples the model was trained on. I became concerned whether the model had learned the rule x + y → digits or has merely memorised the 256 input/output pairs which it was provided during training.
Rather than train on the entire dataset, only train on a random subset and then evaluate on the remainder. A rule-learner would be expected to get the unseen cases correct, but if the model simply memorised the results, it would fail on new data.
Furthermore, with d_model = 32, there are roughly 54 parameters per training example, making memorisation almost the logical approach. With d_model = 4, there would only be 376 parameters in total, which is just above the 256 threshold, thereby requiring a rule-based solution.
bin/long_train.exe uses d_model = 4 (376 parameters, with the same architecture, just narrower vectors), trained on 230 hex examples (a random 90% of the 256), and evaluated on the 26 remaining examples it has never seen, here’s what happens over training:
d_model=4, 376 params, 230 train, 26 held-out:
step train_ex_acc held_ex_acc
1000 0.270 0.115
2000 0.739 0.615
3000 0.865 0.692
4000 0.943 0.885
5000 0.974 1.000 ← held-out hits 100% BEFORE training does
6000 0.991 1.000
7000 0.996 0.962 ← brief regression
8000 1.000 1.000 ← full convergence on both
...
50000 1.000 1.000 (stays at 100/100 for the rest of the run)
Held-out accuracy goes from 0.12 at step 1000 to 1.000 at step 5000 and stays there. The 26 held-out (x, y) pairs the model never saw during training all get correct answers.
There’s a giveaway in the trajectory: held-out hits 100% before training does. At step 5000 the model is at 0.974 on training (still missing some) but already 1.000 on held-out. The only way that can happen is if the model has found a generalising solution, one that gets every input right, including ones it has never been shown. There is no possible way to “memorise the held-out cases without seeing them”; the only way to get them right is to have extracted the rule.
This phenomenon is called grokking (Power et al., 2022).
What does the default d_model = 32 do at 90/10? It’s more nuanced, and the contrast is informative. Across multiple random splits of the data (different split_seed), training to 50,000 steps:
split_seed |
held-out trajectory |
|---|---|
| 42 | groks at step 5000, stable at 100% for the rest |
| 7 | groks briefly at step 10000, drifts back to 0.92 by step 35000 |
| 100 | never groks, plateau at 0.85–0.92 |
| 999 | groks at step 20000, drifts back to 0.92 by step 50000 |
So with d_model = 32, grokking is possible but unstable. Some seeds find the rule and stay; others find it briefly and drift back to memorisation-shaped solutions; others never find it at all. With d_model = 4, in the same setup, grokking happens reliably and stays. Once the small model has found the rule, it stays there for the rest of the training.
The reason is the loss landscape. At d_model = 32 (about 54 parameters per training example), there’s room to memorise and room to encode the rule. Both are valid solutions.
At d_model = 4 (about 1.6 parameters per example), there’s no memorising solution that fits the training set. The only minimum that drives the loss to zero is the rule-extracting one. Gradient descent finds it and there’s nowhere else to go.
A sweep over training fractions confirms the picture. At d_model = 4, with 50,000 training steps and varying how much of the 256-example dataset is held out:
| Train | Held-out | Train acc | Held-out acc | Verdict |
|---|---|---|---|---|
| 128 | 128 | 1.000 | 0.953 | memorises, plateaus |
| 153 | 103 | 1.000 | 0.981 | almost groks |
| 179 | 77 | 1.000 | 1.000 | groks fully |
| 204 | 52 | 1.000 | 1.000 | groks fully |
| 230 | 26 | 1.000 | 1.000 | groks fully |
So with d_model = 4, 70% of the 256 examples is the minimum training fraction at which the model groks fully. Below that, even 50,000 steps isn’t enough; held-out accuracy plateaus.
The threshold makes sense: at 70% (179 training examples), the model has 376 parameters / 179 examples ≈ 2.1 params per example. Above ~2 params per example the easy memorisation path comes back, the model takes it, and held-out lags. Below, the model is forced into the rule.
d_model = 2 would give 140 parameters, well below the memorisation capacity for 256 examples, but neither the training nor the held-out accuracy got above random-guessing.
A 376-parameter transformer trained on 70% of the 256 hex addition examples reliably extracts the rule of addition and applies it to inputs it has never seen.
19. SIMD optimisation
lib/ops.ml is the only module that uses OxCaml extensions. The matmul kernels run on AVX 256-bit vectors via Ocaml_simd_avx.Float64x4 (four float64s per vector, fused multiply-add). The loop pattern is “reorder the loops so the inner one walks contiguous memory, then SIMD the inner loop”:
let matmul ~a ~b ~c =
...
Array.fill cd 0 (m * n) 0.0;
let nv = n / 4 * 4 in
for i = 0 to m - 1 do
let c_off = i * n in
for p = 0 to k - 1 do
let a_ip = ad.(i * k + p) in
let av = Float64x4.set1 (f_of a_ip) in (* broadcast scalar *)
let b_off = p * n in
let mutable j = 0 in (* let mutable, not ref *)
while j < nv do
let bv = Float64x4.Float_array.unsafe_get bd ~idx:(b_off + j) in
let cv = Float64x4.Float_array.unsafe_get cd ~idx:(c_off + j) in
Float64x4.Float_array.unsafe_set cd ~idx:(c_off + j)
(Float64x4.mul_add av bv cv); (* one FMA: cv + av*bv *)
j <- j + 4
done;
while j < n do
cd.(c_off + j) <- cd.(c_off + j) +. (a_ip *. bd.(b_off + j));
j <- j + 1
done
done
done
The OxCaml-specific bits:
Float64x4is an unboxed 256-bit vector type. Passed in YMM registers, no heap allocation.Float64x4.mul_add av bv cvcompiles to onevfmaddinstruction computingcv + av * bvon 4 doubles in parallel.Float64x4.Float_array.unsafe_get arr ~idx:iloadsarr[i..i+3]into a vector register with no allocation.f_ofandf_touse the%unbox_floatand%box_floatintrinsics to convert between OCaml’s boxedfloatand the unboxedfloat#.let mutable j = 0is OxCaml syntax for a mutable local that doesn’t allocate arefcell.
Bench results for the matmul shapes that appear in the model:
=== matmul 32x32x32 === scalar 80.8 µs SIMD 16.8 µs 4.8×
=== matmul 8x32x16 === scalar 9.7 µs SIMD 2.4 µs 4.1×
=== matmul 8x32x128 === scalar 77.4 µs SIMD 15.7 µs 4.9×
=== matmul 8x128x32 === scalar 87.9 µs SIMD 15.1 µs 5.8×
Max numerical difference between scalar and SIMD: about 1e-17. The grad-check tests pass for the SIMD version with the same tolerance as the scalar version (FMA gives more accurate results, not less, because it skips the intermediate rounding step).
End-to-end training time: 17.6s scalar -> 10.3s SIMD (1.7× overall). The end-to-end speedup is smaller than the kernel speedup because softmax, per-head attention scoring, LayerNorm, AdamW, GELU, and the embeddingscatter are still scalar.
20. Reproducing
$ git clone https://github.com/mtelvers/attn-ox
$ cd ~/attn-ox
$ opam exec --switch 5.2.0+ox -- dune build
$ opam exec --switch 5.2.0+ox -- dune runtest # 12 tests, ~50ms
$ opam exec --switch 5.2.0+ox -- dune exec bin/main.exe # train, ~10s
$ opam exec --switch 5.2.0+ox -- dune exec bench/bench.exe # SIMD bench
The seed is fixed in Train.run, so the loss curve and the attention maps in §17 are reproducible.