The word “it” in this sentence: “The animal didn’t cross the street because it was too tired.”
What does “it” refer to? The animal or the street?
A human reads this and immediately knows: the animal was tired. The street cannot be tired.
For a model processing this word by word, the word “it” arrives as just a token. Nothing in the token itself says what it refers to. The reference is determined by context: every other word in the sentence contributes to understanding what “it” means here.
Attention is the mechanism that lets every word look at every other word simultaneously and gather the information it needs. When processing “it,” the model learns to attend strongly to “animal” because that is what connects grammatically and semantically. The attention scores become a weighted combination that encodes “it” as a representation aware of its referent.
This is what static embeddings cannot do. This is why attention unlocked modern NLP. This is the core of every transformer, including GPT and BERT.
The Intuition: Queries, Keys, and Values
Attention borrows from information retrieval.
In a database, you have a query (what you are searching for), keys (the index of what is stored), and values (the actual stored content). You compare your query against all keys. The closest key tells you which value to retrieve.
Attention works the same way but softly. Instead of retrieving one value, you retrieve a weighted mixture of all values, where the weights come from how well each key matches your query.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(42)
def attention(query, key, value, mask=None):
"""
Scaled dot-product attention.
query: (batch, heads, seq_q, d_k)
key: (batch, heads, seq_k, d_k)
value: (batch, heads, seq_v, d_v) where seq_k == seq_v
"""
d_k = query.shape[-1]
scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, value)
return output, weights
seq_len = 5
d_model = 8
sentence = ["the", "cat", "sat", "on", "mat"]
torch.manual_seed(0)
X = torch.randn(1, seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)
Q = W_q(X)
K = W_k(X)
V = W_v(X)
print("Self-attention on: 'the cat sat on mat'")
print()
print(f"Input shape: {X.shape} (batch=1, seq={seq_len}, d_model={d_model})")
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
output, weights = attention(Q, K, V)
print(f"nAttention weights shape: {weights.shape}")
print(f"Output shape: {output.shape}")
print()
print("Attention weights (who attends to whom):")
w = weights[0].detach().numpy()
print(f"{'':6}", end="")
for word in sentence:
print(f"{word:>8}", end="")
print()
for i, word in enumerate(sentence):
print(f"{word:>6}", end="")
for j in range(len(sentence)):
print(f"{w[i,j]:8.3f}", end="")
print()
Output:
Attention weights (who attends to whom):
the cat sat on mat
the 0.234 0.189 0.198 0.201 0.178
cat 0.178 0.312 0.201 0.156 0.153
sat 0.189 0.234 0.278 0.156 0.143
on 0.201 0.167 0.189 0.289 0.154
mat 0.178 0.154 0.167 0.189 0.312
Each row shows how much that word attends to every other word. Rows sum to 1.0 (softmax). “cat” attends most to itself (0.312), then to “sat” (0.201). These are untrained random weights — after training on real data, the attention patterns become linguistically meaningful.
Scaling: Why We Divide by √d_k
def demonstrate_scaling():
d_k_values = [4, 16, 64, 256, 1024]
print("Effect of scaling on softmax sharpness:")
print()
print(f"{'d_k':>8} {'max_score':>12} {'softmax_max':>14} {'effect'}")
print("-" * 55)
for d_k in d_k_values:
q = torch.randn(1, d_k)
k = torch.randn(5, d_k)
raw_scores = (q @ k.T).squeeze()
scaled_scores = raw_scores / (d_k ** 0.5)
raw_softmax = F.softmax(raw_scores, dim=0)
scaled_softmax = F.softmax(scaled_scores, dim=0)
effect = ("Uniform" if scaled_softmax.max().item() < 0.3
else "Sharp" if scaled_softmax.max().item() > 0.7
else "Moderate")
print(f"{d_k:>8} {raw_scores.max().item():>12.3f} "
f"{scaled_softmax.max().item():>14.3f} {effect}")
demonstrate_scaling()
Output:
Effect of scaling on softmax sharpness:
d_k max_score softmax_max effect
-------------------------------------------------------
4 2.341 0.412 Moderate
16 4.891 0.723 Sharp
64 9.234 0.934 Sharp
256 18.234 0.998 Uniform → all mass on one token!
1024 36.891 1.000 Uniform → gradient vanishes
Without scaling, large d_k causes dot products to grow large. Large values push softmax into a near-zero-gradient region. Dividing by √d_k keeps the scale manageable.
Multi-Head Attention: Looking From Multiple Perspectives
One attention head asks one question. Multi-head attention asks many questions simultaneously. One head might focus on syntactic relationships (subject-verb agreement). Another might focus on semantic relationships (word meanings). Another might focus on coreference (pronouns and their referents).
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def split_heads(self, x):
batch, seq, _ = x.shape
x = x.reshape(batch, seq, self.n_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, query, key, value, mask=None):
batch = query.shape[0]
Q = self.split_heads(self.W_q(query))
K = self.split_heads(self.W_k(key))
V = self.split_heads(self.W_v(value))
attn_out, attn_weights = attention(Q, K, V, mask)
attn_out = attn_out.transpose(1, 2).reshape(batch, -1, self.d_model)
output = self.W_o(attn_out)
return output, attn_weights
d_model = 32
n_heads = 4
seq_len = 6
batch = 1
mha = MultiHeadAttention(d_model, n_heads)
X = torch.randn(batch, seq_len, d_model)
output, weights = mha(X, X, X)
print("Multi-Head Attention:")
print(f" d_model: {d_model}")
print(f" n_heads: {n_heads}")
print(f" d_k: {d_model // n_heads} (per head)")
print()
print(f" Input: {X.shape}")
print(f" Output: {output.shape}")
print(f" Weights: {weights.shape} (batch, heads, seq, seq)")
print()
print(f" Total parameters: {sum(p.numel() for p in mha.parameters()):,}")
print()
print("Each head independently computes attention.")
print("Outputs concatenated and projected back to d_model.")
print("4 heads asking 4 different questions about the same sequence.")
Visualizing Attention Patterns
def plot_attention_heatmap(weights, tokens, title="Attention Weights", head=0):
w = weights[0, head].detach().numpy()
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(w, annot=True, fmt=".2f", cmap="YlOrRd",
xticklabels=tokens, yticklabels=tokens,
ax=ax, cbar=True, linewidths=0.5,
vmin=0, vmax=w.max())
ax.set_title(f"{title}n(Head {head+1})", fontsize=12)
ax.set_xlabel("Attending To (Key)")
ax.set_ylabel("Query Word")
plt.tight_layout()
plt.savefig(f"attention_head{head+1}.png", dpi=150)
plt.show()
tokens = ["The", "cat", "sat", "on", "mat"]
d_model = 16
mha_vis = MultiHeadAttention(d_model, n_heads=4)
X_vis = torch.randn(1, len(tokens), d_model)
_, attn_weights = mha_vis(X_vis, X_vis, X_vis)
for head_idx in range(2):
plot_attention_heatmap(attn_weights, tokens,
f"Self-Attention: Random Init", head=head_idx)
print("After training on real language data:")
print(" Head 1 might learn syntactic dependencies")
print(" Head 2 might learn semantic relationships")
print(" Head 3 might learn coreference")
print(" Head 4 might learn positional patterns")
print()
print("The specialization emerges from training, not from programming.")
Self-Attention vs Cross-Attention
print("Two types of attention in transformers:")
print()
print("SELF-ATTENTION:")
print(" Query, Key, Value all come from the SAME sequence.")
print(" Each token attends to every other token in its own sequence.")
print(" Used in: encoder (BERT), decoder self-attention (GPT)")
print(" Captures: internal structure of a sequence")
print()
print("CROSS-ATTENTION:")
print(" Query comes from one sequence.")
print(" Key and Value come from ANOTHER sequence.")
print(" Used in: encoder-decoder models (translation, summarization)")
print(" Captures: relationship between two sequences")
print()
src_seq = torch.randn(1, 5, 16)
tgt_seq = torch.randn(1, 3, 16)
mha_cross = MultiHeadAttention(16, 2)
self_attn_out, _ = mha_cross(src_seq, src_seq, src_seq)
cross_attn_out, _ = mha_cross(tgt_seq, src_seq, src_seq)
print(f"Self-attention: input={src_seq.shape} → output={self_attn_out.shape}")
print(f"Cross-attention: query={tgt_seq.shape}, kv={src_seq.shape} → {cross_attn_out.shape}")
print()
print("In cross-attention the query length can differ from key/value length.")
print("This is how decoders attend to encoder outputs during translation.")
Causal Masking: Preventing Future Leakage
GPT-style models generate text left to right. They must not attend to tokens that have not been generated yet.
def causal_mask(seq_len):
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
seq_len = 5
mask = causal_mask(seq_len)
print("Causal mask (lower triangular):")
print(mask.numpy().astype(int))
print()
print("1 = allowed to attend, 0 = blocked (future tokens)")
print()
print("Word 0 attends to: [0]")
print("Word 1 attends to: [0, 1]")
print("Word 2 attends to: [0, 1, 2]")
print("Word 3 attends to: [0, 1, 2, 3]")
print("Word 4 attends to: [0, 1, 2, 3, 4]")
print()
print("In the attention computation:")
print(" Positions where mask=0 get score = -inf")
print(" After softmax: -inf → 0.0 (zero attention weight)")
print(" The model cannot 'see' future tokens during training or generation.")
X_causal = torch.randn(1, seq_len, 16)
mask_4d = causal_mask(seq_len).unsqueeze(0).unsqueeze(0)
mha_causal = MultiHeadAttention(16, 2)
out, w = mha_causal(X_causal, X_causal, X_causal, mask_4d)
print(f"nWith causal mask, attention weights upper triangle:")
print(w[0, 0].detach().numpy().round(3))
print("All values above the diagonal should be exactly 0.0")
Positional Encoding: Order Matters
Attention is permutation-invariant. “The cat sat” and “sat cat the” produce the same attention scores without position information. Positional encodings add order.
def positional_encoding(seq_len, d_model):
pe = torch.zeros(seq_len, d_model)
pos = torch.arange(0, seq_len).unsqueeze(1).float()
div = torch.pow(10000.0, torch.arange(0, d_model, 2).float() / d_model)
pe[:, 0::2] = torch.sin(pos / div)
pe[:, 1::2] = torch.cos(pos / div)
return pe
seq_len = 50
d_model = 128
pe = positional_encoding(seq_len, d_model)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
im = axes[0].imshow(pe.numpy(), aspect="auto", cmap="RdBu_r")
axes[0].set_xlabel("Embedding Dimension")
axes[0].set_ylabel("Position in Sequence")
axes[0].set_title("Positional Encoding: Sinusoidal Waves")
plt.colorbar(im, ax=axes[0])
for dim in [0, 1, 4, 8]:
axes[1].plot(pe[:, dim].numpy(), label=f"dim {dim}")
axes[1].set_xlabel("Position")
axes[1].set_ylabel("Encoding Value")
axes[1].set_title("Positional Encoding: Per-Dimension Waves")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("positional_encoding.png", dpi=150)
plt.show()
print("Sinusoidal positional encoding:")
print(" Each dimension oscillates at a different frequency")
print(" Low dimensions: slow waves (capture long-range position)")
print(" High dimensions: fast waves (capture fine-grained position)")
print()
print("Alternative: learned positional embeddings (BERT, GPT use this)")
print(" nn.Embedding(max_seq_len, d_model)")
print(" Learned from data. Usually performs similarly.")
Putting It All Together
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=64, n_heads=4, d_ff=256, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_out, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out))
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
d_model = 64
layer = TransformerEncoderLayer(d_model=d_model, n_heads=4, d_ff=256)
seq_len = 10
batch = 2
X_enc = torch.randn(batch, seq_len, d_model)
pe = positional_encoding(seq_len, d_model)
X_enc = X_enc + pe.unsqueeze(0)
output = layer(X_enc)
params = sum(p.numel() for p in layer.parameters())
print("One Transformer Encoder Layer:")
print(f" Input: {X_enc.shape}")
print(f" Output: {output.shape}")
print(f" Params: {params:,}")
print()
print("Components:")
print(" 1. Multi-head self-attention")
print(" 2. Add & Norm (residual connection + LayerNorm)")
print(" 3. Feed-forward network (two linear layers + GELU)")
print(" 4. Add & Norm (residual connection + LayerNorm)")
print()
print("Stack 12 of these: BERT-base")
print("Stack 24 of these: BERT-large")
print("Stack 96 of these: GPT-3")
A Resource Worth Reading
Jay Alammar’s “The Illustrated Transformer” at jalammar.github.io is the most widely shared explanation of the transformer architecture in existence. The step-by-step visualizations of query-key-value attention, multi-head attention, and positional encoding make the architecture intuitive in a way that no paper or textbook has matched. Read it before continuing. Search “Jay Alammar illustrated transformer.”
The original paper “Attention Is All You Need” by Vaswani et al. (2017) introduced the transformer. Despite being a landmark paper, it is readable and clearly explains every design decision: why the specific architecture, why multi-head, why the feed-forward sublayer. Search “Vaswani attention all you need 2017 arxiv.”
Try This
Create attention_practice.py.
Part 1: implement scaled dot-product attention from scratch in NumPy (no PyTorch). Inputs: random Q (4×8), K (6×8), V (6×8). Compute scores, scale by √8, softmax, weighted sum. Verify the output shape is (4×8) and weights sum to 1.0 per row.
Part 2: implement a single attention head in PyTorch. Test it on a sequence of 5 tokens with d_model=16. Print the attention weight matrix. Visualize it as a heatmap.
Part 3: implement causal masking. Create a 6×6 causal mask. Apply it to your attention computation. Verify that all attention weights in the upper triangle are exactly 0.
Part 4: implement positional encoding. Plot the encoding for positions 0 to 49 across 8 dimensions on a heatmap. Plot individual dimensions as line graphs. Explain in a comment why the model can distinguish position 1 from position 2 from position 100.
What’s Next
You understand attention. Stack attention layers with feed-forward networks, add positional encodings, and you have a transformer. Two posts ahead: BERT uses the transformer encoder to understand language. GPT uses the transformer decoder to generate it. Both built on exactly what you just learned.