Build a Reasoning LLM from Scratch in Python
Build a reasoning LLM from scratch in Python — a BPE tokenizer, RoPE attention, SwiGLU transformer blocks, and chain-of-thought dual-loss training. No APIs, no wrappers, just pure PyTorch.
No APIs. No wrappers. No shortcuts. Just pure Python, NumPy, and PyTorch — building a transformer that actually thinks before it answers.
Quick Stats
| Metric | Value |
|---|---|
| 📄 Core Python files | 6 |
| 💻 Total lines of code | ~600 |
| 🔌 External AI APIs used | 0 |
| 🧠 Architecture | Decoder-only Transformer |
| ⚡ Special feature | Dual-loss CoT training |
1. What Makes a Model a Reasoning Model?
A regular language model predicts the next token given a context window — nothing more. A reasoning model does the same thing, but its training data and inference strategy are specifically designed to produce chains of intermediate steps before emitting a final answer.
Think of it as the difference between blurting out "5" versus writing out the long-division working that arrives at "5".
Concretely, three things separate a reasoning model from a plain LLM:
- Scratchpad tokens — special tokens like
<think>that the model uses as working memory - Chain-of-thought (CoT) training data — the corpus includes step-by-step solutions, not just final answers
- Verification loss — the model is rewarded not only for the right answer, but for intermediate steps that are correct
📊 Standard LLM vs Reasoning LLM
flowchart LR
subgraph STD["🔵 STANDARD LLM"]
direction LR
A1["📝 Input Prompt"] --> B1["⚙️ Transformer"] --> C1["💬 Final Answer"]
end
subgraph RSN["🟢 REASONING LLM"]
direction LR
A2["📝 Input Prompt"] --> B2["⚙️ Transformer"]
B2 --> T1["🧠 think:\nStep 1...\nStep 2...\nStep 3..."]
T1 --> C2["✅ answer\n(verified)"]
end
STD ~~~ RSN
💡 Key insight: The model doesn’t become smarter by magic. It becomes smarter because it learned — from training data — that writing out intermediate steps leads to more reliable final answers.
2. The Architecture at a Glance
Our reasoning LLM is a decoder-only transformer — the same family as GPT, LLaMA, and DeepSeek. We add one twist: a small vocabulary of special reasoning tokens and a modified loss function that weights scratchpad steps alongside final answers.
📊 Full Model Architecture
flowchart TD
IN["📥 Input IDs (B, T)"]
IN --> EMB
EMB["🔤 Token Embedding\nvocab_size × d_model"]
EMB --> POS
POS["📍 RoPE Positional Encoding\nbuilt into attention Q, K"]
POS --> BLK
subgraph BLK["🔁 Transformer Block × N layers"]
direction TB
N1["RMSNorm"] --> ATT["Masked Self-Attention\n+ Residual"]
ATT --> N2["RMSNorm"]
N2 --> FFN["SwiGLU Feed-Forward\n+ Residual"]
end
BLK --> NORM["RMSNorm (final)"]
NORM --> LMH["🎯 LM Head + Softmax\nd_model → vocab_size"]
LMH --> TH["🟢 think tokens\nloss × α = 0.5"]
LMH --> AN["🔵 answer tokens\nloss × 1.0"]
TH --> LOSS["⚡ DUAL LOSS\nL = α·L_think + L_answer"]
AN --> LOSS
3. Step 1 — Build a Tokenizer
A tokenizer converts raw text into integer IDs the model can process. We implement a simple Byte-Pair Encoding (BPE) tokenizer that also registers our special reasoning tokens.
📊 Special Token Registry
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# tokenizer.py
from collections import defaultdict, Counter
import re
SPECIAL_TOKENS = {
"<pad>": 0,
"<unk>": 1,
"<bos>": 2,
"<eos>": 3,
"<think>": 4, # begin scratchpad
"</think>": 5, # end scratchpad
"<answer>": 6, # begin final answer
"</answer>": 7, # end final answer
}
class BPETokenizer:
def __init__(self, vocab_size: int = 8000):
self.vocab_size = vocab_size
self.vocab = dict(SPECIAL_TOKENS)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.merges = []
def _get_pairs(self, word_freqs):
pairs = defaultdict(int)
for word, freq in word_freqs.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i + 1])] += freq
return pairs
def train(self, corpus: str) -> None:
"""Train BPE on a raw text corpus string."""
words = re.findall(r'\S+', corpus.lower())
word_freqs = Counter(' '.join(list(w)) + ' </w>' for w in words)
next_id = len(self.vocab)
for word in word_freqs:
for ch in word.split():
if ch not in self.vocab:
self.vocab[ch] = next_id
self.inv_vocab[next_id] = ch
next_id += 1
while len(self.vocab) < self.vocab_size:
pairs = self._get_pairs(word_freqs)
if not pairs:
break
best = max(pairs, key=pairs.get)
merged = ''.join(best)
self.merges.append(best)
self.vocab[merged] = next_id
self.inv_vocab[next_id] = merged
next_id += 1
pattern = re.compile(r'(?<!\S)' + re.escape(' '.join(best)) + r'(?!\S)')
word_freqs = {pattern.sub(merged, w): f for w, f in word_freqs.items()}
def encode(self, text: str) -> list[int]:
tokens = [self.vocab.get(t, 1) for t in text.split()]
return [SPECIAL_TOKENS["<bos>"]] + tokens + [SPECIAL_TOKENS["<eos>"]]
def decode(self, ids: list[int]) -> str:
return ' '.join(self.inv_vocab.get(i, '<unk>') for i in ids)
⚠️ Note: The
<think>and</think>tokens are hardcoded at IDs 4–7 so they are always present regardless of training corpus. The model learns to emit them by seeing CoT-formatted training examples.
4. Step 2 — Embeddings & Positional Encoding
Embeddings map integer token IDs to dense vectors. We add Rotary Positional Encoding (RoPE) — the same technique used in LLaMA and Mistral — directly inside the attention layer.
📊 RoPE vs Sinusoidal Positional Encoding
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# embeddings.py
import torch, torch.nn as nn, math
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, d_model: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.scale = math.sqrt(d_model)
def forward(self, x): # x: (B, T)
return self.embed(x) * self.scale
def precompute_rope(d_head: int, max_seq: int, base: int = 10_000):
"""Precompute RoPE frequency matrix as complex numbers."""
theta = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
pos = torch.arange(max_seq).float()
freqs = torch.outer(pos, theta) # (T, d_head/2)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rope(q, k, freqs):
"""Apply rotary encoding to query and key tensors."""
q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
q_rot = torch.view_as_real(q_ * freqs).flatten(3)
k_rot = torch.view_as_real(k_ * freqs).flatten(3)
return q_rot.type_as(q), k_rot.type_as(k)
5. Step 3 — Self-Attention from Scratch
Masked multi-head self-attention is the computational heart of every transformer. Each token attends to all preceding tokens — the causal mask prevents it from seeing the future.
📊 Causal Attention Mask & Multi-Head Layout
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# attention.py
import torch, torch.nn as nn, torch.nn.functional as F
from embeddings import precompute_rope, apply_rope
class CausalSelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, max_seq: int, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
self.drop = nn.Dropout(dropout)
freqs = precompute_rope(self.d_head, max_seq)
self.register_buffer('freqs', freqs)
def forward(self, x): # x: (B, T, D)
B, T, D = x.shape
qkv = self.qkv(x).split(D, dim=-1)
q, k, v = [
t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
for t in qkv
]
q, k = apply_rope(q, k, self.freqs[:T])
attn = F.scaled_dot_product_attention(
q, k, v,
is_causal=True,
dropout_p=self.drop.p if self.training else 0.0
)
attn = attn.transpose(1, 2).contiguous().view(B, T, D)
return self.out(attn)
6. Step 4 — The Transformer Block
Each transformer block wraps the attention layer with a feed-forward network and RMS normalisation. We use SwiGLU as the activation — the same choice made by LLaMA, Mistral, and Gemma.
📊 SwiGLU vs ReLU FFN
flowchart LR
subgraph RELU["ReLU FFN (old)"]
direction TB
r1["x"] --> r2["Linear"] --> r3["ReLU"] --> r4["Linear"] --> r5["output"]
end
subgraph SWIGLU["SwiGLU FFN ✓ (we use)"]
direction TB
s1["x"] --> sg["Gate: Linear → SiLU"]
s1 --> su["Up: Linear"]
sg --> sm["⊗ element-wise multiply"]
su --> sm
sm --> sd["Down: Linear"] --> s5["output"]
end
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# model.py
import torch, torch.nn as nn, torch.nn.functional as F
from attention import CausalSelfAttention
from embeddings import TokenEmbedding
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-6):
super().__init__()
self.g, self.eps = nn.Parameter(torch.ones(d)), eps
def forward(self, x):
return x * x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() * self.g
class SwiGLUFFN(nn.Module):
def __init__(self, d_model: int, hidden: int):
super().__init__()
self.gate = nn.Linear(d_model, hidden, bias=False)
self.up = nn.Linear(d_model, hidden, bias=False)
self.down = nn.Linear(hidden, d_model, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, max_seq, ffn_mult=4, dropout=0.1):
super().__init__()
self.attn = CausalSelfAttention(d_model, n_heads, max_seq, dropout)
self.ffn = SwiGLUFFN(d_model, int(d_model * ffn_mult * 2 / 3))
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
def forward(self, x):
x = x + self.attn(self.norm1(x)) # pre-norm + residual
x = x + self.ffn(self.norm2(x))
return x
class ReasoningLLM(nn.Module):
def __init__(self, vocab_size, d_model=256, n_heads=8,
n_layers=6, max_seq=512, dropout=0.1):
super().__init__()
self.embed = TokenEmbedding(vocab_size, d_model)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, max_seq, dropout=dropout)
for _ in range(n_layers)
])
self.norm = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying — cuts params by 15–30%, zero accuracy cost
self.lm_head.weight = self.embed.embed.weight
def forward(self, ids): # (B, T)
x = self.embed(ids)
for block in self.blocks:
x = block(x)
return self.lm_head(self.norm(x)) # (B, T, vocab_size)
7. Step 5 — Chain-of-Thought Training Strategy
This is the key differentiator. Every training example is formatted as a problem → scratchpad → answer triple. We compute loss over both spans with separate weights.
📊 CoT Training Sequence & Loss Mask
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# loss.py
import torch, torch.nn.functional as F
THINK_START, THINK_END = 4, 5
ANS_START, ANS_END = 6, 7
def build_loss_mask(ids: torch.Tensor, alpha: float = 0.5) -> torch.Tensor:
B, T = ids.shape
weights = torch.zeros(B, T, device=ids.device)
for b in range(B):
in_think = in_answer = False
for t in range(T):
tok = ids[b, t].item()
if tok == THINK_START: in_think = True
elif tok == THINK_END: in_think = False
elif tok == ANS_START: in_answer = True
elif tok == ANS_END: in_answer = False
elif in_think: weights[b, t] = alpha
elif in_answer: weights[b, t] = 1.0
return weights
def reasoning_loss(logits, targets, ids, alpha: float = 0.5):
B, T, V = logits.shape
mask = build_loss_mask(ids, alpha)
ce = F.cross_entropy(
logits.view(-1, V), targets.view(-1), reduction='none'
).view(B, T)
return (ce * mask).sum() / mask.sum().clamp(min=1)
8. Step 6 — The Training Loop
📊 Learning Rate Schedule
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# train.py
import torch, math
from model import ReasoningLLM
from loss import reasoning_loss
CFG = {
"vocab_size": 8000, "d_model": 256, "n_heads": 8,
"n_layers": 6, "max_seq": 512, "lr": 3e-4,
"batch": 32, "epochs": 20, "alpha": 0.5,
"warmup": 200, "clip": 1.0,
}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ReasoningLLM(**{k: CFG[k] for k in
["vocab_size","d_model","n_heads","n_layers","max_seq"]}).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=CFG["lr"],
weight_decay=0.1, betas=(0.9, 0.95))
def lr_schedule(step, warmup, total):
if step < warmup:
return step / warmup
pct = (step - warmup) / max(1, total - warmup)
return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * pct))
total_steps = CFG["epochs"] * steps_per_epoch
scheduler = torch.optim.lr_scheduler.LambdaLR(
optim, lambda s: lr_schedule(s, CFG["warmup"], total_steps)
)
for epoch in range(CFG["epochs"]):
model.train()
epoch_loss = 0.0
for ids, targets in dataloader:
ids, targets = ids.to(device), targets.to(device)
logits = model(ids)
loss = reasoning_loss(logits[:, :-1], targets[:, 1:],
ids[:, :-1], alpha=CFG["alpha"])
optim.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["clip"])
optim.step()
scheduler.step()
epoch_loss += loss.item()
print(f"Epoch {epoch+1:02d} loss={epoch_loss/steps_per_epoch:.4f}"
f" lr={scheduler.get_last_lr()[0]:.2e}")
⚠️ Training data format is everything. Your
dataloadermust yield sequences formatted as<bos> problem <think> steps </think> <answer> answer </answer> <eos>. Without CoT-formatted data, the model skips the scratchpad entirely.
9. Optimisation Tips
📊 Hyperparameter Cheatsheet
| Parameter | Toy (CPU) | Small (8 GB) | Tip |
|---|---|---|---|
| d_model | 128–256 | 512–768 | divisible by n_heads |
| n_layers | 4–6 | 12–24 | deeper > wider for CoT |
| n_heads | 4–8 | 8–16 | d_head ≥ 32 |
| lr | 3e-4 | 1e-4 → 1e-5 | warmup + cosine |
| alpha (α) | 0.5 | 0.3–0.7 | ↑ if no <think> emitted |
| max_seq | 256–512 | 1024–2048 | scratchpad needs room |
| batch_size | 8–16 | 32–128 | use grad accumulation |
Weight tying — Share weights between TokenEmbedding and lm_head. Cuts parameter count by 15–30% with zero accuracy loss. Standard in GPT-2, LLaMA, and Mistral.
Gradient clipping — Always clip at 1.0. Transformer training can produce large gradients unexpectedly, especially early in CoT training.
Pre-norm vs post-norm — Use pre-norm (apply RMSNorm before the sublayer). Pre-norm is significantly more stable with deeper networks.
Good open-source CoT datasets:
| Dataset | Domain | CoT Format |
|---|---|---|
| GSM8K | Grade school math | Step-by-step |
| MetaMath | Math (augmented) | Chain-of-thought |
| NuminaMath | Competition math | Full working |
| OpenWebMath | General math text | Semi-structured |
| WizardLM-CoT | General reasoning | Multi-step |
10. What’s Next?
📊 Extension Roadmap
flowchart TD
BASE["✅ Current Build\nTokenizer + Transformer + CoT Loss"]
BASE --> KV["KV Cache\nFast autoregressive generation\n⚡ Medium"]
BASE --> GQA["Grouped Query Attention\nLess VRAM · Faster inference\n⚡ Medium"]
BASE --> DPO["RLHF / DPO\nReward correct reasoning steps\n🔥 Hard"]
KV --> SPEC["Speculative Decoding\n3–4× faster generation\n🔥 Hard"]
GQA --> SPEC
DPO --> SPEC
SPEC --> MOE["Mixture of Experts\nScale params ≠ scale compute\n🔥 Hard"]
| Extension | What It Adds | Difficulty | Pure Python? |
|---|---|---|---|
| KV Cache | Fast autoregressive generation | Medium | ✅ |
| Grouped Query Attention | Less VRAM, faster inference | Medium | ✅ |
| RLHF / DPO on CoT traces | Reward correct reasoning steps | Hard | ✅ |
| Speculative Decoding | 3–4× faster generation | Hard | ✅ |
| Mixture of Experts (MoE) | Scale params without scaling compute | Hard | ✅ |
Conclusion
Reasoning in LLMs is not magic — it is a structured engineering choice. By teaching the model to write its work inside <think> tokens, and training on those scratchpad traces with a weighted dual-loss, you get a model that makes far fewer careless errors on multi-step problems.
You have now built every layer from first principles:
- ✅ BPE Tokenizer with special
<think>/<answer>tokens - ✅ RoPE Positional Encoding baked directly into attention
- ✅ Causal Multi-Head Self-Attention with Flash Attention support
- ✅ SwiGLU Feed-Forward + RMSNorm Transformer Block
- ✅ Full
ReasoningLLMclass with weight tying - ✅ Dual-loss (
reasoning_loss) weighting scratchpad vs answer spans - ✅ AdamW + cosine LR Training loop with gradient clipping
The entire codebase — tokenizer.py, embeddings.py, attention.py, model.py, loss.py, train.py — is under 600 lines of pure Python and PyTorch.
No external AI API was used. No wrappers. No shortcuts. That is your reasoning model.
