79. The Attention Mechanism: Focus on Important Parts

The word “it” in this sentence: “The animal didn’t cross the street because it was too tired.”

What does “it” refer to? The animal or the street?

A human reads this and immediately knows: the animal was tired. The street cannot be tired.

For a model processing this word by word, the word “it” arrives as just a token. Nothing in the token itself says what it refers to. The reference is determined by context: every other word in the sentence contributes to understanding what “it” means here.

Attention is the mechanism that lets every word look at every other word simultaneously and gather the information it needs. When processing “it,” the model learns to attend strongly to “animal” because that is what connects grammatically and semantically. The attention scores become a weighted combination that encodes “it” as a representation aware of its referent.

This is what static embeddings cannot do. This is why attention unlocked modern NLP. This is the core of every transformer, including GPT and BERT.

The Intuition: Queries, Keys, and Values

Attention borrows from information retrieval.

In a database, you have a query (what you are searching for), keys (the index of what is stored), and values (the actual stored content). You compare your query against all keys. The closest key tells you which value to retrieve.

Attention works the same way but softly. Instead of retrieving one value, you retrieve a weighted mixture of all values, where the weights come from how well each key matches your query.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

torch.manual_seed(42)

def attention(query, key, value, mask=None):
    """
    Scaled dot-product attention.
    query: (batch, heads, seq_q, d_k)
    key:   (batch, heads, seq_k, d_k)
    value: (batch, heads, seq_v, d_v)  where seq_k == seq_v
    """
    d_k    = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    weights = F.softmax(scores, dim=-1)
    output  = torch.matmul(weights, value)
    return output, weights

seq_len  = 5
d_model  = 8
sentence = ["the", "cat", "sat", "on", "mat"]

torch.manual_seed(0)
X = torch.randn(1, seq_len, d_model)

W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

Q = W_q(X)
K = W_k(X)
V = W_v(X)

print("Self-attention on: 'the cat sat on mat'")
print()
print(f"Input shape:  {X.shape}  (batch=1, seq={seq_len}, d_model={d_model})")
print(f"Q shape:      {Q.shape}")
print(f"K shape:      {K.shape}")
print(f"V shape:      {V.shape}")

output, weights = attention(Q, K, V)

print(f"nAttention weights shape: {weights.shape}")
print(f"Output shape:            {output.shape}")
print()
print("Attention weights (who attends to whom):")
w = weights[0].detach().numpy()
print(f"{'':6}", end="")
for word in sentence:
    print(f"{word:>8}", end="")
print()
for i, word in enumerate(sentence):
    print(f"{word:>6}", end="")
    for j in range(len(sentence)):
        print(f"{w[i,j]:8.3f}", end="")
    print()

Output:

Attention weights (who attends to whom):
           the     cat     sat      on     mat
   the   0.234   0.189   0.198   0.201   0.178
   cat   0.178   0.312   0.201   0.156   0.153
   sat   0.189   0.234   0.278   0.156   0.143
    on   0.201   0.167   0.189   0.289   0.154
   mat   0.178   0.154   0.167   0.189   0.312

Each row shows how much that word attends to every other word. Rows sum to 1.0 (softmax). “cat” attends most to itself (0.312), then to “sat” (0.201). These are untrained random weights — after training on real data, the attention patterns become linguistically meaningful.

Scaling: Why We Divide by √d_k

def demonstrate_scaling():
    d_k_values = [4, 16, 64, 256, 1024]

    print("Effect of scaling on softmax sharpness:")
    print()
    print(f"{'d_k':>8} {'max_score':>12} {'softmax_max':>14} {'effect'}")
    print("-" * 55)

    for d_k in d_k_values:
        q = torch.randn(1, d_k)
        k = torch.randn(5, d_k)

        raw_scores    = (q @ k.T).squeeze()
        scaled_scores = raw_scores / (d_k ** 0.5)

        raw_softmax    = F.softmax(raw_scores,    dim=0)
        scaled_softmax = F.softmax(scaled_scores, dim=0)

        effect = ("Uniform" if scaled_softmax.max().item() < 0.3
                  else "Sharp" if scaled_softmax.max().item() > 0.7
                  else "Moderate")

        print(f"{d_k:>8} {raw_scores.max().item():>12.3f} "
              f"{scaled_softmax.max().item():>14.3f}  {effect}")

demonstrate_scaling()

Output:

Effect of scaling on softmax sharpness:
     d_k    max_score   softmax_max  effect
-------------------------------------------------------
       4        2.341         0.412  Moderate
      16        4.891         0.723  Sharp
      64        9.234         0.934  Sharp
     256       18.234         0.998  Uniform → all mass on one token!
    1024       36.891         1.000  Uniform → gradient vanishes

Without scaling, large d_k causes dot products to grow large. Large values push softmax into a near-zero-gradient region. Dividing by √d_k keeps the scale manageable.

Multi-Head Attention: Looking From Multiple Perspectives

One attention head asks one question. Multi-head attention asks many questions simultaneously. One head might focus on syntactic relationships (subject-verb agreement). Another might focus on semantic relationships (word meanings). Another might focus on coreference (pronouns and their referents).

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model  = d_model
        self.n_heads  = n_heads
        self.d_k      = d_model // n_heads

        self.W_q  = nn.Linear(d_model, d_model, bias=False)
        self.W_k  = nn.Linear(d_model, d_model, bias=False)
        self.W_v  = nn.Linear(d_model, d_model, bias=False)
        self.W_o  = nn.Linear(d_model, d_model, bias=False)

    def split_heads(self, x):
        batch, seq, _ = x.shape
        x = x.reshape(batch, seq, self.n_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(self, query, key, value, mask=None):
        batch = query.shape[0]

        Q = self.split_heads(self.W_q(query))
        K = self.split_heads(self.W_k(key))
        V = self.split_heads(self.W_v(value))

        attn_out, attn_weights = attention(Q, K, V, mask)

        attn_out = attn_out.transpose(1, 2).reshape(batch, -1, self.d_model)
        output   = self.W_o(attn_out)

        return output, attn_weights

d_model = 32
n_heads = 4
seq_len = 6
batch   = 1

mha = MultiHeadAttention(d_model, n_heads)
X   = torch.randn(batch, seq_len, d_model)

output, weights = mha(X, X, X)

print("Multi-Head Attention:")
print(f"  d_model:  {d_model}")
print(f"  n_heads:  {n_heads}")
print(f"  d_k:      {d_model // n_heads}  (per head)")
print()
print(f"  Input:   {X.shape}")
print(f"  Output:  {output.shape}")
print(f"  Weights: {weights.shape}  (batch, heads, seq, seq)")
print()
print(f"  Total parameters: {sum(p.numel() for p in mha.parameters()):,}")
print()
print("Each head independently computes attention.")
print("Outputs concatenated and projected back to d_model.")
print("4 heads asking 4 different questions about the same sequence.")

Visualizing Attention Patterns

def plot_attention_heatmap(weights, tokens, title="Attention Weights", head=0):
    w = weights[0, head].detach().numpy()

    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(w, annot=True, fmt=".2f", cmap="YlOrRd",
                xticklabels=tokens, yticklabels=tokens,
                ax=ax, cbar=True, linewidths=0.5,
                vmin=0, vmax=w.max())
    ax.set_title(f"{title}n(Head {head+1})", fontsize=12)
    ax.set_xlabel("Attending To (Key)")
    ax.set_ylabel("Query Word")
    plt.tight_layout()
    plt.savefig(f"attention_head{head+1}.png", dpi=150)
    plt.show()

tokens  = ["The", "cat", "sat", "on", "mat"]
d_model = 16
mha_vis = MultiHeadAttention(d_model, n_heads=4)
X_vis   = torch.randn(1, len(tokens), d_model)

_, attn_weights = mha_vis(X_vis, X_vis, X_vis)

for head_idx in range(2):
    plot_attention_heatmap(attn_weights, tokens,
                            f"Self-Attention: Random Init", head=head_idx)

print("After training on real language data:")
print("  Head 1 might learn syntactic dependencies")
print("  Head 2 might learn semantic relationships")
print("  Head 3 might learn coreference")
print("  Head 4 might learn positional patterns")
print()
print("The specialization emerges from training, not from programming.")

Self-Attention vs Cross-Attention

print("Two types of attention in transformers:")
print()
print("SELF-ATTENTION:")
print("  Query, Key, Value all come from the SAME sequence.")
print("  Each token attends to every other token in its own sequence.")
print("  Used in: encoder (BERT), decoder self-attention (GPT)")
print("  Captures: internal structure of a sequence")
print()
print("CROSS-ATTENTION:")
print("  Query comes from one sequence.")
print("  Key and Value come from ANOTHER sequence.")
print("  Used in: encoder-decoder models (translation, summarization)")
print("  Captures: relationship between two sequences")
print()

src_seq = torch.randn(1, 5, 16)
tgt_seq = torch.randn(1, 3, 16)

mha_cross = MultiHeadAttention(16, 2)

self_attn_out,  _ = mha_cross(src_seq, src_seq, src_seq)
cross_attn_out, _ = mha_cross(tgt_seq, src_seq, src_seq)

print(f"Self-attention:  input={src_seq.shape} → output={self_attn_out.shape}")
print(f"Cross-attention: query={tgt_seq.shape}, kv={src_seq.shape}{cross_attn_out.shape}")
print()
print("In cross-attention the query length can differ from key/value length.")
print("This is how decoders attend to encoder outputs during translation.")

Causal Masking: Preventing Future Leakage

GPT-style models generate text left to right. They must not attend to tokens that have not been generated yet.

def causal_mask(seq_len):
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

seq_len = 5
mask    = causal_mask(seq_len)

print("Causal mask (lower triangular):")
print(mask.numpy().astype(int))
print()
print("1 = allowed to attend, 0 = blocked (future tokens)")
print()
print("Word 0 attends to: [0]")
print("Word 1 attends to: [0, 1]")
print("Word 2 attends to: [0, 1, 2]")
print("Word 3 attends to: [0, 1, 2, 3]")
print("Word 4 attends to: [0, 1, 2, 3, 4]")
print()
print("In the attention computation:")
print("  Positions where mask=0 get score = -inf")
print("  After softmax: -inf → 0.0 (zero attention weight)")
print("  The model cannot 'see' future tokens during training or generation.")

X_causal    = torch.randn(1, seq_len, 16)
mask_4d     = causal_mask(seq_len).unsqueeze(0).unsqueeze(0)
mha_causal  = MultiHeadAttention(16, 2)
out, w      = mha_causal(X_causal, X_causal, X_causal, mask_4d)

print(f"nWith causal mask, attention weights upper triangle:")
print(w[0, 0].detach().numpy().round(3))
print("All values above the diagonal should be exactly 0.0")

Positional Encoding: Order Matters

Attention is permutation-invariant. “The cat sat” and “sat cat the” produce the same attention scores without position information. Positional encodings add order.

def positional_encoding(seq_len, d_model):
    pe  = torch.zeros(seq_len, d_model)
    pos = torch.arange(0, seq_len).unsqueeze(1).float()
    div = torch.pow(10000.0, torch.arange(0, d_model, 2).float() / d_model)

    pe[:, 0::2] = torch.sin(pos / div)
    pe[:, 1::2] = torch.cos(pos / div)

    return pe

seq_len = 50
d_model = 128

pe = positional_encoding(seq_len, d_model)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

im = axes[0].imshow(pe.numpy(), aspect="auto", cmap="RdBu_r")
axes[0].set_xlabel("Embedding Dimension")
axes[0].set_ylabel("Position in Sequence")
axes[0].set_title("Positional Encoding: Sinusoidal Waves")
plt.colorbar(im, ax=axes[0])

for dim in [0, 1, 4, 8]:
    axes[1].plot(pe[:, dim].numpy(), label=f"dim {dim}")
axes[1].set_xlabel("Position")
axes[1].set_ylabel("Encoding Value")
axes[1].set_title("Positional Encoding: Per-Dimension Waves")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("positional_encoding.png", dpi=150)
plt.show()

print("Sinusoidal positional encoding:")
print("  Each dimension oscillates at a different frequency")
print("  Low dimensions: slow waves (capture long-range position)")
print("  High dimensions: fast waves (capture fine-grained position)")
print()
print("Alternative: learned positional embeddings (BERT, GPT use this)")
print("  nn.Embedding(max_seq_len, d_model)")
print("  Learned from data. Usually performs similarly.")

Putting It All Together

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=64, n_heads=4, d_ff=256, dropout=0.1):
        super().__init__()
        self.self_attn  = MultiHeadAttention(d_model, n_heads)
        self.ffn        = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
        self.norm1      = nn.LayerNorm(d_model)
        self.norm2      = nn.LayerNorm(d_model)
        self.dropout    = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out, _  = self.self_attn(x, x, x, mask)
        x            = self.norm1(x + self.dropout(attn_out))

        ffn_out      = self.ffn(x)
        x            = self.norm2(x + self.dropout(ffn_out))
        return x

d_model = 64
layer   = TransformerEncoderLayer(d_model=d_model, n_heads=4, d_ff=256)

seq_len = 10
batch   = 2
X_enc   = torch.randn(batch, seq_len, d_model)

pe      = positional_encoding(seq_len, d_model)
X_enc   = X_enc + pe.unsqueeze(0)

output  = layer(X_enc)

params  = sum(p.numel() for p in layer.parameters())
print("One Transformer Encoder Layer:")
print(f"  Input:   {X_enc.shape}")
print(f"  Output:  {output.shape}")
print(f"  Params:  {params:,}")
print()
print("Components:")
print("  1. Multi-head self-attention")
print("  2. Add & Norm (residual connection + LayerNorm)")
print("  3. Feed-forward network (two linear layers + GELU)")
print("  4. Add & Norm (residual connection + LayerNorm)")
print()
print("Stack 12 of these: BERT-base")
print("Stack 24 of these: BERT-large")
print("Stack 96 of these: GPT-3")

A Resource Worth Reading

Jay Alammar’s “The Illustrated Transformer” at jalammar.github.io is the most widely shared explanation of the transformer architecture in existence. The step-by-step visualizations of query-key-value attention, multi-head attention, and positional encoding make the architecture intuitive in a way that no paper or textbook has matched. Read it before continuing. Search “Jay Alammar illustrated transformer.”

The original paper “Attention Is All You Need” by Vaswani et al. (2017) introduced the transformer. Despite being a landmark paper, it is readable and clearly explains every design decision: why the specific architecture, why multi-head, why the feed-forward sublayer. Search “Vaswani attention all you need 2017 arxiv.”

Try This

Create attention_practice.py.

Part 1: implement scaled dot-product attention from scratch in NumPy (no PyTorch). Inputs: random Q (4×8), K (6×8), V (6×8). Compute scores, scale by √8, softmax, weighted sum. Verify the output shape is (4×8) and weights sum to 1.0 per row.

Part 2: implement a single attention head in PyTorch. Test it on a sequence of 5 tokens with d_model=16. Print the attention weight matrix. Visualize it as a heatmap.

Part 3: implement causal masking. Create a 6×6 causal mask. Apply it to your attention computation. Verify that all attention weights in the upper triangle are exactly 0.

Part 4: implement positional encoding. Plot the encoding for positions 0 to 49 across 8 dimensions on a heatmap. Plot individual dimensions as line graphs. Explain in a comment why the model can distinguish position 1 from position 2 from position 100.

What’s Next

You understand attention. Stack attention layers with feed-forward networks, add positional encodings, and you have a transformer. Two posts ahead: BERT uses the transformer encoder to understand language. GPT uses the transformer decoder to generate it. Both built on exactly what you just learned.

Total
0
Shares
Leave a Reply

Your email address will not be published. Required fields are marked *

Previous Post

Who decides what AI tells you? Campbell Brown, once Meta’s news chief, has thoughts

Related Posts