OCR at scale via Mistral's Batch API

OCRBatch

Used

  • OCR
  • Batch Inference

Setup

First, let's install mistralai and datasets

!pip install mistralai datasets

We can now set up our client. You can create an API key on our Plateforme.

from mistralai import Mistral

api_key = "API_KEY"
client = Mistral(api_key=api_key)
ocr_model = "mistral-ocr-latest"

Without Batch

As an example, let's use Mistral OCR to extract text from multiple images.

We will use a dataset containing raw image data. To send this data via an image URL, we need to encode it in base64. For more information, please visit our Vision Documentation.

import base64
from io import BytesIO
from PIL import Image

def encode_image_data(image_data):
    try:
        # Ensure image_data is bytes
        if isinstance(image_data, bytes):
            # Directly encode bytes to base64
            return base64.b64encode(image_data).decode('utf-8')
        else:
            # Convert image data to bytes if it's not already
            buffered = BytesIO()
            image_data.save(buffered, format="JPEG")
            return base64.b64encode(buffered.getvalue()).decode('utf-8')
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

For this demo, we will use a simple dataset containing numerous documents and scans in image format. Specifically, we will use the HuggingFaceM4/DocumentVQA dataset, loaded via the datasets library.

We will download only 100 samples for this demonstration.

from datasets import load_dataset

n_samples = 100
dataset = load_dataset("HuggingFaceM4/DocumentVQA", split="train", streaming=True)
subset = list(dataset.take(n_samples))

With our subset of 100 samples ready, we can loop through each image to extract the text.

We will save the results in a new dataset and export it as a JSONL file.

from tqdm import tqdm

ocr_dataset = []
for sample in tqdm(subset):
    image_data = sample['image']  # 'image' contains the actual image data

    # Encode the image data to base64
    base64_image = encode_image_data(image_data)
    image_url = f"data:image/jpeg;base64,{base64_image}"

    # Process the image using Mistral OCR
    response = client.ocr.process(
        model=ocr_model,
        document={
            "type": "image_url",
            "image_url": image_url,
        }
    )

    # Store the image data and OCR content in the new dataset
    ocr_dataset.append({
        'image': base64_image,
        'ocr_content': response.pages[0].markdown # Since we are dealing with single images, there will be only one page
    })
import json

with open('ocr_dataset.json', 'w') as f:
    json.dump(ocr_dataset, f, indent=4)

Perfect, we have extracted all text from the 100 samples. However, this process can be made more cost-efficient using Batch Inference.

With Batch

To use Batch Inference, we need to create a JSONL file containing all the image data and request information for our batch.

Let's create a function called create_batch_file to handle this task by generating a file in the proper format.

def create_batch_file(image_urls, output_file):
    with open(output_file, 'w') as file:
        for index, url in enumerate(image_urls):
            entry = {
                "custom_id": str(index),
                "body": {
                    "document": {
                        "type": "image_url",
                        "image_url": url
                    },
                    "include_image_base64": True
                }
            }
            file.write(json.dumps(entry) + '\n')

The next step involves encoding the data of each image into base64 and saving the URL of each image that will be used.

image_urls = []
for sample in tqdm(subset):
    image_data = sample['image']  # 'image' contains the actual image data

    # Encode the image data to base64 and add the url to the list
    base64_image = encode_image_data(image_data)
    image_url = f"data:image/jpeg;base64,{base64_image}"
    image_urls.append(image_url)

We can now create our batch file.

batch_file = "batch_file.jsonl"
create_batch_file(image_urls, batch_file)

With everything ready, we can upload it to the API.

batch_data = client.files.upload(
    file={
        "file_name": batch_file,
        "content": open(batch_file, "rb")},
    purpose = "batch"
)

The file is uploaded, but the batch inference has not started yet. To initiate it, we need to create a job.

created_job = client.batch.jobs.create(
    input_files=[batch_data.id],
    model=ocr_model,
    endpoint="/v1/ocr",
    metadata={"job_type": "testing"}
)

Our batch is ready and running!

We can retrieve information using the following method:

retrieved_job = client.batch.jobs.get(job_id=created_job.id)
print(f"Status: {retrieved_job.status}")
print(f"Total requests: {retrieved_job.total_requests}")
print(f"Failed requests: {retrieved_job.failed_requests}")
print(f"Successful requests: {retrieved_job.succeeded_requests}")
print(
    f"Percent done: {round((retrieved_job.succeeded_requests + retrieved_job.failed_requests) / retrieved_job.total_requests, 4) * 100}%"
)

Let's automate this feedback loop and download the results once they are ready!

import time
from IPython.display import clear_output

while retrieved_job.status in ["QUEUED", "RUNNING"]:
    retrieved_job = client.batch.jobs.get(job_id=created_job.id)

    clear_output(wait=True)  # Clear the previous output ( User Friendly )
    print(f"Status: {retrieved_job.status}")
    print(f"Total requests: {retrieved_job.total_requests}")
    print(f"Failed requests: {retrieved_job.failed_requests}")
    print(f"Successful requests: {retrieved_job.succeeded_requests}")
    print(
        f"Percent done: {round((retrieved_job.succeeded_requests + retrieved_job.failed_requests) / retrieved_job.total_requests, 4) * 100}%"
    )
    time.sleep(2)
client.files.download(file_id=retrieved_job.output_file)

Done! With this method, you can perform OCR tasks in bulk in a very cost-effective way.