Batch API Basics with Mistral AI

VisionImage understanding+1

In this notebook, we will guide you through an example of extracting information from multiple receipts using the Mistral batch API and returning the data in a pandas DataFrame.

!pip install mistralai datasets

Get receipts images

We found some receipts images from Hugging Face:

import pandas as pd
from datasets import load_dataset

# Replace 'dataset_name' with the actual name of the dataset you want to download
dataset_name = 'shirastromer/supermarket-receipts'  # Example: IMDB dataset

# Load the dataset
dataset = load_dataset(dataset_name)

# Convert the dataset to a pandas DataFrame
# Assuming you want to load the 'train' split of the dataset
df = pd.DataFrame(dataset['train'])

# Display the first few rows of the DataFrame
df.head()

Take a look at one image

Let's start with one image and get info from this image with Pixtral Large.

# take a look at an image
df.image[1]

let's extract information from a single image using the Mistral API.

import base64
from io import BytesIO
from typing import Any
from PIL.Image import Image

def format_image(image: Image) -> str:
    """
    Converts an image to a base64-encoded string with a JPEG format.

    Args:
        image (Image): The image to be formatted.

    Returns:
        str: The base64-encoded string with a data URI prefix.
    """
    # Convert image to base64
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")

    # Add the prefix for base64 format
    formatted_base64 = f"data:image/jpeg;base64,{image_base64}"
    return formatted_base64
from mistralai import Mistral
import os

api_key = os.environ["MISTRAL_API_KEY"]

client = Mistral(api_key=api_key)
# Define the messages for the chat
# Let's extract name, price, and get category for the item
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image_url",
                "image_url": format_image(df.image[1])
            },
            {
                "type": "text",
                "text": "Extract the name and price of each item on the receipt, categorize each item into one of the following categories: 'Medical', 'Food', 'Beverage', 'Travel', or 'Other', and return the results as a well-structured JSON object. The JSON object should include only the fields: name, price, and classification for each item."
            }
        ]
    },
    {"role": "assistant", "content": "{", "prefix": True},
]

# Get the chat response
chat_response = client.chat.complete(
    model="pixtral-large-latest",
    messages=messages,
    response_format = {
          "type": "json_object",
    }

)

# Print the content of the response
print(chat_response.choices[0].message.content)

Use batch API to process many images

Create a batch

Let's process 10 images as an example.

import json
from io import BytesIO

num_samples = 10

list_of_json = []
for idx in range(num_samples):
    request = {
        "custom_id": str(idx),
        "body": {
            "max_tokens": 1000,
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": format_image(df.image[idx])
                        },
                        {
                            "type": "text",
                            "text": "Identify the name and price of each item on the receipt, categorize each item into one of the following categories: 'Medical', 'Food', 'Beverage', 'Travel', or 'Other', and return the results as a well-structured JSON object. The JSON object should include only the fields: name, price, and classification for each item."
                        }
                    ]
                },
                {
                    "role": "assistant",
                    "content": "{",
                    "prefix": True
                }
            ],
        "response_format": {"type": "json_object"}
        }
    }
    list_of_json.append(json.dumps(request).encode("utf-8"))

Upload your batch

batch_data = client.files.upload(
    file={
        "file_name": "file.jsonl",
        "content": b"\n".join(list_of_json)},
    purpose = "batch"
)
batch_data

Create a batch job

created_job = client.batch.jobs.create(
    input_files=[batch_data.id],
    model="pixtral-large-latest",
    endpoint="/v1/chat/completions",
    metadata={"job_type": "testing"}
)
created_job

Get batch job details

retrieved_job = client.batch.jobs.get(job_id=created_job.id)
retrieved_job
print(f"Total requests: {retrieved_job.total_requests}")
print(f"Failed requests: {retrieved_job.failed_requests}")
print(f"Successful requests: {retrieved_job.succeeded_requests}")
print(
    f"Percent done: {round((retrieved_job.succeeded_requests + retrieved_job.failed_requests) / retrieved_job.total_requests, 4) * 100}")

Get batch results

output = client.files.download(file_id=retrieved_job.output_file).read().decode("utf-8").strip()
print(output)

Extract info to a Pandas dataframe

# Parse JSON lines
lines = output.strip().split('\n')

# Extract required fields
extracted_data = []
for line in lines:
    parsed_line = json.loads(line)
    custom_id = parsed_line.get("custom_id")
    response = parsed_line.get("response", {})
    body = response.get("body", {})
    choices = body.get("choices", [])

    for choice in choices:
        message_content = choice.get("message", {}).get("content", "")
        # Extract items from the JSON string in "content"
        try:
            items_data = json.loads(message_content.strip('```'))
            items = items_data if isinstance(items_data, list) else items_data.get("items", [])
            for item in items:
                extracted_data.append({
                    "custom_id": custom_id,
                    "name": item.get("name"),
                    "price": item.get("price"),
                    "classification": item.get("classification")
                })
        except json.JSONDecodeError:
            continue

# Create a Pandas DataFrame
df_output = pd.DataFrame(extracted_data)
df_output