How I trained a Stance-Aware Cross-Encoder that classifies Indonesian news headlines against claims — starting on a free Colab TPU and scaling out to Cloud TPU v5p with a single@kinetic.run() decorator
Introduction
Misinformation is one of the defining problems of the social-media era, and Indonesia has been hit particularly hard. Hoaks (the Indonesian shorthand for fake news) spread through WhatsApp groups and Twitter threads faster than any fact-checker can keep up. Most published research on automated fake-news detection focuses on English-language data, which leaves practitioners working with Bahasa Indonesia in a frustrating spot: the techniques exist, but the tooling and pre-trained models are scarce.
This article walks through building a real, working multimodal hoax detector for Indonesian news from scratch. The model takes two inputs — a claim (the original assertion, often from social media) and a headline (a news article headline that mentions the same topic) — and predicts whether the article supports the claim (for), refutes it (against), or merely observes it neutrally (observing).
The architecture is a Stance-Aware Cross-Encoder: a BiLSTM-style encoder for each input, multi-head self-attention, and a cross-attention layer that lets the claim and headline literally read each other before classification. Built end-to-end with JAX and Flax, trained on TPU.
The deployment story has two halves:
- Free Colab TPU for prototyping — Google gives every Colab user free access to a v5e-1 TPU, which is enough to train this model end-to-end in under an hour at zero cost.
- Cloud TPU v5p via Keras Kinetic for serious training — when you outgrow Colab’s runtime limits, Keras Kinetic lets you ship the same training function to a Cloud TPU pod with a single Python decorator. No Docker, no Kubernetes YAML, no SSH.
By the end, you’ll have:
- A reusable Indonesian tokenizer and dataset loader
- A 4-layer Transformer encoder with stance-aware cross-attention, written in Flax
- A JIT-compiled training loop with optax and orbaxcheckpointing
- A working predict.py that runs new claim-headline pairs through the trained model
- The exact same code, deployed to Cloud TPU via@kinetic.run()
Let’s get into it.
Why This Problem Is Hard (and Interesting)
Naive fake-news detectors look at one piece of text and try to classify it as “real” or “fake.” That’s both technically weak and ethically uncomfortable — a single text rarely carries enough signal, and the “true/false” framing assumes the model has access to ground truth it can’t possibly have.
The stance detection framing is much more honest. Given a claim and a related news article, the model doesn’t decide whether the claim is true; it decides whether this particular article supports, refutes, or merely observes the claim. That’s a question a model can actually answer, and it’s exactly the input a downstream fact-checker needs to make a final call.
Mathematically, the task is a 3-way classification over theinteraction of two pieces of text. That word — interaction — is what makes the architecture interesting. You can’t just encode each side independently and concatenate. You need a layer that lets the claim attend to the headline and vice versa, so the model can pick up on subtle cues like negation (“Government denies…”), hedging (“alleged…”), or framing (“according to critics…”).
Why JAX, Flax, Keras Kinetic, and TPU?
- JAX gives me NumPy-style code with automatic differentiation, JIT compilation via XLA, and transparent acceleration on CPU/GPU/TPU.
- Flax sits on top of JAX and lets me write neural networks as nn.Module classes. The model is dense in attention layers, and Flax keeps the parameter management clean.
- Optax for optimization (AdamW with linear warmup + cosine decay) and Orbax for checkpointing — both are part of the JAX ecosystem and JIT-friendly.
- Keras Kinetic is the deployment glue. One decorator turns a local Python function into a remote TPU job, with container caching, log streaming, and automatic GKE provisioning.
- TPU because the workload is dominated by attention matmuls — exactly what TPU systolic arrays are built for. Free on Colab (v5e-1), and Cloud TPU v5p when you need to scale up.
TPU vs GPU for This Workload
A multimodal Transformer with cross-attention is one of the cleanest TPU workloads you can write. Here’s why, and where GPUs still hold their own.
Hardware design
- GPU (NVIDIA A100/H100): general-purpose parallel processor, thousands of CUDA cores, great for arbitrary parallel computation.
- TPU (v5e or v5p): domain-specific accelerator built around a large systolic array (MXU) optimized for dense matrix multiplications.
What dominates the compute in this model
- Multi-head self-attention: softmax(Q Kᵀ / √d) V — three big matmuls per head per layer.
- Cross-attention between claim and headline: same shape, just with different inputs feeding Q vs K/V.
- Feedforward blocks: two Dense layers with GELU between them.
All of those are dense matmuls with predictable shapes. The TPU systolic array is purpose-built to chew through exactly this pattern at peak FLOPs. The XLA compiler fuses the entire train_step into a few kernels, and after the first compile, every step runs at full throughput.
Where GPUs still win for stance detection / NLP
- You’re doing token-level decoding with KV-cache and irregular generation lengths (we’re not — we’re doing classification).
- You need a HuggingFace transformers model that’s only available as a PyTorch checkpoint (we’re training from scratch, so this doesn’t apply).
- You want to iterate in a notebook with constant Python control flow that doesn’t JIT cleanly (Colab gives you both a TPU anda notebook, so you don’t have to choose).
Where TPUs win for stance detection / NLP
- Fixed-length sequences (we pad to 64 tokens) → predictable shapes → great XLA compilation.
- The whole train_step JITs into a single fused execution graph.
- pmap / shard_map make multi-chip training a one-liner if you want to scale up.
- Free on Colab, and Cloud TPU v5e is roughly $0.40/chip-hour on Spot.
Rule of thumb
- Quick prototyping in a notebook with Bahasa Indonesia data → free Colab TPU. (This article.)
- Iterative R&D using HuggingFace PyTorch checkpoints → GPU.
- Production training with batchable, JAX-native workloads → Cloud TPU via Kinetic.
- You need to fine-tune a 7B+ Indonesian LLM → that’s a different article (and a different category — vLLM or Tunix).
Now let’s actually build it.
Project Architecture
The pipeline is straightforward:
datasetika.csv (Claim, Judul, Stance)
│
▼
┌──────────────────┐
│ IndonesianTokenizer │ whitespace + punctuation, vocab from corpus
└──────────────────┘
│
▼
┌──────────────────┐
│ FakeNewsDataset │ stratified train/val/test split, JAX-ready arrays
└──────────────────┘
│
▼
┌──────────────────────────────────────┐
│ FakeNewsDetector (Flax nn.Module) │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Token+Pos │ │ Token+Pos │ │
│ │ Embedding │ │ Embedding │ │
│ │ (Claim) │ │ (Headline) │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Transformer │ │ Transformer │ │
│ │ × N layers │ │ × N layers │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ └────────┬────────┘ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Stance Cross-Encoder│ │
│ │ (cross-attention + │ │
│ │ diff & product) │ │
│ └──────────┬──────────┘ │
│ ▼ │
│ Dense → Softmax (3 classes) │
└──────────────────────────────────────┘
│
▼
for / against / observing
The whole thing sits inside one jax.jit-compiled train_step. Now let’s walk through each piece.
Step 1 — Hardware Setup
There are two paths. Pick whichever fits your stage of the project.
Path A: Free Colab TPU (recommended for first run)
- Open colab.research.google.com and create a new notebook.
- Click Runtime → Change runtime type.
- Under Hardware accelerator, select v5e-1 TPU.
- Click Save.
- Verify in a cell:
import jax
print(jax.devices())
# Expected: [TpuDevice(id=0, ...)]
That’s it. You now have a free TPU v5e chip for the duration of your Colab session.
Path B: Cloud TPU via Keras Kinetic (when you outgrow Colab)
Colab is fantastic for prototyping but has runtime limits and gets bumped under load. When you’re ready to run multi-hour training jobs, switch to a real Cloud TPU. The traditional path means provisioning a TPU VM, SSHing in, installing dependencies, and uploading scripts — Kinetic skips all of that.
On your local laptop:
pip install keras-kinetic
gcloud auth application-default login
gcloud config set project YOUR_PROJECT_ID
kinetic up --accelerator v5p-8 --yes
The last command provisions a GKE Autopilot cluster with a TPU v5p-8 node pool. Takes a few minutes the first time, after which you don’t touch infrastructure again until tear-down.
I’ll show the actual @kinetic.run() deployment in Step 6. For now, let’s build the model.
Step 2 — Indonesian Tokenizer and Dataset
Bahasa Indonesia is morphologically less complex than, say, Turkish or Finnish, so a whitespace + punctuation tokenizer with a learned vocabulary works surprisingly well as a baseline. (For production, swap in IndoBERT — I’ll show how at the end of this section.)
The tokenizer reserves four special tokens, builds a frequency-ranked vocabulary from the training corpus, and emits(token_ids, attention_mask) pairs at a fixed length. Standard stuff, but with one subtlety: we tokenize both the Claim and Judul (headline) columns into a shared vocabulary so the embedding layer can pick up cross-input correlations.
"""
Data Preprocessing for Indonesian Fake News Detection
Tokenizes Claim + Judul columns, encodes Stance labels.
Compatible with JAX/Flax training pipeline.
"""
import re
import numpy as np
import pandas as pd
from collections import Counter
from typing import List, Tuple, Dict
from sklearn.model_selection import train_test_split
LABEL_MAP = {"for": 0, "against": 1, "observing": 2}
ID_TO_LABEL = {v: k for k, v in LABEL_MAP.items()}
class IndonesianTokenizer:
"""
Lightweight whitespace + punctuation tokenizer for Indonesian text.
For production, swap with:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
"""
SPECIAL_TOKENS = {"": 0, " ": 1, " ": 2, " ": 3}
def __init__(self, vocab_size: int = 30_000, min_freq: int = 2):
self.vocab_size = vocab_size
self.min_freq = min_freq
self.word2id: Dict[str, int] = dict(self.SPECIAL_TOKENS)
self.id2word: Dict[int, str] = {v: k for k, v in self.word2id.items()}
@staticmethod
def _clean(text: str) -> str:
text = text.lower()
text = re.sub(r"<[^>]+>", " ", text) # strip HTML
text = re.sub(r"[^ws]", " ", text, flags=re.UNICODE) # keep alphanum
text = re.sub(r"s+", " ", text).strip()
return text
@staticmethod
def tokenize(text: str) -> List[str]:
return IndonesianTokenizer._clean(text).split()
def build_vocab(self, texts: List[str]) -> None:
counter: Counter = Counter()
for t in texts:
counter.update(self.tokenize(t))
top_tokens = [
word for word, freq in counter.most_common()
if freq >= self.min_freq
][: self.vocab_size - len(self.SPECIAL_TOKENS)]
for idx, word in enumerate(top_tokens, start=len(self.SPECIAL_TOKENS)):
self.word2id[word] = idx
self.id2word[idx] = word
print(f"[Tokenizer] Vocabulary size: {len(self.word2id):,} tokens")
def encode(self, text: str, max_len: int = 128, add_special: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
tokens = self.tokenize(text)
ids = [self.word2id.get(t, self.SPECIAL_TOKENS[""]) for t in tokens]
if add_special:
ids = [self.SPECIAL_TOKENS[""]] + ids + [self.SPECIAL_TOKENS[" "]]
ids = ids[:max_len]
pad_len = max_len - len(ids)
mask = [1] * len(ids) + [0] * pad_len
ids = ids + [self.SPECIAL_TOKENS[""]] * pad_len
return np.array(ids, dtype=np.int32), np.array(mask, dtype=np.int32)
def encode_batch(self, texts: List[str], max_len: int = 128
) -> Tuple[np.ndarray, np.ndarray]:
pairs = [self.encode(t, max_len) for t in texts]
ids_arr = np.stack([p[0] for p in pairs])
mask_arr = np.stack([p[1] for p in pairs])
return ids_arr, mask_arr
The dataset class loads the CSV, builds the tokenizer on the combined Claim + Judul corpus, encodes both columns into fixed-length arrays, and produces stratified train/val/test splits:
class FakeNewsDataset:
"""
Loads datasetika.csv and prepares JAX-compatible NumPy arrays.
Expected columns:
id, Claim, idStance, Judul, Stance
Stance ∈ {for, against, observing}
"""
def __init__(
self,
csv_path: str,
claim_max_len: int = 64,
headline_max_len: int = 64,
vocab_size: int = 30_000,
test_size: float = 0.15,
val_size: float = 0.10,
random_seed: int = 42,
):
self.claim_max_len = claim_max_len
self.headline_max_len = headline_max_len
df = pd.read_csv(csv_path)
print(f"[Dataset] Loaded {len(df):,} samples")
print(f"[Dataset] Stance distribution:n{df['Stance'].value_counts()}n")
self.tokenizer = IndonesianTokenizer(vocab_size)
all_texts = df["Claim"].tolist() + df["Judul"].tolist()
self.tokenizer.build_vocab(all_texts)
claim_ids, claim_masks = self.tokenizer.encode_batch(
df["Claim"].tolist(), claim_max_len)
headline_ids, headline_masks = self.tokenizer.encode_batch(
df["Judul"].tolist(), headline_max_len)
labels = np.array([LABEL_MAP[s] for s in df["Stance"]], dtype=np.int32)
# Stratified train / val / test
indices = np.arange(len(df))
train_idx, test_idx = train_test_split(
indices, test_size=test_size, stratify=labels, random_state=random_seed)
train_idx, val_idx = train_test_split(
train_idx,
test_size=val_size / (1 - test_size),
stratify=labels[train_idx],
random_state=random_seed,
)
self.train = self._slice(claim_ids, claim_masks, headline_ids,
headline_masks, labels, train_idx)
self.val = self._slice(claim_ids, claim_masks, headline_ids,
headline_masks, labels, val_idx)
self.test = self._slice(claim_ids, claim_masks, headline_ids,
headline_masks, labels, test_idx)
@staticmethod
def _slice(claim_ids, claim_masks, head_ids, head_masks, labels, idx):
return {
"claim_ids": claim_ids[idx],
"claim_masks": claim_masks[idx],
"headline_ids": head_ids[idx],
"headline_masks": head_masks[idx],
"labels": labels[idx],
}
Production swap: if you have GPU/TPU memory to spare, replace IndonesianTokenizer withAutoTokenizer.from_pretrained(“indobenchmark/indobert-base-p1”) and feed those token IDs into the same downstream model. The training accuracy jump is significant — IndoBERT was pre-trained on hundreds of millions of Indonesian tokens.
Step 3 — The Multimodal Architecture (JAX + Flax)
This is where it gets fun. The model has four logical pieces stacked in a Flax nn.Module:
- Token + positional embedding for both inputs
- Transformer encoder stack (multi-head self-attention + feedforward) for each input independently
- Stance-aware cross-encoder that lets claim and headline attend to each other
- Classification head that produces logits over {for, against, observing}
3.1 — Token and Positional Embedding
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple
class TokenEmbedding(nn.Module):
"""Learnable token + positional embeddings."""
vocab_size: int
embed_dim: int
max_len: int = 256
dropout_rate: float = 0.1
@nn.compact
def __call__(self, token_ids: jnp.ndarray, train: bool = False) -> jnp.ndarray:
tok_emb = nn.Embed(self.vocab_size, self.embed_dim)(token_ids) # (B, L, D)
pos = jnp.arange(token_ids.shape[1])[None, :] # (1, L)
pos_emb = nn.Embed(self.max_len, self.embed_dim)(pos) # (1, L, D)
x = tok_emb + pos_emb
x = nn.LayerNorm()(x)
x = nn.Dropout(self.dropout_rate, deterministic=not train)(x)
return x
3.2 — Multi-Head Attention
A clean, from-scratch multi-head attention. Yes, Flax hasnn.MultiHeadDotProductAttention built in, but writing it explicitly keeps the article educational and lets you see exactly what the attention mask is doing:
class MultiHeadAttention(nn.Module):
num_heads: int
head_dim: int
dropout_rate: float = 0.1
@nn.compact
def __call__(self, q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray,
mask: jnp.ndarray = None, train: bool = False) -> jnp.ndarray:
B, Lq, D = q.shape
Lk = k.shape[1]
H, Dh = self.num_heads, self.head_dim
Q = nn.Dense(H * Dh)(q).reshape(B, Lq, H, Dh).transpose(0, 2, 1, 3)
K = nn.Dense(H * Dh)(k).reshape(B, Lk, H, Dh).transpose(0, 2, 1, 3)
V = nn.Dense(H * Dh)(v).reshape(B, Lk, H, Dh).transpose(0, 2, 1, 3)
scores = jnp.einsum("bhqd,bhkd->bhqk", Q, K) / jnp.sqrt(Dh)
if mask is not None:
# mask: (B, Lk) -> (B, 1, 1, Lk)
mask = mask[:, None, None, :]
scores = jnp.where(mask == 0, -1e9, scores)
attn = jax.nn.softmax(scores, axis=-1)
attn = nn.Dropout(self.dropout_rate, deterministic=not train)(attn)
out = jnp.einsum("bhqk,bhkd->bhqd", attn, V)
out = out.transpose(0, 2, 1, 3).reshape(B, Lq, H * Dh)
out = nn.Dense(D)(out)
return out
3.3 — Transformer Block
Standard pre-norm Transformer block with a feedforward expansion factor of 4:
class TransformerBlock(nn.Module):
num_heads: int
head_dim: int
ff_dim: int
dropout_rate: float = 0.1
@nn.compact
def __call__(self, x: jnp.ndarray, mask: jnp.ndarray = None,
train: bool = False) -> jnp.ndarray:
# Self-attention sublayer
h = nn.LayerNorm()(x)
h = MultiHeadAttention(self.num_heads, self.head_dim, self.dropout_rate)(
h, h, h, mask=mask, train=train)
h = nn.Dropout(self.dropout_rate, deterministic=not train)(h)
x = x + h
# Feedforward sublayer
h = nn.LayerNorm()(x)
h = nn.Dense(self.ff_dim)(h)
h = nn.gelu(h)
h = nn.Dense(x.shape[-1])(h)
h = nn.Dropout(self.dropout_rate, deterministic=not train)(h)
x = x + h
return x
3.4 — Stance Cross-Encoder (the Interesting Part)
This is what makes the architecture multimodal-aware rather than just two encoders bolted together. After the claim and headline are independently encoded, the cross-encoder lets each one attend to the other, then extracts an interaction vector by concatenating the difference and element-wise product of the two pooled representations:
class StanceCrossEncoder(nn.Module):
"""Cross-attention between claim and headline + interaction features."""
num_heads: int
head_dim: int
dropout_rate: float = 0.1
@nn.compact
def __call__(self, claim: jnp.ndarray, headline: jnp.ndarray,
claim_mask: jnp.ndarray, headline_mask: jnp.ndarray,
train: bool = False) -> jnp.ndarray:
# Claim attends to headline
claim_attended = MultiHeadAttention(
self.num_heads, self.head_dim, self.dropout_rate
)(claim, headline, headline, mask=headline_mask, train=train)
# Headline attends to claim
head_attended = MultiHeadAttention(
self.num_heads, self.head_dim, self.dropout_rate
)(headline, claim, claim, mask=claim_mask, train=train)
# Mean-pool with mask
def masked_mean(x, m):
m = m[..., None].astype(x.dtype)
return (x * m).sum(axis=1) / jnp.maximum(m.sum(axis=1), 1.0)
c = masked_mean(claim_attended, claim_mask) # (B, D)
h = masked_mean(head_attended, headline_mask) # (B, D)
# Stance-aware interaction features
interaction = jnp.concatenate(
[c, h, jnp.abs(c - h), c * h], axis=-1
) # (B, 4D)
return interaction
The four components — c, h, |c – h|, c * h — capture different aspects of how the two inputs relate. This is a classic NLI (natural language inference) trick and it works well here: |c – h| flags semantic distance, while c * h highlights agreement.
3.5 — Putting It All Together
class FakeNewsDetector(nn.Module):
vocab_size: int
embed_dim: int = 256
num_heads: int = 8
head_dim: int = 32
ff_dim: int = 1024
num_layers: int = 4
dropout_rate: float = 0.1
num_classes: int = 3
@nn.compact
def __call__(self, claim_ids, claim_mask, headline_ids, headline_mask,
train: bool = False):
# Embed both inputs
claim_x = TokenEmbedding(self.vocab_size, self.embed_dim,
dropout_rate=self.dropout_rate)(claim_ids, train)
headline_x = TokenEmbedding(self.vocab_size, self.embed_dim,
dropout_rate=self.dropout_rate)(headline_ids, train)
# Independent Transformer stacks
for _ in range(self.num_layers):
claim_x = TransformerBlock(self.num_heads, self.head_dim,
self.ff_dim, self.dropout_rate)(
claim_x, mask=claim_mask, train=train)
headline_x = TransformerBlock(self.num_heads, self.head_dim,
self.ff_dim, self.dropout_rate)(
headline_x, mask=headline_mask, train=train)
# Cross-encoder interaction
interaction = StanceCrossEncoder(self.num_heads, self.head_dim,
self.dropout_rate)(
claim_x, headline_x, claim_mask, headline_mask, train=train)
# Classification head
h = nn.Dense(self.embed_dim)(interaction)
h = nn.gelu(h)
h = nn.Dropout(self.dropout_rate, deterministic=not train)(h)
logits = nn.Dense(self.num_classes)(h)
return logits
A 4-layer encoder stack with embed_dim=256, num_heads=8, and ff_dim=1024 lands at around 7M parameters — small enough to train quickly on a free Colab TPU, large enough to learn meaningful stance representations.
Step 4 — Training Pipeline (TPU-Optimized)
The training loop is a standard Optax + JAX pattern, with two TPU-specific touches: everything is jax.jit-compiled, and we use Orbax for checkpointing because it handles the JAX pytreeparameter format natively.
4.1 — Train State and Loss
import optax
from flax.training import train_state
def create_train_state(rng, model, learning_rate, weight_decay,
claim_len, head_len, num_warmup_steps, num_total_steps):
# Dummy inputs to initialize parameters
dummy_claim_ids = jnp.ones((1, claim_len), dtype=jnp.int32)
dummy_claim_mask = jnp.ones((1, claim_len), dtype=jnp.int32)
dummy_head_ids = jnp.ones((1, head_len), dtype=jnp.int32)
dummy_head_mask = jnp.ones((1, head_len), dtype=jnp.int32)
params = model.init(
rng, dummy_claim_ids, dummy_claim_mask,
dummy_head_ids, dummy_head_mask, train=False,
)["params"]
# Linear warmup + cosine decay
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=learning_rate,
warmup_steps=num_warmup_steps,
decay_steps=num_total_steps - num_warmup_steps,
end_value=learning_rate * 0.1,
)
optimizer = optax.adamw(schedule, weight_decay=weight_decay)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
def cross_entropy_loss(logits, labels):
one_hot = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
return -jnp.mean(jnp.sum(one_hot * jax.nn.log_softmax(logits), axis=-1))
4.2 — JIT-Compiled Train and Eval Steps
@jax.jit
def train_step(state, batch, dropout_rng):
def loss_fn(params):
logits = state.apply_fn(
{"params": params},
batch["claim_ids"], batch["claim_masks"],
batch["headline_ids"], batch["headline_masks"],
train=True,
rngs={"dropout": dropout_rng},
)
loss = cross_entropy_loss(logits, batch["labels"])
return loss, logits
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
state = state.apply_gradients(grads=grads)
accuracy = (jnp.argmax(logits, axis=-1) == batch["labels"]).mean()
return state, loss, accuracy
@jax.jit
def eval_step(state, batch):
logits = state.apply_fn(
{"params": state.params},
batch["claim_ids"], batch["claim_masks"],
batch["headline_ids"], batch["headline_masks"],
train=False,
)
loss = cross_entropy_loss(logits, batch["labels"])
accuracy = (jnp.argmax(logits, axis=-1) == batch["labels"]).mean()
return loss, accuracy
The @jax.jit is doing a lot of work here. On the first call, XLA compiles the entire forward pass, loss computation, gradient calculation, and parameter update into a single fused execution graph. After that first ~30-second compile, every subsequent step runs at peak TPU throughput.
4.3 — The Training Loop
import orbax.checkpoint as ocp
from pathlib import Path
import time
def train(args):
# Load and prepare data
ds = FakeNewsDataset(
args.data_path,
claim_max_len=args.claim_len,
headline_max_len=args.head_len,
)
# Create model and train state
model = FakeNewsDetector(
vocab_size=len(ds.tokenizer.word2id),
embed_dim=args.embed_dim,
num_heads=args.num_heads,
)
rng = jax.random.PRNGKey(0)
init_rng, train_rng = jax.random.split(rng)
steps_per_epoch = len(ds.train["labels"]) // args.batch_size
num_total_steps = steps_per_epoch * args.epochs
num_warmup_steps = int(0.1 * num_total_steps)
state = create_train_state(
init_rng, model, args.lr, weight_decay=0.01,
claim_len=args.claim_len, head_len=args.head_len,
num_warmup_steps=num_warmup_steps,
num_total_steps=num_total_steps,
)
# Orbax checkpointer
ckpt_dir = Path(args.output_dir).resolve()
ckpt_dir.mkdir(parents=True, exist_ok=True)
checkpointer = ocp.PyTreeCheckpointer()
best_val_acc = 0.0
for epoch in range(args.epochs):
epoch_start = time.time()
train_rng, shuffle_rng = jax.random.split(train_rng)
# Shuffle training indices
perm = jax.random.permutation(shuffle_rng, len(ds.train["labels"]))
# Training pass
train_loss, train_acc = 0.0, 0.0
for step in range(steps_per_epoch):
batch_idx = perm[step * args.batch_size:(step + 1) * args.batch_size]
batch = {k: v[batch_idx] for k, v in ds.train.items()}
train_rng, dropout_rng = jax.random.split(train_rng)
state, loss, acc = train_step(state, batch, dropout_rng)
train_loss += float(loss)
train_acc += float(acc)
train_loss /= steps_per_epoch
train_acc /= steps_per_epoch
# Validation pass
val_loss, val_acc = 0.0, 0.0
val_steps = len(ds.val["labels"]) // args.batch_size
for step in range(val_steps):
batch = {k: v[step * args.batch_size:(step + 1) * args.batch_size]
for k, v in ds.val.items()}
loss, acc = eval_step(state, batch)
val_loss += float(loss)
val_acc += float(acc)
val_loss /= val_steps
val_acc /= val_steps
elapsed = time.time() - epoch_start
print(f"Epoch {epoch+1:3d}/{args.epochs} | "
f"train loss {train_loss:.4f} acc {train_acc:.4f} | "
f"val loss {val_loss:.4f} acc {val_acc:.4f} | "
f"{elapsed:.1f}s")
# Save best checkpoint
if val_acc > best_val_acc:
best_val_acc = val_acc
checkpointer.save(ckpt_dir / "best", state.params, force=True)
print(f" ✓ Saved checkpoint (val_acc={val_acc:.4f})")
4.4 — Running on Free Colab TPU
In a Colab cell:
class Args:
data_path = "datasetika.csv"
output_dir = "checkpoints"
epochs = 20
batch_size = 64
lr = 2e-4
claim_len = 64
head_len = 64
embed_dim = 256
num_heads = 8
train(Args)
Sample output on Colab v5e-1:
[Dataset] Loaded 12,847 samples
[Dataset] Stance distribution:
observing 5,932
for 4,221
against 2,694
[Tokenizer] Vocabulary size: 18,734 tokens
Epoch 1/20 | train loss 0.9932 acc 0.5414 | val loss 0.8821 acc 0.6102 | 41.3s
Epoch 2/20 | train loss 0.7881 acc 0.6587 | val loss 0.7024 acc 0.7011 | 22.8s
Epoch 3/20 | train loss 0.6014 acc 0.7321 | val loss 0.6398 acc 0.7314 | 22.5s
...
Epoch 20/20 | train loss 0.2741 acc 0.9001 | val loss 0.4012 acc 0.8412 | 22.6s
The first epoch is slower because XLA is compiling the train_step. From epoch 2 onward, every epoch takes ~22 seconds on a free Colab v5e-1. Total wall time: about 8 minutes for 20 epochs.
Step 5 — Inference on New Claims
Once trained, classifying a new claim-headline pair is a one-shot forward pass. Load the tokenizer and checkpointed parameters, encode the inputs, run the model:
import pickle
import orbax.checkpoint as ocp
def load_model(checkpoint_path: str, tokenizer_path: str,
claim_len: int = 64, headline_len: int = 64,
embed_dim: int = 256, num_heads: int = 8, num_layers: int = 4):
with open(tokenizer_path, "rb") as f:
tokenizer: IndonesianTokenizer = pickle.load(f)
model = FakeNewsDetector(
vocab_size=len(tokenizer.word2id),
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
)
checkpointer = ocp.PyTreeCheckpointer()
params = checkpointer.restore(checkpoint_path)
return model, params, tokenizer
def predict(model, params, tokenizer, claim: str, headline: str,
claim_len: int = 64, headline_len: int = 64):
claim_ids, claim_mask = tokenizer.encode(claim, claim_len)
headline_ids, headline_mask = tokenizer.encode(headline, headline_len)
# Add batch dimension
claim_ids = claim_ids[None, :]
claim_mask = claim_mask[None, :]
headline_ids = headline_ids[None, :]
headline_mask = headline_mask[None, :]
logits = model.apply(
{"params": params},
claim_ids, claim_mask, headline_ids, headline_mask, train=False,
)
probs = jax.nn.softmax(logits, axis=-1)[0]
pred_idx = int(jnp.argmax(probs))
return ID_TO_LABEL[pred_idx], {ID_TO_LABEL[i]: float(probs[i]) for i in range(3)}
# Example
model, params, tokenizer = load_model("checkpoints/best", "tokenizer.pkl")
examples = [
{"claim": "Pemerintah resmi menaikkan harga BBM bulan depan",
"headline": "Kementerian ESDM bantah rencana kenaikan harga BBM",
"expected": "against"},
{"claim": "Vaksin baru efektif 95 persen mencegah penularan",
"headline": "Studi terbaru: vaksin tunjukkan efikasi 95 persen pada uji klinis fase 3",
"expected": "for"},
{"claim": "Presiden umumkan kebijakan baru ekonomi",
"headline": "Pengamat: masih perlu dikaji lebih lanjut dampaknya",
"expected": "observing"},
]
for ex in examples:
pred, probs = predict(model, params, tokenizer, ex["claim"], ex["headline"])
print(f"nClaim: {ex['claim']}")
print(f"Headline: {ex['headline']}")
print(f"Predicted: {pred} (expected: {ex['expected']})")
print(f"Probabilities: {probs}")
Step 6 — Scaling Up: Deploying to Cloud TPU with Keras Kinetic
Free Colab TPU is great for prototyping. But Colab sessions have runtime limits (12h max, often less under load), no persistent storage, and you can’t run multiple experiments in parallel. When you want to do real hyperparameter sweeps or train on a larger dataset, it’s time to graduate to Cloud TPU.
The traditional path means provisioning a TPU VM, SSHing in, installing dependencies, uploading scripts, and managing the deployment yourself. Keras Kinetic skips all of that. You decorate your training function with @kinetic.run(accelerator=”v5p-8″) and call it from your laptop. Kinetic packages the code, builds a container, provisions a GKE cluster with TPUs attached, runs the function on the remote pod, and streams logs back to your local terminal.
6.1 — Wrap Your Training Function
Save the existing tokenizer, dataset, model, and training code asfakenews.py. Then create train_kinetic.py:
import kinetic
@kinetic.run(
accelerator="v5p-8",
requirements=[
"jax[tpu]",
"flax",
"optax",
"orbax-checkpoint",
"numpy",
"pandas",
"scikit-learn",
],
)
def train_remote(data_gcs_path: str, epochs: int = 50, batch_size: int = 128,
lr: float = 2e-4):
# CRITICAL: imports happen *inside* the function. The body runs on
# the remote TPU pod, so imports resolve against the pod's installed
# packages, not your laptop's.
import os
os.environ["JAX_PLATFORMS"] = "tpu"
from fakenews import FakeNewsDataset, FakeNewsDetector, train
class Args:
data_path = data_gcs_path
output_dir = "https://medium.com/tmp/checkpoints"
epochs = epochs
batch_size = batch_size
lr = lr
claim_len = 64
head_len = 64
embed_dim = 256
num_heads = 8
train(Args)
return {"status": "completed", "epochs": epochs}
if __name__ == "__main__":
# This runs on the *remote* TPU but feels like a local function call.
result = train_remote(
data_gcs_path="gs://my-bucket/datasetika.csv",
epochs=50,
batch_size=128,
)
print(f"Training finished: {result}")
Three things worth highlighting:
- All imports go inside the function — the body runs on the remote pod, so import jax needs to resolve against the pod’s jax[tpu] package, not your laptop’s.
- requirements=[…] is your remote requirements.txt — Kinetic uses it to build the container image on the first run. None of these need to be installed locally.
- Data lives in GCS — your dataset goes to a Google Cloud Storage bucket, and the function reads it from gs://…. The TPU pod has automatic GCS access via its service account.
6.2 — Launch from Your Laptop
python train_kinetic.py
First run takes ~5 minutes for the container build. Subsequent runs with unchanged dependencies start in under a minute. You’ll see remote logs streamed to your terminal:
Shipping to TPU via Kinetic...
[Stage 1/4] Preflight & packaging...
[Stage 2/4] Building container image (5m, cached after this run)...
[Stage 3/4] Submitting job to GKE cluster...
[Stage 4/4] Executing on TPU v5p-8...
[remote] [Dataset] Loaded 12,847 samples
[remote] [Tokenizer] Vocabulary size: 18,734 tokens
[remote] Epoch 1/50 | train loss 0.9912 acc 0.5421 | val loss 0.8810 acc 0.6112 | 12.8s
[remote] Epoch 2/50 | train loss 0.7821 acc 0.6601 | val loss 0.6987 acc 0.7022 | 4.2s
...
[remote] Epoch 50/50 | train loss 0.1912 acc 0.9301 | val loss 0.3811 acc 0.8612 | 4.3s
Job complete. Streaming results to local...
Training finished: {'status': 'completed', 'epochs': 50}
The v5p-8 has 8 chips vs Colab’s 1, so per-epoch time drops from 22s to ~4s — about 5× faster, which matters when you’re running 50+ epochs or sweeping hyperparameters.
6.3 — Tear Down
The GKE cluster’s control plane costs money even when no TPU nodes are active. Always shut it down when you’re done:
kinetic down --yes
Training Configuration Summary
- Architecture: 4-layer Transformer + Stance Cross-Encoder
- Embedding dim: 256
- Attention heads: 8 (head_dim 32)
- Feedforward dim: 1024
- Total parameters: ~7M
- Optimizer: AdamW with linear warmup (10% of steps) + cosine decay
- Learning rate: 2e-4 peak
- Weight decay: 0.01
- Dropout: 0.1
- Batch size: 64 (Colab) / 128 (Cloud TPU v5p-8)
- Epochs: 20 (Colab) / 50 (Cloud TPU)
- Sequence length: 64 tokens for both claim and headline
Key Takeaways
Free Colab TPU is the best on-ramp for JAX/Flax in 2026
You get a full v5e-1 chip with no signup beyond a Google account. For models in the 5–50M parameter range, that’s enough to do real work, not just toy demos. If you’ve been putting off learning JAX because the cloud setup felt overwhelming, this path is friction-free.
Keras Kinetic makes the leap from Colab to Cloud TPU painless
The biggest practical lesson from this project: when Colab’s runtime limits start hurting, switching to Cloud TPU traditionally means rewriting your deployment story. With Kinetic, you wrap your existing function in a decorator and call it from the same laptop you’ve been using. The mental model — “I have a function; I want it to run on a TPU” — stays intact.
Stance detection is a more honest framing than fake-news classification
Asking a model “is this true?” puts it in an impossible position. Asking “does this article support, refute, or merely observe this claim?” gives it a question it can answer, and gives downstream fact-checkers exactly the signal they need.
Cross-attention beats independent encoders for paired-text tasks
The StanceCrossEncoder is what makes this model genuinely multimodal-aware. Concatenating two independently-encoded vectors and slapping a classifier on top works, but performance jumps significantly when you let the two inputs literally read each other before pooling. The [c, h, |c-h|, c*h] interaction trick is borrowed from NLI literature and consistently outperforms simpler combinations.
TPU is the right pick for fixed-length attention workloads
Multi-head attention with fixed sequence lengths is exactly what TPU systolic arrays were built for. The whole train_step JITs into a single fused execution graph, and after the first compile, every step runs at peak FLOPs. GPUs are still preferable when you need irregular-length sequences, KV-caching for generation, or HuggingFace PyTorch checkpoints — but for from-scratch JAX/Flax training, TPU wins on both speed and cost.
Adapting This to Other Languages and Tasks
The pipeline is generic. To use it on your own data:
- Any language → just point the tokenizer at your text corpus. The whitespace tokenizer is language-agnostic. For better quality, swap in the appropriate pre-trained tokenizer (English → BERT, Indonesian → IndoBERT, Arabic → AraBERT, etc.).
- Any paired-text classification task → swap the labels. The [c, h, |c-h|, c*h] interaction trick works for paraphrase detection, NLI, semantic textual similarity, duplicate question detection, and more.
- Larger corpora → bump embed_dim, num_layers, and switch from a learned tokenizer to a SentencePiece or BPE one. The training loop and Kinetic deployment don’t change.
Resources
- Project repository: (your GitHub link here)
- JAX: github.com/google/jax
- Flax: github.com/google/flax
- Optax: github.com/google-deepmind/optax
- Orbax: github.com/google/orbax
- Keras Kinetic: github.com/keras-team/kinetic
- IndoBERT (production tokenizer upgrade):huggingface.co/indobenchmark/indobert-base-p1
- Cloud TPU documentation: cloud.google.com/tpu/docs
- Google Colab: colab.research.google.com
Acknowledgement
- Google Cloud credits were provided for this project. #TPUSprint
- Thanks to the JAX, Flax, and Keras teams for building such a clean stack — training custom transformers on TPUs used to be a research-grade pain.
Tags: JAX, Flax, KerasKinetic, TPU, GoogleColab, FakeNewsDetection, StanceDetection, NaturalLanguageProcessing, MachineLearning, Indonesia
Building a Multimodal Indonesian Fake-News Detector with JAX, Flax, and Keras Kinetic on Cloud TPU was originally published in Google Developer Experts on Medium, where people are continuing the conversation by highlighting and responding to this story.