Creating more dependable and accurate machine learning models nowadays depends on combining data from several modalities, including text and images. Multimodal classification generates more complex predictions for several uses, including object recognition, sentiment analysis, and content moderation, by using various input forms, facilitating a more profound knowledge of the data.
Since PaLI-GEMMA can manage text and image data concurrently, it is a robust model ideal for multimodal classification challenges. By adjusting this model to your specific use case, you can ensure it uses the current data to generate findings.
In this article, I will guide you through optimizing the PaLI-GEMMA model for multimodal classification. Whether you are interested in machine learning or a professional looking to examine the most recent advancements in multimodal learning, this comprehensive guide will help you create the necessary tools and maximize the model for your particular dataset. Let’s explore the field of multimodal categorization with PaLI-GEMMA!
Requirements:
- Python 3.7+
- PyTorch
- Transformers
- PEFT (Parameter-Efficient Fine-Tuning)
- bitsandbytes
- pandas
- Pillow (PIL)
- scikit-learn
- Google Colab / At least T4 Nvidia GPU
Step 1 — Installation
The first step is to install the required libraries
pip install transformers peft bitsandbytes pandas pillow scikit-learn torch
Step 2: Import Required Libraries
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from torch.utils.data import Dataset, DataLoader, Subset
from peft import get_peft_model, LoraConfig, TaskType
from sklearn.model_selection import KFold
from PIL import Image
import pandas as pd
import torch
import os
import math
Step 3: Define the Dataset Classes
To load the data, create a dataset class. The dataset class in this example loads a pair of text and images, turns them into a tensor that PaliGemma can use, and then returns the tensor with a mood label of 0 for negative, 1 for neutral, or 2 for positive based on the given index.
“Text Input: [input] {newline}Sentiment:” is the prompt used in this case. This format lets the model get a general idea of what we want from the input: the sentiment.
class TextImageDataset(Dataset):
def __init__(self, dataset_csv_path:str, processor):
# The csv in this example contains columns for
# text_input: str
# image_path: str
# sentiment: int in [0, 1, 2]
# id: str
self.dataset = pd.read_csv(dataset_csv_path)
self.processor = processor
self.sentiment_tokens = ['Negative', 'Neutral', 'Positive']
self.sentiment_ids = processor.tokenizer.convert_tokens_to_ids(self.sentiment_tokens)
self.prompt_template = "Text Input: {}nnSentiment:"
def __getitem__(self, idx):
dataset_item = self.dataset.iloc[idx]
prompt = self.prompt_template.format(dataset_item['text_input'])
inputs = self.processor(prompt, Image.open(dataset_item['image_path']), return_tensors="pt")
inputs['labels'] = torch.tensor([dataset_item['sentiment']])
inputs['dataset_id'] = dataset_item['id']
return inputs
# The returned inputs contains
# input_ids: Tensor[int] of shape (1, -1)
# attention_mask: Tensor[int] of shape (1, -1)
# pixel_values: Tensor[int] of shape (1, 3, 224, 224)
# labels: Tensor[int in [0, 1, 2]] of shape (1, )
# dataset_id: str
Step 4: Define the Data Collator Function
This function will be used to prepare batches for training. It will left pad the input ids and attention mask until it’s the length of the longest sequence in that batch.
def data_collator_fn(batch):
input_ids = [item['input_ids'] for item in batch]
max_len = max([ids.shape[1] for ids in input_ids])
for item in batch:
pad_len = max_len - item['input_ids'].shape[1]
padding = torch.unsqueeze(torch.ones(pad_len, dtype=torch.int) * processor.tokenizer.pad_token_id, 0)
attention_padding = torch.unsqueeze(torch.zeros(pad_len, dtype=torch.int), 0)
item['input_ids'] = torch.cat([padding, item['input_ids']], dim=1)
item['attention_mask'] = torch.cat([attention_padding, item['attention_mask']], dim=1)
return {
'input_ids': torch.cat([item['input_ids'] for item in batch]),
'attention_mask': torch.cat([item['attention_mask'] for item in batch]),
'pixel_values': torch.cat([item['pixel_values'] for item in batch]),
'labels': torch.cat([item['labels'] for item in batch]).to(torch.int64),
'dataset_id' : [item['dataset_id'] for item in batch]
}
# The returned batch contains
# input_ids: Tensor[int] of shape (batch_size, batch_max_len)
# attention_mask: Tensor[int] of shape (batch_size, batch_max_len)
# pixel_values: Tensor[int] of shape (batch_size, 3, 224, 224)
# labels: Tensor[int in [0, 1, 2]] of shape (batch_size, )
# dataset_id: List[str]
Step 5: Set Up the PaLI-GEMMA Model
Initialize the PaLI-GEMMA model with 4-bit quantization. The model needs to be quantized and have LoRA applied to it to be able to train in a free Google Colab environment.
paligemma_id = "google/paligemma-3b-mix-224"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
def paligemma_loader_fn():
return PaliGemmaForConditionalGeneration.from_pretrained(paligemma_id, device_map="auto", quantization_config=bnb_config)
processor = AutoProcessor.from_pretrained(paligemma_id)
Step 6: Define Training and Evaluation Functions
Define the functions used for training and evaluation. Because we’re only interested in classifying the sentiment as negative, neutral, or positive, we can take the last token prediction and only pay attention to the predicted probabilities for Negative, Neutral, and Positive. We also do backward propagation on the cross entropy loss between the three of those probabilities only.
crossentropy = torch.nn.CrossEntropyLoss()
paligemma_sentiment_tokens = ['Negative', 'Neutral', 'Positive']
paligemma_sentiment_ids = processor.tokenizer.convert_tokens_to_ids(paligemma_sentiment_tokens)
def paligemma_train(model, dataloader, optimizer, scheduler, num_epochs):
model.train()
batch_num = len(dataloader)
batch_num_of10 = math.ceil(batch_num/10)
temp_loss = 0
for epoch in range(num_epochs):
for b_i, batch in enumerate(dataloader):
batch.pop("dataset_id")
batch = {
key: value.to(model.device)
for key, value in batch.items()
}
labels = batch.pop('labels')
outputs = model(**batch)
sentiment_logits = outputs.logits[:, -1, paligemma_sentiment_ids]
loss = crossentropy(sentiment_logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
temp_loss = loss.item()
del loss, sentiment_logits, outputs, batch, labels
if b_i % batch_num_of10 == batch_num_of10-1:
print("=", end="")
print(f" | Epoch {epoch+1}/{num_epochs} - Loss: {temp_loss}")
scheduler.step()
def paligemma_evaluate(model, dataloader):
model.eval()
total_loss = 0
result_df = pd.DataFrame(columns=['dataset_id', 'target_sentiment', 'predicted_sentiment'])
with torch.inference_mode():
for b_i, batch in enumerate(dataloader):
dataset_id = batch.pop("dataset_id")
batch = {
key: value.to(model.device)
for key, value in batch.items()
}
labels = batch.pop('labels')
outputs = model(**batch)
sentiment_logits = outputs.logits[:, -1, paligemma_sentiment_ids]
total_loss += crossentropy(sentiment_logits, labels).item()
predicted_sentiment = torch.argmax(sentiment_logits, dim=1).cpu()
result_df = pd.concat([result_df, pd.DataFrame({
'dataset_id': dataset_id,
'Target Sentiment': [paligemma_sentiment_tokens[label] for label in labels],
'Predicted Sentiment': [paligemma_sentiment_tokens[pred] for pred in predicted_sentiment]
})])
result_df = result_df.reset_index(drop=True)
return result_df , (total_loss / len(dataloader))
Step 7: Set Up PEFT Configuration
Configure the LoRA (Low-Rank Adaptation) settings for efficient fine-tuning. Ideally, we also apply LoRA to the up_proj and down_proj modules, but we can’t afford to do that using a T4 GPU.
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=[
'o_proj',
'k_proj',
'q_proj',
'gate_proj',
'v_proj'
]
)
Step 8: Implement K-Fold Cross-Validation
This function will perform k-fold cross-validation for fine-tuning and evaluation. One advantage of using LoRA here is that after we finish training and evaluating for a fold, we can just use .unload() to throw away the learned LoRA adapters and set up a new one to fine-tune the model from scratch.
def paligemma_evaluate_kfold(model_loader_fn, dataset, k=5, result_save_dir="./", epochs=2, start_fold=1):
dataset_check = dataset
while isinstance(dataset_check, Subset):
dataset_check = dataset_check.dataset
if not isinstance(dataset_check, TextImageDataset):
raise ValueError("dataset must be a TextImageDataset or a subset of it")
os.makedirs(result_save_dir, exist_ok=True)
kfold = KFold(n_splits=k, shuffle=True, random_state=42)
eval_results = []
model = model_loader_fn()
assert isinstance(model, PaliGemmaForConditionalGeneration), "model loader fn must return a PaliGemmaForConditionalGeneration"
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
if fold+1 < start_fold:
continue
print(f"++ Fold {fold+1}/{k} ++")
model = get_peft_model(model, peft_config)
train_dataset = Subset(dataset, train_ids)
test_dataset = Subset(dataset, test_ids)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator_fn)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=data_collator_fn)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
paligemma_train(model, train_dataloader, optimizer, scheduler, epochs)
eval_result = paligemma_evaluate(model, test_dataloader)
result_df = eval_result[0]
result_df.to_csv(os.path.join(result_save_dir, f"paligemma_ft _fold_{fold+1}.csv"), index=False)
eval_results.append(eval_result)
# save and reset model
model.save_pretrained(os.path.join(result_save_dir, f"paligemma_ft_fold_{fold+1}_adapters.pt"))
model = model.unload()
print()
return eval_results
Step 9: Prepare the Dataset and Run Fine-tuning
This code will load the dataset and run the K-Fold cross validation.
# Load and prepare the dataset
dataset = TextImageDataset("path/to/train_dataset.csv", processor)
# Run fine-tuning with k-fold cross-validation
paligemma_ft_result = paligemma_evaluate_kfold(
paligemma_loader_fn,
dataset,
k=5,
result_save_dir="paligemma_results",
epochs=2
)
Step 10: Analyze Results
After fine-tuning, you can analyze the results saved in the `result_save_dir`. Each fold will have:
- A CSV file with predictions (`paligemma_ft_fold_X.csv`)
- Saved model adapters (`paligemma_ft_fold_X_adapters.pt`)
You can use these to evaluate the model’s performance and make predictions on new data.
PaliGemma Fine-tuning for Multimodal Classification was originally published in Google Developer Experts on Medium, where people are continuing the conversation by highlighting and responding to this story.