Post

Build a Reasoning LLM from Scratch in Python

Build a reasoning LLM from scratch in Python — a BPE tokenizer, RoPE attention, SwiGLU transformer blocks, and chain-of-thought dual-loss training. No APIs, no wrappers, just pure PyTorch.

Build a Reasoning LLM from Scratch in Python

No APIs. No wrappers. No shortcuts. Just pure Python, NumPy, and PyTorch — building a transformer that actually thinks before it answers.

Quick Stats

MetricValue
📄 Core Python files6
💻 Total lines of code~600
🔌 External AI APIs used0
🧠 ArchitectureDecoder-only Transformer
⚡ Special featureDual-loss CoT training

1. What Makes a Model a Reasoning Model?

A regular language model predicts the next token given a context window — nothing more. A reasoning model does the same thing, but its training data and inference strategy are specifically designed to produce chains of intermediate steps before emitting a final answer.

Think of it as the difference between blurting out "5" versus writing out the long-division working that arrives at "5".

Concretely, three things separate a reasoning model from a plain LLM:

  • Scratchpad tokens — special tokens like <think> that the model uses as working memory
  • Chain-of-thought (CoT) training data — the corpus includes step-by-step solutions, not just final answers
  • Verification loss — the model is rewarded not only for the right answer, but for intermediate steps that are correct

📊 Standard LLM vs Reasoning LLM

flowchart LR
    subgraph STD["🔵 STANDARD LLM"]
        direction LR
        A1["📝 Input Prompt"] --> B1["⚙️ Transformer"] --> C1["💬 Final Answer"]
    end

    subgraph RSN["🟢 REASONING LLM"]
        direction LR
        A2["📝 Input Prompt"] --> B2["⚙️ Transformer"]
        B2 --> T1["🧠 think:\nStep 1...\nStep 2...\nStep 3..."]
        T1 --> C2["✅ answer\n(verified)"]
    end

    STD ~~~ RSN

💡 Key insight: The model doesn’t become smarter by magic. It becomes smarter because it learned — from training data — that writing out intermediate steps leads to more reliable final answers.

2. The Architecture at a Glance

Our reasoning LLM is a decoder-only transformer — the same family as GPT, LLaMA, and DeepSeek. We add one twist: a small vocabulary of special reasoning tokens and a modified loss function that weights scratchpad steps alongside final answers.

📊 Full Model Architecture

flowchart TD
    IN["📥 Input IDs (B, T)"]
    IN --> EMB

    EMB["🔤 Token Embedding\nvocab_size × d_model"]
    EMB --> POS

    POS["📍 RoPE Positional Encoding\nbuilt into attention Q, K"]
    POS --> BLK

    subgraph BLK["🔁 Transformer Block × N layers"]
        direction TB
        N1["RMSNorm"] --> ATT["Masked Self-Attention\n+ Residual"]
        ATT --> N2["RMSNorm"]
        N2 --> FFN["SwiGLU Feed-Forward\n+ Residual"]
    end

    BLK --> NORM["RMSNorm (final)"]
    NORM --> LMH["🎯 LM Head + Softmax\nd_model → vocab_size"]

    LMH --> TH["🟢 think tokens\nloss × α = 0.5"]
    LMH --> AN["🔵 answer tokens\nloss × 1.0"]

    TH --> LOSS["⚡ DUAL LOSS\nL = α·L_think + L_answer"]
    AN --> LOSS

3. Step 1 — Build a Tokenizer

A tokenizer converts raw text into integer IDs the model can process. We implement a simple Byte-Pair Encoding (BPE) tokenizer that also registers our special reasoning tokens.

📊 Special Token Registry

SPECIAL TOKEN REGISTRY TOKEN ID PURPOSE <pad> 0 Padding — ignored in loss <unk> 1 Unknown token fallback <bos> 2 Begin of sequence <eos> 3 End of sequence <think> 4 ← Begin scratchpad reasoning </think> 5 ← End scratchpad reasoning <answer> 6 ← Begin final answer </answer> 7 ← End final answer
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# tokenizer.py
from collections import defaultdict, Counter
import re

SPECIAL_TOKENS = {
    "<pad>":     0,
    "<unk>":     1,
    "<bos>":     2,
    "<eos>":     3,
    "<think>":   4,    # begin scratchpad
    "</think>":  5,    # end scratchpad
    "<answer>":  6,    # begin final answer
    "</answer>": 7,    # end final answer
}

class BPETokenizer:
    def __init__(self, vocab_size: int = 8000):
        self.vocab_size = vocab_size
        self.vocab      = dict(SPECIAL_TOKENS)
        self.inv_vocab  = {v: k for k, v in self.vocab.items()}
        self.merges     = []

    def _get_pairs(self, word_freqs):
        pairs = defaultdict(int)
        for word, freq in word_freqs.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def train(self, corpus: str) -> None:
        """Train BPE on a raw text corpus string."""
        words      = re.findall(r'\S+', corpus.lower())
        word_freqs = Counter(' '.join(list(w)) + ' </w>' for w in words)

        next_id = len(self.vocab)
        for word in word_freqs:
            for ch in word.split():
                if ch not in self.vocab:
                    self.vocab[ch]          = next_id
                    self.inv_vocab[next_id] = ch
                    next_id += 1

        while len(self.vocab) < self.vocab_size:
            pairs = self._get_pairs(word_freqs)
            if not pairs:
                break
            best   = max(pairs, key=pairs.get)
            merged = ''.join(best)
            self.merges.append(best)
            self.vocab[merged]      = next_id
            self.inv_vocab[next_id] = merged
            next_id += 1
            pattern    = re.compile(r'(?<!\S)' + re.escape(' '.join(best)) + r'(?!\S)')
            word_freqs = {pattern.sub(merged, w): f for w, f in word_freqs.items()}

    def encode(self, text: str) -> list[int]:
        tokens = [self.vocab.get(t, 1) for t in text.split()]
        return [SPECIAL_TOKENS["<bos>"]] + tokens + [SPECIAL_TOKENS["<eos>"]]

    def decode(self, ids: list[int]) -> str:
        return ' '.join(self.inv_vocab.get(i, '<unk>') for i in ids)

⚠️ Note: The <think> and </think> tokens are hardcoded at IDs 4–7 so they are always present regardless of training corpus. The model learns to emit them by seeing CoT-formatted training examples.

4. Step 2 — Embeddings & Positional Encoding

Embeddings map integer token IDs to dense vectors. We add Rotary Positional Encoding (RoPE) — the same technique used in LLaMA and Mistral — directly inside the attention layer.

📊 RoPE vs Sinusoidal Positional Encoding

POSITIONAL ENCODING COMPARISON FEATURE SINUSOIDAL RoPE ✓ (we use this) Extra parameters None None Long sequences Degrades Extrapolates well Where applied After embedding layer Inside attention Q, K Relative position Implicit only Explicit Used by GPT-2, BERT LLaMA, Mistral, Qwen
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# embeddings.py
import torch, torch.nn as nn, math

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.scale = math.sqrt(d_model)

    def forward(self, x):   # x: (B, T)
        return self.embed(x) * self.scale


def precompute_rope(d_head: int, max_seq: int, base: int = 10_000):
    """Precompute RoPE frequency matrix as complex numbers."""
    theta = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
    pos   = torch.arange(max_seq).float()
    freqs = torch.outer(pos, theta)              # (T, d_head/2)
    return torch.polar(torch.ones_like(freqs), freqs)


def apply_rope(q, k, freqs):
    """Apply rotary encoding to query and key tensors."""
    q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
    k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
    q_rot = torch.view_as_real(q_ * freqs).flatten(3)
    k_rot = torch.view_as_real(k_ * freqs).flatten(3)
    return q_rot.type_as(q), k_rot.type_as(k)

5. Step 3 — Self-Attention from Scratch

Masked multi-head self-attention is the computational heart of every transformer. Each token attends to all preceding tokens — the causal mask prevents it from seeing the future.

📊 Causal Attention Mask & Multi-Head Layout

CAUSAL MASK (4 tokens) MULTI-HEAD ATTENTION x1 x2 x3 x4 x1 x2 x3 x4 1.0 −∞ −∞ −∞ 0.6 1.0 −∞ −∞ 0.3 0.5 1.0 −∞ 0.2 0.4 0.7 1.0 attend −∞ masked Head 1 Head 2 Head H Concat(h1, h2, … hH) × W_O → output softmax( Q·Kᵀ / √d_k ) · V upper-triangle of QKᵀ = −∞ (causal mask) d_head = d_model / n_heads
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# attention.py
import torch, torch.nn as nn, torch.nn.functional as F
from embeddings import precompute_rope, apply_rope

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_seq: int, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head  = d_model // n_heads
        self.qkv     = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out     = nn.Linear(d_model, d_model,     bias=False)
        self.drop    = nn.Dropout(dropout)
        freqs = precompute_rope(self.d_head, max_seq)
        self.register_buffer('freqs', freqs)

    def forward(self, x):          # x: (B, T, D)
        B, T, D = x.shape
        qkv     = self.qkv(x).split(D, dim=-1)
        q, k, v = [
            t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
            for t in qkv
        ]
        q, k = apply_rope(q, k, self.freqs[:T])
        attn = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,
            dropout_p=self.drop.p if self.training else 0.0
        )
        attn = attn.transpose(1, 2).contiguous().view(B, T, D)
        return self.out(attn)

6. Step 4 — The Transformer Block

Each transformer block wraps the attention layer with a feed-forward network and RMS normalisation. We use SwiGLU as the activation — the same choice made by LLaMA, Mistral, and Gemma.

📊 SwiGLU vs ReLU FFN

flowchart LR
    subgraph RELU["ReLU FFN (old)"]
        direction TB
        r1["x"] --> r2["Linear"] --> r3["ReLU"] --> r4["Linear"] --> r5["output"]
    end

    subgraph SWIGLU["SwiGLU FFN ✓ (we use)"]
        direction TB
        s1["x"] --> sg["Gate: Linear → SiLU"]
        s1      --> su["Up:   Linear"]
        sg --> sm["⊗ element-wise multiply"]
        su --> sm
        sm --> sd["Down: Linear"] --> s5["output"]
    end
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# model.py
import torch, torch.nn as nn, torch.nn.functional as F
from attention  import CausalSelfAttention
from embeddings import TokenEmbedding

class RMSNorm(nn.Module):
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.g, self.eps = nn.Parameter(torch.ones(d)), eps

    def forward(self, x):
        return x * x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() * self.g


class SwiGLUFFN(nn.Module):
    def __init__(self, d_model: int, hidden: int):
        super().__init__()
        self.gate = nn.Linear(d_model, hidden, bias=False)
        self.up   = nn.Linear(d_model, hidden, bias=False)
        self.down = nn.Linear(hidden,  d_model, bias=False)

    def forward(self, x):
        return self.down(F.silu(self.gate(x)) * self.up(x))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, max_seq, ffn_mult=4, dropout=0.1):
        super().__init__()
        self.attn  = CausalSelfAttention(d_model, n_heads, max_seq, dropout)
        self.ffn   = SwiGLUFFN(d_model, int(d_model * ffn_mult * 2 / 3))
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # pre-norm + residual
        x = x + self.ffn(self.norm2(x))
        return x


class ReasoningLLM(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8,
                 n_layers=6, max_seq=512, dropout=0.1):
        super().__init__()
        self.embed   = TokenEmbedding(vocab_size, d_model)
        self.blocks  = nn.ModuleList([
            TransformerBlock(d_model, n_heads, max_seq, dropout=dropout)
            for _ in range(n_layers)
        ])
        self.norm    = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # Weight tying — cuts params by 15–30%, zero accuracy cost
        self.lm_head.weight = self.embed.embed.weight

    def forward(self, ids):            # (B, T)
        x = self.embed(ids)
        for block in self.blocks:
            x = block(x)
        return self.lm_head(self.norm(x))  # (B, T, vocab_size)

7. Step 5 — Chain-of-Thought Training Strategy

This is the key differentiator. Every training example is formatted as a problem → scratchpad → answer triple. We compute loss over both spans with separate weights.

📊 CoT Training Sequence & Loss Mask

TRAINING SEQUENCE LAYOUT PROBLEM loss weight = 0.0 <think> SCRATCHPAD loss weight = α = 0.5 </think> <answer> ANSWER loss weight = 1.0 LOSS WEIGHT PER TOKEN 0.0 0.5 1.0 EXAMPLE TRAINING SAMPLE <bos> Q: What is 17×23? <think> 17×23 = 17×20 + 17×3 = 340 + 51 = 391 </think> <answer> 391 </answer> <eos>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# loss.py
import torch, torch.nn.functional as F

THINK_START, THINK_END = 4, 5
ANS_START,   ANS_END   = 6, 7

def build_loss_mask(ids: torch.Tensor, alpha: float = 0.5) -> torch.Tensor:
    B, T    = ids.shape
    weights = torch.zeros(B, T, device=ids.device)
    for b in range(B):
        in_think = in_answer = False
        for t in range(T):
            tok = ids[b, t].item()
            if   tok == THINK_START: in_think  = True
            elif tok == THINK_END:   in_think  = False
            elif tok == ANS_START:   in_answer = True
            elif tok == ANS_END:     in_answer = False
            elif in_think:   weights[b, t] = alpha
            elif in_answer:  weights[b, t] = 1.0
    return weights


def reasoning_loss(logits, targets, ids, alpha: float = 0.5):
    B, T, V = logits.shape
    mask    = build_loss_mask(ids, alpha)
    ce      = F.cross_entropy(
                  logits.view(-1, V), targets.view(-1), reduction='none'
              ).view(B, T)
    return (ce * mask).sum() / mask.sum().clamp(min=1)

8. Step 6 — The Training Loop

📊 Learning Rate Schedule

COSINE LR SCHEDULE WITH WARMUP training steps → LR peak ×0.1 warmup step 200 linear ramp up cosine decay floor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# train.py
import torch, math
from model     import ReasoningLLM
from loss      import reasoning_loss

CFG = {
    "vocab_size": 8000,  "d_model": 256,   "n_heads": 8,
    "n_layers":   6,     "max_seq": 512,   "lr":      3e-4,
    "batch":      32,    "epochs":  20,    "alpha":   0.5,
    "warmup":     200,   "clip":    1.0,
}

device = "cuda" if torch.cuda.is_available() else "cpu"
model  = ReasoningLLM(**{k: CFG[k] for k in
           ["vocab_size","d_model","n_heads","n_layers","max_seq"]}).to(device)

optim  = torch.optim.AdamW(model.parameters(), lr=CFG["lr"],
                            weight_decay=0.1, betas=(0.9, 0.95))

def lr_schedule(step, warmup, total):
    if step < warmup:
        return step / warmup
    pct = (step - warmup) / max(1, total - warmup)
    return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * pct))

total_steps = CFG["epochs"] * steps_per_epoch
scheduler   = torch.optim.lr_scheduler.LambdaLR(
    optim, lambda s: lr_schedule(s, CFG["warmup"], total_steps)
)

for epoch in range(CFG["epochs"]):
    model.train()
    epoch_loss = 0.0
    for ids, targets in dataloader:
        ids, targets = ids.to(device), targets.to(device)
        logits       = model(ids)
        loss         = reasoning_loss(logits[:, :-1], targets[:, 1:],
                                      ids[:, :-1], alpha=CFG["alpha"])
        optim.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["clip"])
        optim.step()
        scheduler.step()
        epoch_loss += loss.item()

    print(f"Epoch {epoch+1:02d}  loss={epoch_loss/steps_per_epoch:.4f}"
          f"  lr={scheduler.get_last_lr()[0]:.2e}")

⚠️ Training data format is everything. Your dataloader must yield sequences formatted as <bos> problem <think> steps </think> <answer> answer </answer> <eos>. Without CoT-formatted data, the model skips the scratchpad entirely.

9. Optimisation Tips

📊 Hyperparameter Cheatsheet

ParameterToy (CPU)Small (8 GB)Tip
d_model128–256512–768divisible by n_heads
n_layers4–612–24deeper > wider for CoT
n_heads4–88–16d_head ≥ 32
lr3e-41e-4 → 1e-5warmup + cosine
alpha (α)0.50.3–0.7↑ if no <think> emitted
max_seq256–5121024–2048scratchpad needs room
batch_size8–1632–128use grad accumulation

Weight tying — Share weights between TokenEmbedding and lm_head. Cuts parameter count by 15–30% with zero accuracy loss. Standard in GPT-2, LLaMA, and Mistral.

Gradient clipping — Always clip at 1.0. Transformer training can produce large gradients unexpectedly, especially early in CoT training.

Pre-norm vs post-norm — Use pre-norm (apply RMSNorm before the sublayer). Pre-norm is significantly more stable with deeper networks.

Good open-source CoT datasets:

DatasetDomainCoT Format
GSM8KGrade school mathStep-by-step
MetaMathMath (augmented)Chain-of-thought
NuminaMathCompetition mathFull working
OpenWebMathGeneral math textSemi-structured
WizardLM-CoTGeneral reasoningMulti-step

10. What’s Next?

📊 Extension Roadmap

flowchart TD
    BASE["✅ Current Build\nTokenizer + Transformer + CoT Loss"]

    BASE --> KV["KV Cache\nFast autoregressive generation\n⚡ Medium"]
    BASE --> GQA["Grouped Query Attention\nLess VRAM · Faster inference\n⚡ Medium"]
    BASE --> DPO["RLHF / DPO\nReward correct reasoning steps\n🔥 Hard"]

    KV  --> SPEC["Speculative Decoding\n3–4× faster generation\n🔥 Hard"]
    GQA --> SPEC
    DPO --> SPEC

    SPEC --> MOE["Mixture of Experts\nScale params ≠ scale compute\n🔥 Hard"]
ExtensionWhat It AddsDifficultyPure Python?
KV CacheFast autoregressive generationMedium
Grouped Query AttentionLess VRAM, faster inferenceMedium
RLHF / DPO on CoT tracesReward correct reasoning stepsHard
Speculative Decoding3–4× faster generationHard
Mixture of Experts (MoE)Scale params without scaling computeHard

Conclusion

Reasoning in LLMs is not magic — it is a structured engineering choice. By teaching the model to write its work inside <think> tokens, and training on those scratchpad traces with a weighted dual-loss, you get a model that makes far fewer careless errors on multi-step problems.

You have now built every layer from first principles:

  • BPE Tokenizer with special <think> / <answer> tokens
  • RoPE Positional Encoding baked directly into attention
  • Causal Multi-Head Self-Attention with Flash Attention support
  • SwiGLU Feed-Forward + RMSNorm Transformer Block
  • Full ReasoningLLM class with weight tying
  • Dual-loss (reasoning_loss) weighting scratchpad vs answer spans
  • AdamW + cosine LR Training loop with gradient clipping

The entire codebase — tokenizer.py, embeddings.py, attention.py, model.py, loss.py, train.py — is under 600 lines of pure Python and PyTorch.

No external AI API was used. No wrappers. No shortcuts. That is your reasoning model.

Khushal Jethava
Khushal Jethava

Machine Learning Engineer at Codiste, specializing in Generative AI, NLP, and Computer Vision. Building production AI systems with Python.

This post is licensed under CC BY 4.0 by the author.