Fine-Tuning Gemma 3 on TPU for Medical Q&A with Keras and JAX

I wanted to build a medical Q&A fine-tuning project that stayed genuinely TPU-native on Kaggle. This project uses Google’s `gemma-3–1b-it`, KerasHub, JAX backend, and a Kaggle TPU to fine-tune on medical dialogue data from ChatDoctor. I also compared the result against a much larger Gemma 4 model served remotely.

The short version is:

– Gemma 3 training pipeline ran end to end on TPU

– LoRA fine-tuning completed in under a minute on a `TPU v5e-8`

– Qualitative answers improved more than benchmark accuracy

– Gemma 4 remained much stronger on MedMCQA in zero-shot mode

It was worth noting that conversational fine-tuning and benchmark gains are not always the same thing.

You can checkout the Kaggle notebook here.

Medical LLM Fine-Tuning on TPU with Keras + JAX

What I Built

A pipeline that does five things:

1. Verifies a Kaggle TPU environment with Keras locked to the JAX backend.

2. Loads and formats the ChatDoctor dataset into Gemma chat-style prompt/response pairs.

3. Evaluates base Gemma 3 1B on a fixed 100 question MedMCQA slice.

4. Fine-tunes Gemma 3 1B with LoRA on TPU.

5. Compares the final model against Gemma 4 through OpenRouter for a zero-shot reference point.

The active local training path is fully TPU-native for Gemma 3. The Gemma 4 section is intentionally separate and remote, so I label it clearly as an API comparison rather than TPU-local inference.

Environment and TPU Setup

This run was executed in Kaggle with:

Python `3.12.13`
NumPy `2.4.3`
datasets `4.8.3`
Keras `3.13.2`
JAX `0.9.2`

and 8` TPU devices visible to JAX. The critical setup detail was making sure Keras saw `KERAS_BACKEND=jax` before its first import:

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import jax
print(keras.backend.backend())
print(jax.devices())

That sounds minor, but in notebook environments it is often the difference between a clean TPU run and a confusing backend mismatch.

Dataset Choice: ChatDoctor for Fine-Tuning, MedMCQA for Benchmarking

I used two datasets with different roles:

– `LinhDuong/chatdoctor-200k` for conversational medical fine-tuning

– `medmcqa` for a multiple-choice medical benchmark

From the final run:

– raw ChatDoctor examples: `207,408`

– valid formatted examples: `207,405`

– train split used: `1,800`

– validation split used: `200`

– MedMCQA evaluation size: `100`

ChatDoctor was reformatted into Gemma chat turns:

text = (
f"usern"
f"You are a helpful medical assistant. Answer the patient's question clearly and safely.nn"
f"{patient_msg}n"
f"modeln"
f"{doctor_msg}"
)

That structure let me keep preprocessing inside KerasHub rather than building a separate tokenization training pipeline by hand.

Baseline: Gemma 3 1B Before Fine-Tuning

For the baseline, I loaded `google/gemma-3–1b-it` directly through KerasHub:

baseline_preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
"hf://google/gemma-3–1b-it",
sequence_length=256,
)
baseline_gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(
"hf://google/gemma-3–1b-it",
preprocessor=baseline_preprocessor,
)
baseline_gemma_lm.compile(sampler="greedy")

On the `100`-question MedMCQA slice, the base model scored 30/100 i.e, 30.0% accuracy. This gave me the “before fine-tuning” reference point for the rest of the project.

Fine-Tuning Setup: LoRA on TPU

The model had `999,885,952` parameters in total. After enabling LoRA, only `2,609,152` parameters were trainable, which made this practical on Kaggle TPU without trying to full-fine-tune the entire model.

The LoRA step was simple:

gemma_lm.backbone.enable_lora(rank=16)

Training configuration from the final run:

sequence length: `256`
train samples: `1,800`
validation samples: `200`
batch size: `1`
epochs: `1`
learning rate: `1e-4`
LoRA rank: `16`

The model was compiled with SGD plus sparse categorical cross-entropy:

optimizer = keras.optimizers.SGD(learning_rate=1e-4)
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
sampler="greedy",
)

TPU Training Results

The training run completed successfully on TPU, and the first batch behaved exactly the way TPU users expect: slower up front because of XLA compilation, then much faster steady-state execution.

Final TPU training stats from the notebook:

hardware: `Google Cloud TPU (8 devices)`
model: `google/gemma-3–1b-it`
framework: `Keras 3.13.2 + JAX backend`
method: `LoRA (rank=16)`
train examples: `1800`
batch size: `1`
epochs: `1`
total train time: `0.7 minutes`
throughput: `~10512 tokens/sec`

Training metrics:
train loss: `1.8927`
validation loss: `1.7304`
train token accuracy: `0.3061`
validation token accuracy: `0.3129`

The token-level accuracy is a language-model training signal, not a “medical question answering accuracy” metric, so I treated it as a health check rather than the headline result.

Did Fine-Tuning Improve the Model?

This is where the story gets interesting. On MedMCQA, the answer was: not really.

The fine-tuned Gemma 3 model scored 30/100 i.e, 30.0% accuracy. That is exactly flat versus the baseline on this benchmark.

But the qualitative outputs still changed. In the side-by-side examples, the fine-tuned model became a bit more direct and safety-oriented. For example, on the chest pain and shortness of breath prompt:

– the baseline answer was cautious and explanatory

– the fine-tuned answer escalated more clearly toward urgent medical evaluation

That pattern showed up again in the other examples. The model did not suddenly become benchmark-dominant, but it did become more aligned with the tone of medical guidance in the fine-tuning data.

This is an important practical lesson: fine-tuning on conversational domain data can meaningfully change response style and safety emphasis without moving a multiple-choice benchmark very much.

Held-Out ROUGE-L Check

I also ran a lightweight held-out check on 20 validation examples from ChatDoctor and measured ROUGE-L between model outputs and reference answers.

Final result: ROUGE-L: 0.106

I would not oversell that number. It is best treated as a coarse similarity signal, not a substitute for clinical quality evaluation. Still, it adds one more perspective beyond MedMCQA.

Gemma 4 Comparison Through OpenRouter

For a stronger reference model, I added an optional Gemma 4 section through OpenRouter:

– model: `google/gemma-4–26b-a4b-it`

– execution mode: remote API inference

– not TPU-local

The Gemma 4 comparison was useful, but it was not part of the TPU training path. It was there to answer a product question: how does a fine-tuned small model compare to a larger newer model out of the box?

On the same 100 MedMCQA questions, Gemma 4 scored 68/100 i.e, 68.0% accuracy. That is a large gap over both the base and fine-tuned Gemma 3 runs.

The qualitative outputs were stronger too. On the chest pain example, Gemma 4 moved immediately into emergency-style guidance, explicitly telling the user to treat it as a potential emergency and seek immediate care. On the fatigue, thirst, and frequent urination example, it cleanly recognized the classic diabetes-related symptom triad.

Final Results

Additional metrics:

  • ROUGE-L on held-out ChatDoctor examples: `0.106`

– trainable parameters with LoRA: `2,609,152`

– total parameters: `999,885,952`

– end-to-end TPU training time: `0.7 minutes`

What This Project Actually Shows

The biggest takeaway is not fine-tuning always beats a better base model. My final run does not support that claim.

What it does show is:

– Keras and JAX is a practical TPU training stack on Kaggle

– Gemma 3 1B can be fine-tuned end to end on a TPU with a very small adaptation budget

– medical dialogue fine-tuning can shift qualitative behavior even when benchmark accuracy stays flat

– a much larger newer model can still dominate factual medical MCQ evaluation in zero-shot mode

Closing Thoughts

The Gemma 3 training path worked exactly the way I wanted i.e, simple setup, clean TPU execution, small trainable footprint, and fast iteration. The results were also a good reminder that not every successful training run produces a flashy benchmark jump. Sometimes the real win is building a robust pipeline, understanding what changed, and being honest about what did not.

Overall, I successfully fine-tuned Gemma 3 1B for medical dialogue on Kaggle TPU with Keras and JAX, saw qualitative improvements in response behavior, and confirmed that a remote Gemma 4 baseline still substantially outperformed it on MedMCQA.


Fine-Tuning Gemma 3 on TPU for Medical Q&A with Keras and JAX was originally published in Google Developer Experts on Medium, where people are continuing the conversation by highlighting and responding to this story.

Total
0
Shares
Leave a Reply

Your email address will not be published. Required fields are marked *

Previous Post

SQL Subquery and CTEs( Common Table Expressions)

Next Post

Bringing Multimodal Gemma 4 E2B to the Edge: A Deep Dive into LiteRT-LM and Qualcomm QNN

Related Posts