Inside TinyLM: How I Built a Transformer I Could Actually Read

January 8, 2026 · 15 min read · 📝 Article
machine-learning transformers pytorch cuda

A walkthrough of building a transformer from scratch - RMSNorm vs LayerNorm, RoPE's rotation trick, the KV cache that makes generation not suck, and what training curves actually tell you.

Inside TinyLM: How I Built a Transformer I Could Actually Read

I wanted to understand how transformers work, not the attention diagram everyone draws, but the code that runs when you call model.forward().

So I opened the LLaMA source. Then Hugging Face's implementation. Then GPT-NeoX. Each time I hit the same wall: thousands of files, abstraction layers, configuration systems that needed their own documentation. I could follow the math on paper, but I couldn't point to the line where Q meets K.

So I built my own: ~6,800 lines of Python, two architecture presets, and a transformer small enough to read in an afternoon.

This isn't a tutorial. It's what I learned building it: what clicked, what surprised me, and enough implementation detail that you could build your own.


The Core Loop

Every transformer forward pass is basically the same loop. Here's a simplified version of the real forward path:

def forward(self, x: torch.Tensor) -> torch.Tensor:
    # x: [batch, seq_len] token indices

    h = self.embedding(x)  # [batch, seq_len, dim]

    for i, block in enumerate(self.blocks):
        h = block(h, pos_ctx, cache, layer_idx=i, start_pos=start_pos)

    h = self.norm(h)           # Final normalization
    logits = self.head(h)      # [batch, seq_len, vocab_size]
    return logits

Structurally it's boring: embed → repeat block → norm → head. The real complexity lives inside block():

def forward(self, x, pos_ctx, cache, layer_idx, start_pos):
    # Pre-norm: normalize before each sublayer
    x = x + self.attn(self.norm1(x), pos_ctx, cache, layer_idx, start_pos)
    x = x + self.mlp(self.norm2(x))
    return x

Two residual branches. Two normalizations. One attention and one MLP. Repeat N times.

What makes this hard in practice isn't the shape of the loop. It's the constant fight against instability, bandwidth, and latency.


The Parts That Surprised Me

RMSNorm: Fewer Moving Parts

LayerNorm does two things: centers (subtract mean) and scales (divide by std). RMSNorm drops the centering:

# LayerNorm: two reductions (mean, then variance)
y = (x - mean(x)) / std(x) * γ + β

# RMSNorm: one reduction (mean of squares)
y = x / sqrt(mean(x²) + ε) * γ

centering learned scale learned bias

Why does dropping mean subtraction work? The RMSNorm paper's core claim is empirical: you can drop mean-centering and still train well. What matters is controlling activation scale. Without normalization, residual streams tend to drift and variance accumulates with depth.

In practical terms: RMSNorm gives you a stabilizer with one reduction instead of two.

The CUDA kernel makes the dataflow concrete:

// One block per row, templatized for fp16/fp32
template<typename scalar_t>
__global__ void rmsnorm_fwd_kernel(
    const scalar_t* __restrict__ x, const scalar_t* __restrict__ w,
    scalar_t* y, float* inv_rms_out, int hidden, float eps) {

int row = blockIdx.x;
const scalar_t* x_row = x + row * hidden;

float sumsq = 0.f;
for (int i = threadIdx.x; i < hidden; i += blockDim.x)
sumsq += to_float(x_row[i]) * to_float(x_row[i]);
float reduced = blockReduceSum<float>(sumsq); // warp shuffles

shared float s_inv_rms;
if (threadIdx.x == 0) {
s_inv_rms = rsqrtf(reduced / hidden + eps);
inv_rms_out[row] = s_inv_rms; // cache for backward
}
__syncthreads();

scalar_t* y_row = y + row * hidden;
for (int i = threadIdx.x; i < hidden; i += blockDim.x)
y_row[i] = from_float<scalar_t>(
to_float(x_row[i]) * s_inv_rms * to_float(w[i]));
}

parallel reduce shared mem / cached for bwd type conversion (fp16/fp32) learned scale γ

On NVIDIA GPUs, rsqrtf is a fast device intrinsic (SFU-backed), which is why CUDA best practices recommend it for normalization. We compute it once, broadcast via shared memory, and reuse it across threads. The reduction uses warp shuffles (no atomics). Global memory traffic is the unavoidable part (read x, read weight, write y), but we avoid extra global traffic just to compute the reduction.

I cache inv_rms for backward. The backward needs the same value, and recomputing it would burn bandwidth for no gain.


RoPE (Rotary Position Embedding): Position as Rotation

Most positional encodings add a position vector to the embedding. RoPE does something cleaner: it rotates query and key vectors as a function of position.

A 512-dim query becomes 256 independent 2D rotations:

q = [q₀, q₁, q₂, q₃, q₄, q₅, ...]      # 512 dimensions
       └──┘   └──┘   └──┘
       2D     2D     2D          # 256 pairs

At position t, rotate each pair by t × θᵢ:

pair 0: rotate by t × θ₀ # fast rotation
pair 1: rotate by t × θ₁ # medium
pair 2: rotate by t × θ₂ # slower
... # inv_freq[i] = 1 / base^(2i/d) (angles are t * inv_freq[i])
pair 255: rotate by t × θ₂₅₅ # very slow

Why this gives relative position. The rotations "subtract" in the dot product: a query rotated by t_q θ dotted with a key rotated by t_k θ depends on (t_q - t_k)θ. Attention scores become a function of how far apart tokens are, not just their absolute indices.

Why multiple frequencies. Some pairs rotate quickly (good for local precision), others rotate slowly (good for long-range structure). It's the same intuition as Fourier features.

Implementation-wise, RoPE is mostly precompute + a cheap per-token rotate:

def precompute(self, max_seq_len, device):
    inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2) / self.dim))
    t = torch.arange(max_seq_len)
    freqs = torch.einsum('t,f->tf', t, inv_freq)

    self.sin = torch.sin(torch.cat([freqs, freqs], dim=-1))
    self.cos = torch.cos(torch.cat([freqs, freqs], dim=-1))
def apply(self, x, start_pos):
    sin = self.sin[start_pos:start_pos + seq_len]
    cos = self.cos[start_pos:start_pos + seq_len]

    x1, x2 = x[..., ::2], x[..., 1::2]
    return torch.cat([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1)

No learned parameters, no additive interference with content embeddings, and it often extrapolates better than learned absolute embeddings, though performance can still degrade at lengths far beyond training.


SwiGLU: The Third Projection

A standard transformer MLP is "up → activation → down." It works, but the nonlinearity is doing too much: it decides what to compute and what to keep.

SwiGLU splits those roles:

gate = silu(W_gate @ x)     # should we use this?
up   = W_up @ x             # what value do we compute?
out  = W_down @ (gate * up)
gate (controls flow) up (computes values) down (projects back)

The two projections see the same input but learn different functions. A dimension only contributes if both agree: gate * up.

Three projections sounds like 50% more parameters. The trick is to shrink the hidden size:

Standard: 2 × dim × 4d   = 8d²
SwiGLU:   3 × dim × 8d/3 = 8d²  ← same parameter count (with this hidden size)

Then round for GPU-friendly alignment:

hidden_dim = int(dim * 4 * 2 / 3)
hidden_dim = 256 * ((hidden_dim + 255) // 256)

Making Generation Not Suck

The KV Cache Problem

Training sees full sequences. Generation emits one token at a time. If you naively call the model on the growing prefix, you recompute key/value projections for the entire history on every step.

But keys and values for past tokens don't change. Cache them once, reuse them forever.

Without caching

# O(n²) total: recompute K,V for the whole prefix each step
for i in range(n):
    logits = model(all_tokens[:i+1])

With caching

# Pre-allocate
K[max_seq, dim], V[max_seq, dim] = zeros(...)

# Each token:
K[pos], V[pos] = project(new_token)
attn = Q @ K[:pos].T

Now you're roughly O(n) per token instead of O(n²) total recompute.


Pre-allocation Matters

Pre-allocating [batch, n_kv_heads, max_seq_len, head_dim] means zero allocations during decoding: just indexed writes. If you grow tensors with torch.cat, you'll pay allocator overhead and eventually fragment memory.

Pre-allocation costs you reserved memory, but decoding typically has a known max length anyway.


GQA (Grouped-Query Attention): Trading KV Capacity for Memory

KV cache is often the inference memory bottleneck. The size is easy to compute:

KV cache bytes = 2 × B × n_kv_heads × T × d_head × bytes_per_elem

For LLaMA-2 7B (fp16, batch=1, seq=2048, n_kv=32, d_head=128):

  • Per layer: ~32 MB
  • Across 32 layers: ~1 GB total

GQA reduces n_kv_heads. Multiple query heads share each KV head:

# MHA (Multi-Head Attention): n_heads = n_kv_heads = 32
# GQA (Grouped-Query Attention): n_heads = 32, n_kv_heads = 8
# MQA (Multi-Query Attention): n_heads = 32, n_kv_heads = 1

For a 7B model (32 layers, fp16, 2048 ctx):

Variant Per Layer Total Model
MHA (32 KV heads) ~32 MB ~1 GB
GQA (8 KV heads) ~8 MB ~256 MB
MQA (1 KV head) ~1 MB ~32 MB

These scale linearly with batch size, context length, KV heads, and dtype bytes (fp32 doubles them).

LLaMA 2 70B uses GQA with 8 KV heads for 64 Q heads, an 8× reduction in KV cache, and a big reason GQA shows up in large models.

The implementation is just "repeat KV heads to match Q heads":

def _repeat_kv(self, kv):
    """Expand KV heads to match Q heads."""
    if self.n_rep == 1:
        return kv
    return kv.repeat_interleave(self.n_rep, dim=1)

What the Training Curves Tell You

Training Comparison

13.77M model on TinyStories (25,000 steps)

LLaMA
GPT
LLaMA-style (Pre-norm, RoPE, SwiGLU)
1.25
Val Loss
3.5
Val PPL
GPT-style (Post-norm, Learned, GELU)
1.33
Val Loss
3.8
Val PPL
Training config
batch_size: 32 seq_len: 512 lr: 0.0003 steps: 25,000

I trained two identical models on TinyStories, same size, data, and hyperparameters. Only difference: LLaMA-style (pre-norm, RMSNorm, RoPE, SwiGLU) vs GPT-style (post-norm, LayerNorm, learned positions, standard MLP).

After 25,000 steps:

Architecture Val Loss Val PPL
LLaMA-style 1.25 3.49
GPT-style 1.33 3.79

The more interesting part is the shape: LLaMA-style pulls ahead early and then the curves run roughly parallel. That's a hint the architecture changes are affecting optimization dynamics more than capacity.

None of this is shocking if you've read the papers. What surprised me is how clearly it shows up at tiny scale.

Full training config
model:
  dim: 384
  n_layers: 6
  n_heads: 6
  max_seq_len: 4096

training:
  batch_size: 32
  seq_len: 512
  lr: 0.0003
  weight_decay: 0.1
  warmup_steps: 500
  grad_clip: 1.0

hardware: Single A2000

If You Want to Build Your Own

I'm not going to tell you to use my code. But if you're starting a model project from scratch, this is what would have saved me time.

Start with the forward pass

Write the dumbest possible version first: no cache, no mixed precision, no fused kernels.

def forward(x):
    h = embed(x)
    for block in blocks:
        h = h + attention(norm(h))
        h = h + mlp(norm(h))
    return head(norm(h))

Make gradients and shapes sane. Then optimize.

The registry pattern pays for itself

Registries look like over-engineering until the first time you want to swap one component without touching model code.

@NORM_REGISTRY.register("rmsnorm")
class RMSNorm(nn.Module):
    ...

norm = NORM_REGISTRY.build("rmsnorm", dim=512)

What I'd do differently

Write tests earlier. I had a RoPE bug that only showed up at sequence lengths > 512. A tiny reference-implementation test would've caught it.

Profile before optimizing. I spent a week on a CUDA kernel for a few percent end-to-end speedup. The real bottleneck was attention, and PyTorch's scaled_dot_product_attention (SDPA) already handles that well. (Note: SDPA picks Flash/memory-efficient/math kernels depending on your PyTorch build, GPU, dtype, and tensor shapes.)

Keep configs boring. Hydra is powerful, but it adds conceptual overhead. For many research codebases, dataclasses + argparse is enough.


The Code

If you want to browse the implementation: github.com/RetamalVictor/TinyLM-Lab

The main files:

  • tinylm/model/transformer.py: the forward pass
  • tinylm/components/normalization/rmsnorm.py: RMSNorm dispatch
  • tinylm/components/positional/rope.py: RoPE
  • tinylm/components/attention/mha.py: MHA/GQA/MQA
  • csrc/rmsnorm_cuda.cu: RMSNorm kernel

References