TL;DR: This notebook will show you:
- How to call Mistral's batch inference API
- How to pass images (encoded in base64) in your API calls to Mistral's VLM (here Pixtral-12B)
- How to fine-tune Pixtral-12B on an image classification problem in order to improve its accuracy.
For additional references check out the docs:
from IPython.display import clear_output
!pip install mistralai==1.9.3
clear_output()
Prepare the dataset
We will use AID: A scene classification dataset introduced by Xia et al. hosted on Kaggle under a Public Domain license.
To downloading it, you will have to generate your Kaggle API token:
- Go to your Kaggle account in the Kaggle API Token section,
- Click "Create New API Token" → this will download kaggle.json.
- Upload kaggle.json to Google Colab.
Download and parse the data
def is_colab_runtime() -> bool:
try:
import google.colab
return True
except ImportError:
return False
if is_colab_runtime():
from google.colab import files
# This will prompt you to upload kaggle.json
print("Please upload your kaggle.json file below:")
files.upload()
clear_output()
from pathlib import Path
if (Path() / "kaggle.json").exists():
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d jiayuanchengala/aid-scene-classification-datasets
# This might take a few minutes
!unzip aid-scene-classification-datasets.zip -d satellite_dataset
clear_output()
The dataset consists in:
- satelite images (jpg files)
- each image belongs to a specific class (e.g. airport, commercial, dense residential, medium residential, park, forest, farmland etc.)
We first transform this dataset into something usable for finetuning
- Create pairs of (image, labels)
- Load the images and encode them in base64 (the format expected by Pixtral API)
- Downgrade the quality of the image in order to be a bit more memory-efficient
Note that smaller, specialized vision models could potentially achieve comparable performance levels. This cookbook aims to guide you through the process of effectively fine-tuning Mistral’s Vision Language Model (VLM) using a straightforward example, and to demonstrate its impact on basic classification metrics. More advanced applications of fine-tuning could include interactions like "speak with an image" or generating image captions.
from pathlib import Path
import pandas as pd
from PIL import Image
import base64
import io
from sklearn.model_selection import train_test_split
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Extract pairs of (image, label)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
root_dir = Path() / "satellite_dataset" / "AID"
data = []
for d in root_dir.iterdir():
if not d.is_dir():
continue
data.extend([{"label": d.name, "img_path": p} for p in d.iterdir()])
dataset_df = pd.DataFrame(data)
classes = [*dataset_df["label"].unique()]
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Load image and encode in base64 (this might take a few minutes)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# NOTE: This is not needed here, but a nice additional step would be to resize
# the images into 1024 longest edge (if the image was too big)
# For more details see: https://docs.mistral.ai/capabilities/vision/
def encode_image_to_base64(image_path: str | Path) -> str:
image = Image.open(image_path)
# Resize the image by a factor of 0.5
new_size = (image.width // 2, image.height // 2)
image = image.resize(new_size, Image.Resampling.LANCZOS)
buffer = io.BytesIO()
image.save(buffer, format='JPEG')
buffer.seek(0)
encoded_string = (
"data:image/jpeg;base64,"
+ base64.b64encode(buffer.read()).decode('utf-8')
)
return encoded_string
dataset_df["img_b64"] = [
encode_image_to_base64(img_path)
for img_path in dataset_df["img_path"]
]
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Split dataset in train / test
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
train_df, test_df = train_test_split(
dataset_df,
test_size=0.2,
random_state=42,
stratify=dataset_df["label"]
)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
# Release a bit of memory
del dataset_df
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Light check at the data
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
print("Classes:", classes)
print("Size train:", len(train_df))
print("Size test:", len(test_df))
train_df.head()
from IPython.display import display, HTML
def display_image(dataset_df: pd.DataFrame, idx: int) -> None:
img_b64 = dataset_df["img_b64"].iloc[idx]
label = dataset_df["label"].iloc[idx]
display(HTML(f'<h2>{label}</h2><img src="{img_b64}" />'))
display_image(train_df, idx=0)
display_image(train_df, idx=2)
Prepare a jsonl dataset for Mistral API
# NOTE: We had to engineer a slighlty complex system prompt as the current
# finetuning of Pixtral-12B (July 2nd 2025) misses a few features that more
# recent models offer:
# - The Classifier Factory in LaPlateforme only supports ministral-3b for now
# See: https://docs.mistral.ai/capabilities/finetuning/classifier_factory/
# - Structured outputs are not supported by Pixtral-12B
from pydantic import BaseModel
class ClassifierOutput(BaseModel):
image_description: str
label: str
instruction_prompt = f"""
Classify the following image into the category it belongs to.
- These category labels are {'; '.join(classes)}.
- Output your result using exclusively the following schema: {ClassifierOutput.model_fields}
- Put your results between a json tag
```json
```
""".strip()
print("Instruction prompt:")
print(instruction_prompt)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Template functions to generate valid objects according to
# Mistral's schemas
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
from mistralai.models import UserMessage, SystemMessage, AssistantMessage
def _template_inference_message(
img_b64: str
) -> list[SystemMessage | UserMessage]:
return [
SystemMessage(content=instruction_prompt),
UserMessage(content=[{"type": "image_url", "image_url": img_b64}])
]
def _assistant_message(label: str) -> str:
return f"""
```json
{{
"image_description": "An aerial view of a {label.lower()}",
"label": "{label}"
}}
```
""".strip()
def _template_finetuning_message(
img_b64: str, label: str
) -> list[SystemMessage | UserMessage | AssistantMessage]:
return [
SystemMessage(content=instruction_prompt),
UserMessage(content=[{"type": "image_url", "image_url": img_b64}]),
AssistantMessage(content=_assistant_message(label))
]
_template_inference_message(train_df["img_b64"].iloc[0])
_template_finetuning_message(train_df["img_b64"].iloc[0], train_df["label"].iloc[0])
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# We will use the Batch API for inference on the test dataset,
# and thus apply the proper formatting
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
inference_data = []
for row_id, row in test_df.iterrows():
messages = _template_inference_message(img_b64=row["img_b64"])
inference_data.append(
{
"custom_id": str(row_id),
"body": {
"messages": [m.model_dump() for m in messages]
}
}
)
inference_data[0]
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# We will finetune Pixtral on the train dataset
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
finetuning_data = []
for row_id, row in train_df.iterrows():
messages = _template_finetuning_message(img_b64=row["img_b64"], label=row["label"])
finetuning_data.append([m.model_dump() for m in messages])
finetuning_data[0]
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Small util function to convert a dictionnary / JSON into a
# jsonl file for upload to LaPlateforme
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
import json
def dict_to_jsonl(objects: list[dict], path: Path | str) -> None:
with open(path, 'w') as f:
for obj in objects:
f.write(json.dumps(obj) + '\n')
dict_to_jsonl(inference_data, path=(Path() / "test_satellite_for_batch_inference.jsonl"))
dict_to_jsonl(finetuning_data, path=(Path() / "train_satellite_for_finetuning.jsonl"))
Classify satellite images using Pixtral-12B
Get a baseline using the off-the-shelf Pixtral-12B
We will use the "base" Pixtral-12B (no fine-tuning) to get a baseline on this classification task.
from getpass import getpass
from mistralai import Mistral
api_key= getpass("Type your API Key")
client = Mistral(api_key=api_key)
pixtral = "pixtral-12b-2409"
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Helper function to extract & parse JSON content from the
# model response (it mimics the logic of structured output mode
# which unfortunately is not supported on Pixtral-12B)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
import re
import json
def parse_content(content: str) -> ClassifierOutput | None:
match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
if not match:
return None
json_str = match.group(1).strip()
try:
return ClassifierOutput(**json.loads(json_str))
except:
return None
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# We call the "base" Pixtral-12B on an example that works
# well (right label is retrieved)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
IDX = 0
user_message = _template_inference_message(train_df["img_b64"].iloc[IDX])
results = client.chat.complete(
model=pixtral,
messages=user_message,
temperature=0
)
results = parse_content(results.choices[0].message.content)
display_image(train_df, idx=IDX)
print(results)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# We call the "base" Pixtral-12B on an example that does not work
# well (wrong label is retrieved)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
IDX = 4
user_message = _template_inference_message(train_df["img_b64"].iloc[IDX])
results = client.chat.complete(
model=pixtral,
messages=user_message,
temperature=0
)
results = parse_content(results.choices[0].message.content)
display_image(train_df, idx=IDX)
print(results)
Since we have many images to caption, we will use Mistral batch API (and might have to wait a little bit)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Launch a batch inference job
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# NOTE: This can also be done via LaPlateforme UI by drag-and-drop of
# the input files
# NOTE: There might be some size limitation, both in terms of number
# of records and memory footprint of the model
# Upload the dataset
with open(Path() / "test_satellite_for_batch_inference.jsonl", "rb") as f:
batch_inference_data = client.files.upload(
file={
"file_name": "test_satellite_for_batch_inference.jsonl",
"content": f
},
purpose="batch"
)
# Launch the job
created_job = client.batch.jobs.create(
input_files=[batch_inference_data.id],
model=pixtral,
endpoint="/v1/chat/completions",
metadata={"job_type": "testing"}
)
print("file ID:", batch_inference_data.id)
print("job ID:", created_job.id)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Util decorator to timeout operations (in case the batch API takes
# longer than expected)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
import signal
from contextlib import contextmanager
class TimeoutException(Exception):
pass
@contextmanager
def timeout(seconds: int, error_message="Operation timed out"):
def _handle_timeout(signum, frame):
raise TimeoutException(error_message)
# Set the signal handler and an alarm
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
yield
finally:
# Cancel the alarm regardless of success or failure
signal.alarm(0)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Retrieve job results (with a timeout of 1h)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
import time
with timeout(3600):
while True:
retrieved_job = client.batch.jobs.get(job_id=created_job.id)
if retrieved_job.status not in ["RUNNING", "QUEUED"]:
print(f"Job finished with status {retrieved_job.status}")
break
time.sleep(30) # Pause for 30s
if not retrieved_job.status == "SUCCESS":
raise RuntimeError("Batch job failed")
output_file_stream = client.files.download(file_id=retrieved_job.output_file)
results_df = pd.read_json(path_or_buf=output_file_stream, lines=True)
results_df.head()
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Apply processing to the retrieved results
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
parsed_results_df = results_df.copy()
parsed_results_df[f"predicted_{pixtral}"] = (
parsed_results_df["response"]
.apply(lambda response: parse_content(response["body"]["choices"][0]["message"]["content"]))
.apply(lambda x: x.label if x else None)
)
parsed_results_df = (
parsed_results_df.sort_values("custom_id")
.reset_index(drop=True)
[["custom_id", f"predicted_{pixtral}"]]
)
parsed_results_df.head()
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Let's compare against our "ground_truth"
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
comparison_df = test_df.copy()
comparison_df["custom_id"] = comparison_df.reset_index(drop=False)["index"]
comparison_df = comparison_df.merge(parsed_results_df)
comparison_df.head()
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Display traditional classification metrics
# ✅ Good baseline model: some classes are very well handled
# (e.g. Desert, Railway station, etc.)
# ❌ But there are limitations
# - some labels are "hallucinated" (i.e. not present in the system prompt)
# - some classes are not well handled (e.g. River vs. Bridge)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
y_true = [*comparison_df["label"]]
y_pred = [
# NOTE: Pixtral-12B might have hallucinated some categories
# We'll put them under the "WRONG" label
label if label in classes else "WRONG"
for label in comparison_df[f"predicted_{pixtral}"]
]
labels = [*classes, "WRONG"]
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Confusion matrix
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
cm = confusion_matrix(y_true, y_pred, labels=labels)
print("Confusion Matrix:")
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
fig, ax = plt.subplots(figsize=(8, 6))
disp.plot(ax=ax, cmap='Oranges', colorbar=False)
plt.title(f"Confusion Matrix ({pixtral})", fontsize=18, fontweight='bold')
plt.xticks(rotation=90)
plt.show()
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Classification report
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=labels))
Improve the results by finetuning a Pixtral-12B
Note: Here the objective of the finetuning is to improve the classification accuracy of the VLM. Finetuning can also be helpful to align the tone of the assistant to a desired style
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Launch a finetuning job
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# NOTE: This can also be done via LaPlateforme UI by drag-and-drop of
# the input files
# NOTE: There might be some size limitation, both in terms of number
# of records and memory footprint of the model
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Upload your training dataset
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
with open(Path() / "train_satellite_for_finetuning.jsonl", "rb") as f:
finetuning_data = client.files.upload(
file={
"file_name": "train_satellite_for_finetuning.jsonl",
"content": f
},
purpose="fine-tune"
)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Create the job
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
#
# Fine-tuning is both an art and a science
#
# Our recommended approach for fine-tuning our models is to do so iteratively.
# Start with smaller dataset e.g 200-1000 examples (and potentially smaller
# models), larger learning rates, lower epochs and once you start to see the
# desired results in the correct direction start scaling these to be bigger
# and bigger.
#
# Once you are somewhat comfortable with the fine-tuning data used for the
# model, you want to start playing around with the different fine-tuning
# hyperparameters.
#
# Learning Rate - The learning rate defines how much the model adjusts for each
# training step. Bigger learning rates will make the model move faster so you
# potentially may need less training time, however it may also potentially
# overshoot and reach a local minima
#
# Epochs - This is the amount of full runs across the whole dataset the training
# will do. It’s somewhat proportional to the training time.
#
# Batch Size - This is the number of examples passed in each gradient step to
# updates the FT-weights. A higher batch size typically leads to smoother
# training, but with small datasets you typically use a smaller batch size.
#
# Here is an equation that summarize the relationship between the parameters:
# Epochs = Steps × Batch Size / Total Number of Training Samples
created_job = client.fine_tuning.jobs.create(
model="pixtral-12b-latest",
training_files=[{"file_id": finetuning_data.id, "weight": 1}],
validation_files=[],
hyperparameters={
"epochs": 2,
"learning_rate": 0.0001,
},
auto_start=True # This could be set to False to let you manually validate the run
)
print("file ID:", finetuning_data.id)
print("job ID:", created_job.id)
with timeout(3600):
while True:
retrieved_job = client.fine_tuning.jobs.get(job_id=created_job.id)
if retrieved_job.status not in ["VALIDATING", "RUNNING", "QUEUED"]:
print(f"Job finished with status {retrieved_job.status}")
break
time.sleep(30) # Pause for 30s
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# See all the details associated with a finetuning job
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
retrieved_job
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Visualize the evolution of the loss during FT
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
def plot_loss(job_id: str, model: str) -> None:
# Extract step_number and train_loss from the job data
retrieved_job = client.fine_tuning.jobs.get(job_id=job_id)
step_numbers = [checkpoint.step_number for checkpoint in retrieved_job.checkpoints]
train_losses = [checkpoint.metrics.train_loss for checkpoint in retrieved_job.checkpoints]
# Plot the data
plt.figure(figsize=(12, 8))
plt.plot(step_numbers, train_losses, marker='o', linestyle='-', color='b', label='Train Loss')
# Adding titles and labels
plt.title(f'Finetune {model} on your data - Train Loss', fontsize=16, fontweight='bold')
plt.xlabel('Step Number', fontsize=14)
plt.ylabel('Train Loss', fontsize=14)
# Adding grid for better readability
plt.grid(True, linestyle='--', alpha=0.7)
# Adding legend
plt.legend(loc='best')
# Customizing the appearance
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
# Adding a background color
plt.gca().set_facecolor('#f0f0f0')
# Display the plot
plt.show()
plot_loss(job_id=retrieved_job.id, model=pixtral)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Retrieve the ID of the fine-tuned model
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
fine_tuned_pixtral = client.fine_tuning.jobs.get(job_id=created_job.id).fine_tuned_model
fine_tuned_pixtral
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Make a batch inference with fine-tuned model
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
created_job = client.batch.jobs.create(
input_files=[batch_inference_data.id],
model=fine_tuned_pixtral,
endpoint="/v1/chat/completions",
metadata={"job_type": "testing"}
)
print("file ID:", batch_inference_data.id)
print("job ID:", created_job.id)
with timeout(3600):
while True:
retrieved_job = client.batch.jobs.get(job_id=created_job.id)
if retrieved_job.status not in ["RUNNING", "QUEUED"]:
print(f"Job finished with status {retrieved_job.status}")
break
time.sleep(30) # Pause for 30s
if not retrieved_job.status == "SUCCESS":
raise RuntimeError("Batch job failed")
output_file_stream = client.files.download(file_id=retrieved_job.output_file)
results_df = pd.read_json(path_or_buf=output_file_stream, lines=True)
results_df.head()
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Apply processing to the retrieved results (same as previously)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
parsed_results_df = results_df.copy()
parsed_results_df[f"predicted_{fine_tuned_pixtral}"] = (
parsed_results_df["response"]
.apply(lambda response: parse_content(response["body"]["choices"][0]["message"]["content"]))
.apply(lambda x: x.label if x else None)
)
parsed_results_df = (
parsed_results_df.sort_values("custom_id")
.reset_index(drop=True)
[["custom_id", f"predicted_{fine_tuned_pixtral}"]]
)
parsed_results_df.head()
model_comparison_df = comparison_df.merge(parsed_results_df)
model_comparison_df.head()
len(y_true), len(y_pred)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Get classification metrics on the new model
# 🎉 🎉 🎉 Good news: performance are improved:
# - overall accuracy improvement (from ~0.5 to ~0.8)
# - almost no more hallucination
# - huge performance gain on some classes
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
y_true = [*model_comparison_df["label"]]
y_pred = [
# NOTE: Pixtral-12B might have hallucinated some categories
# We'll put them under the "WRONG" label
label if label in classes else "WRONG"
for label in model_comparison_df[f"predicted_{fine_tuned_pixtral}"]
]
labels = classes + (["WRONG"] if "WRONG" in y_pred else [])
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Confusion matrix
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
cm = confusion_matrix(y_true, y_pred, labels=labels)
print("Confusion Matrix:")
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
fig, ax = plt.subplots(figsize=(8, 6))
disp.plot(ax=ax, cmap='Oranges', colorbar=False)
plt.title(f"Confusion Matrix ({fine_tuned_pixtral})", fontsize=18, fontweight='bold')
plt.xticks(rotation=90)
plt.show()
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Classification report
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=labels))
Wrapping up: Finetuning Pixtral-12B really improves its classification accuracy
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Compare the two models:
# - Huge performance gain on some classes
# - A few classes see a small drop in performance, but overall
# model performance is improved
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
def per_class_accuracy(y_true, y_pred, classes):
accuracies = []
for cls in classes:
idx = np.where(np.array(y_true) == cls)[0]
if len(idx) == 0:
accuracies.append(np.nan)
else:
acc = accuracy_score(np.array(y_true)[idx], np.array(y_pred)[idx])
accuracies.append(acc)
return accuracies
y_true = model_comparison_df["label"]
y_pred_pixtral = model_comparison_df[f"predicted_{pixtral}"]
y_pred_ft_pixtral = model_comparison_df[f"predicted_{fine_tuned_pixtral}"]
classes = sorted(set(y_true))
# Compute per-class accuracy
acc_pixtral = per_class_accuracy(y_true, y_pred_pixtral, classes)
acc_ft_pixtral = per_class_accuracy(y_true, y_pred_ft_pixtral, classes)
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
# Plotting
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
x = np.arange(len(classes))
width = 0.35
fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.bar(x - width/2, acc_pixtral, width, label=pixtral, color='#F39C12')
bars2 = ax.bar(x + width/2, acc_ft_pixtral, width, label=fine_tuned_pixtral, color='#D35400')
ax.set_xlabel('Class', fontsize=14)
ax.set_ylabel('Accuracy', fontsize=14)
ax.set_title('Per-Class Accuracy Comparison', fontsize=16, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(classes)
ax.set_ylim(0, 1.0)
ax.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=90)
plt.show()