I wanted to build a medical Q&A fine-tuning project that stayed genuinely TPU-native on Kaggle. This project uses Google’s `gemma-3–1b-it`, KerasHub, JAX backend, and a Kaggle TPU to fine-tune on medical dialogue data from ChatDoctor. I also compared the result against a much larger Gemma 4 model served remotely.

The short version is:
– Gemma 3 training pipeline ran end to end on TPU
– LoRA fine-tuning completed in under a minute on a `TPU v5e-8`
– Qualitative answers improved more than benchmark accuracy
– Gemma 4 remained much stronger on MedMCQA in zero-shot mode
It was worth noting that conversational fine-tuning and benchmark gains are not always the same thing.
You can checkout the Kaggle notebook here.
Medical LLM Fine-Tuning on TPU with Keras + JAX
What I Built
A pipeline that does five things:
1. Verifies a Kaggle TPU environment with Keras locked to the JAX backend.
2. Loads and formats the ChatDoctor dataset into Gemma chat-style prompt/response pairs.
3. Evaluates base Gemma 3 1B on a fixed 100 question MedMCQA slice.
4. Fine-tunes Gemma 3 1B with LoRA on TPU.
5. Compares the final model against Gemma 4 through OpenRouter for a zero-shot reference point.
The active local training path is fully TPU-native for Gemma 3. The Gemma 4 section is intentionally separate and remote, so I label it clearly as an API comparison rather than TPU-local inference.
Environment and TPU Setup
This run was executed in Kaggle with:
Python `3.12.13`
NumPy `2.4.3`
datasets `4.8.3`
Keras `3.13.2`
JAX `0.9.2`
and 8` TPU devices visible to JAX. The critical setup detail was making sure Keras saw `KERAS_BACKEND=jax` before its first import:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import jax
print(keras.backend.backend())
print(jax.devices())
That sounds minor, but in notebook environments it is often the difference between a clean TPU run and a confusing backend mismatch.
Dataset Choice: ChatDoctor for Fine-Tuning, MedMCQA for Benchmarking
I used two datasets with different roles:
– `LinhDuong/chatdoctor-200k` for conversational medical fine-tuning
– `medmcqa` for a multiple-choice medical benchmark
From the final run:
– raw ChatDoctor examples: `207,408`
– valid formatted examples: `207,405`
– train split used: `1,800`
– validation split used: `200`
– MedMCQA evaluation size: `100`
ChatDoctor was reformatted into Gemma chat turns:
text = (
f"usern"
f"You are a helpful medical assistant. Answer the patient's question clearly and safely.nn"
f"{patient_msg}n"
f"modeln"
f"{doctor_msg}"
)
That structure let me keep preprocessing inside KerasHub rather than building a separate tokenization training pipeline by hand.
Baseline: Gemma 3 1B Before Fine-Tuning
For the baseline, I loaded `google/gemma-3–1b-it` directly through KerasHub:
baseline_preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
"hf://google/gemma-3–1b-it",
sequence_length=256,
)
baseline_gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(
"hf://google/gemma-3–1b-it",
preprocessor=baseline_preprocessor,
)
baseline_gemma_lm.compile(sampler="greedy")
On the `100`-question MedMCQA slice, the base model scored 30/100 i.e, 30.0% accuracy. This gave me the “before fine-tuning” reference point for the rest of the project.
Fine-Tuning Setup: LoRA on TPU
The model had `999,885,952` parameters in total. After enabling LoRA, only `2,609,152` parameters were trainable, which made this practical on Kaggle TPU without trying to full-fine-tune the entire model.
The LoRA step was simple:
gemma_lm.backbone.enable_lora(rank=16)
Training configuration from the final run:
sequence length: `256`
train samples: `1,800`
validation samples: `200`
batch size: `1`
epochs: `1`
learning rate: `1e-4`
LoRA rank: `16`
The model was compiled with SGD plus sparse categorical cross-entropy:
optimizer = keras.optimizers.SGD(learning_rate=1e-4)
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
sampler="greedy",
)
TPU Training Results
The training run completed successfully on TPU, and the first batch behaved exactly the way TPU users expect: slower up front because of XLA compilation, then much faster steady-state execution.
Final TPU training stats from the notebook:
hardware: `Google Cloud TPU (8 devices)`
model: `google/gemma-3–1b-it`
framework: `Keras 3.13.2 + JAX backend`
method: `LoRA (rank=16)`
train examples: `1800`
batch size: `1`
epochs: `1`
total train time: `0.7 minutes`
throughput: `~10512 tokens/sec`
Training metrics:
train loss: `1.8927`
validation loss: `1.7304`
train token accuracy: `0.3061`
validation token accuracy: `0.3129`
The token-level accuracy is a language-model training signal, not a “medical question answering accuracy” metric, so I treated it as a health check rather than the headline result.
Did Fine-Tuning Improve the Model?
This is where the story gets interesting. On MedMCQA, the answer was: not really.
The fine-tuned Gemma 3 model scored 30/100 i.e, 30.0% accuracy. That is exactly flat versus the baseline on this benchmark.
But the qualitative outputs still changed. In the side-by-side examples, the fine-tuned model became a bit more direct and safety-oriented. For example, on the chest pain and shortness of breath prompt:
– the baseline answer was cautious and explanatory
– the fine-tuned answer escalated more clearly toward urgent medical evaluation
That pattern showed up again in the other examples. The model did not suddenly become benchmark-dominant, but it did become more aligned with the tone of medical guidance in the fine-tuning data.
This is an important practical lesson: fine-tuning on conversational domain data can meaningfully change response style and safety emphasis without moving a multiple-choice benchmark very much.
Held-Out ROUGE-L Check
I also ran a lightweight held-out check on 20 validation examples from ChatDoctor and measured ROUGE-L between model outputs and reference answers.
Final result: ROUGE-L: 0.106
I would not oversell that number. It is best treated as a coarse similarity signal, not a substitute for clinical quality evaluation. Still, it adds one more perspective beyond MedMCQA.
Gemma 4 Comparison Through OpenRouter
For a stronger reference model, I added an optional Gemma 4 section through OpenRouter:
– model: `google/gemma-4–26b-a4b-it`
– execution mode: remote API inference
– not TPU-local
The Gemma 4 comparison was useful, but it was not part of the TPU training path. It was there to answer a product question: how does a fine-tuned small model compare to a larger newer model out of the box?
On the same 100 MedMCQA questions, Gemma 4 scored 68/100 i.e, 68.0% accuracy. That is a large gap over both the base and fine-tuned Gemma 3 runs.
The qualitative outputs were stronger too. On the chest pain example, Gemma 4 moved immediately into emergency-style guidance, explicitly telling the user to treat it as a potential emergency and seek immediate care. On the fatigue, thirst, and frequent urination example, it cleanly recognized the classic diabetes-related symptom triad.
Final Results

Additional metrics:
- ROUGE-L on held-out ChatDoctor examples: `0.106`
– trainable parameters with LoRA: `2,609,152`
– total parameters: `999,885,952`
– end-to-end TPU training time: `0.7 minutes`
What This Project Actually Shows
The biggest takeaway is not fine-tuning always beats a better base model. My final run does not support that claim.
What it does show is:
– Keras and JAX is a practical TPU training stack on Kaggle
– Gemma 3 1B can be fine-tuned end to end on a TPU with a very small adaptation budget
– medical dialogue fine-tuning can shift qualitative behavior even when benchmark accuracy stays flat
– a much larger newer model can still dominate factual medical MCQ evaluation in zero-shot mode
Closing Thoughts
The Gemma 3 training path worked exactly the way I wanted i.e, simple setup, clean TPU execution, small trainable footprint, and fast iteration. The results were also a good reminder that not every successful training run produces a flashy benchmark jump. Sometimes the real win is building a robust pipeline, understanding what changed, and being honest about what did not.
Overall, I successfully fine-tuned Gemma 3 1B for medical dialogue on Kaggle TPU with Keras and JAX, saw qualitative improvements in response behavior, and confirmed that a remote Gemma 4 baseline still substantially outperformed it on MedMCQA.
Fine-Tuning Gemma 3 on TPU for Medical Q&A with Keras and JAX was originally published in Google Developer Experts on Medium, where people are continuing the conversation by highlighting and responding to this story.